In [1]:
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from model import ExpressionClassifier
import torch
from torchvision import transforms

In [2]:
train_transforms = transforms.Compose([
    transforms.RandomRotation(30),
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                         std=[0.229, 0.224, 0.225])
])

valid_transforms = transforms.Compose([
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                         std=[0.229, 0.224, 0.225])
])

In [3]:
train_set = ImageFolder(root='/home/khairulimam/datasets/expressions/IMFDB/train/', transform=train_transforms)
valid_set = ImageFolder(root='/home/khairulimam/datasets/expressions/IMFDB/valid/', transform=valid_transforms)

train_loader = DataLoader(train_set, batch_size=32, shuffle=True)
valid_loader = DataLoader(valid_set, batch_size=32, shuffle=True)

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = ExpressionClassifier(num_classes=7)
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
model.to(device)
model = torch.nn.DataParallel(model)
criterion = torch.nn.CrossEntropyLoss()

In [34]:
def train(model, imgs, lbls):
    model.train()
    imgs = imgs.to(device)
    lbls = lbls.to(device)
    
    logits = model(imgs)
    loss = criterion(logits, lbls)
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    return loss.item()

def validate(model, imgs, lbls):
    model.eval()
    with torch.no_grad():
        imgs = imgs.to(device)
        lbls = lbls.to(device)
        
        logits = model(imgs)
        
        _, predictions = torch.max(logits, 1)
        
        trues = predictions == lbls
        return trues.sum().item()
        

In [None]:
for epoch in range(10):
    lossses = list()
    accuracies = list()
    for idx, (imgs, lbls) in enumerate(train_loader):
        loss = train(model, imgs, lbls)
        lossses.append(loss)
    print(epoch, 'train loss', sum(lossses)/len(lossses))
    for idx, (imgs, lbls) in enumerate(valid_loader):
        accuracy = validate(model, imgs, lbls)
        accuracies.append(accuracy)
    print(epoch, 'valid accuracies', sum(accuracies)/len(accuracies))

0 train loss 1.639694983236621
0 valid accuracies 10.109289617486338
1 train loss 1.6176352529563955
1 valid accuracies 10.628415300546449
2 train loss 1.595262550862036
2 valid accuracies 10.972677595628415
3 train loss 1.579426805549693
3 valid accuracies 10.830601092896174
