In [26]:
from torchvision import datasets, transforms
import torch
import numpy as np

np.random.seed(0)
torch.manual_seed(0)

data_path = 'images5em'
dataset = datasets.ImageFolder(
    root=data_path,
    transform= transforms.ToTensor()
)
dataset.classes

['anger', 'happiness', 'neutral', 'sadness', 'surprise']

In [27]:
M=100
dataloader = torch.utils.data.DataLoader(dataset, batch_size=M, shuffle=True)
images, labels = next(iter(dataloader))
images.shape # 100 images, 3 components (r,g,b) and 350 x 350 pixels

torch.Size([100, 3, 350, 350])

In [28]:
images = images.mean(axis=1)
print(images.shape)


torch.Size([100, 350, 350])


In [29]:
# note that r=g=b for grey images (all of these) so mean across (r,g,b) 
# then squash the 350x350 matrices into 122500 arrays
#new_col=np.zeros(100)

images=torch.unsqueeze(images, 1) 

#images = images.mean(axis=1)
print(images.shape)


torch.Size([100, 1, 350, 350])


In [132]:
#x,y=dataset[0]




In [34]:
from torch.autograd import Variable
import torch.nn.functional as F

class CNN(torch.nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        #Layers
        #ouput will be 175x175 "images"
        self.layer1 = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=1, out_channels=32, kernel_size=5, stride=1, padding =2),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=2, stride=2,padding=0)
        )
        #output will be 88x88 images
        self.layer2 = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, stride=1, padding =2),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )
        #output will be 44x44 images
        self.layer3 = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=64, out_channels=128, kernel_size=5, stride=1, padding =2),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        )
        #output will be 22x22 images
        self.layer4 = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=128, out_channels=256, kernel_size=5, stride=1, padding =2),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        )
        #output will be 11x11 images
        self.layer5 = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=256, out_channels=512, kernel_size=5, stride=1, padding =2),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        )
        #avoid overfitting
        self.drop_out = torch.nn.Dropout()
        #fully connected layers
        self.fc1 = torch.nn.Linear(11 * 11 * 512, 1000)
        self.fc2 = torch.nn.Linear(1000, 5)
        
    def forward(self,x):
        xout = self.layer1(x)
        xout = self.layer2(xout)
        xout = self.layer3(xout)
        xout = self.layer4(xout)
        xout = self.layer5(xout)
        xout = xout.reshape(xout.size(0), -1)
        xout = self.drop_out(xout)
        xout = self.fc1(xout)
        xout = self.fc2(xout)
        return xout

In [None]:
model = CNN()
learning_rate=.001
n_epochs=100
running_loss=[]

# Loss function (cross entropy log loss)
criterion = torch.nn.CrossEntropyLoss()
#optimizer which strongly penalizes high confidence in the wrong answer
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

for epoch in range(n_epochs):
    
    #train
    for i in range(M):
        
        outputs = model(images)
        loss = criterion(outputs, labels)
        running_loss.append(loss.item())
        
        #backprop and optimization step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total = M
        _, predicted = torch.max(outputs.data, 1)
        correct = (predicted == labels).sum().item()
        
        if(( i+1 ) %10 == 0):
            print("Epoch ", epoch + 1)
            print("Loss ", loss.item() )
            print("Percent correct ", correct/total * 100)

Epoch  0
Loss  1.5976682901382446
Percent correct  49.0
Epoch  0
Loss  6.909214019775391
Percent correct  49.0
Epoch  0
Loss  3.1908750534057617
Percent correct  47.0
Epoch  0
Loss  1.1456141471862793
Percent correct  47.0
