In [18]:
import torchvision.models as models
import torch.nn as nn
import torch
from torchvision import transforms
from dataset import FacialImageData
from torch.utils.data import DataLoader
from custom_transforms import ImgMask
import torch.optim as optim
import os

In [9]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [5]:
## setup current model and load weights
resnet50 = models.resnet50()
num_classes = 7
resnet50.fc = nn.Linear(resnet50.fc.in_features, num_classes)
pretrained = torch.load("./saved_model/resnet50_on_FER.pth")
resnet50.load_state_dict(pretrained["state_dict"])
resnet50.to(device)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [6]:
resnet50.eval()

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [10]:
## make some evaluation on full image (no mask)
VAL_BATCH_SIZE = 100
full_val_transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize(
            (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)
        )
    ]
)
full_val_set = FacialImageData(
    directory="./data/test",
    transform=full_val_transform
)
full_val_loader = DataLoader(
    full_val_set,
    batch_size=VAL_BATCH_SIZE,
    shuffle=False,
    num_workers=2
)

In [11]:
loss_func = nn.CrossEntropyLoss().to(device)

In [12]:
with torch.no_grad():
    val_loss = 0
    total_examples = 0
    correct_examples = 0
    for batch_idx, (inputs, targets) in enumerate(full_val_loader):
        # copy inputs to device
        inputs = inputs.to(device)
        targets = targets.to(device)
        # compute the output and loss
        out = resnet50(inputs)
        loss = loss_func(out, targets)
        # count the number of correctly predicted samples
        # in the current batch
        _, predicted = torch.max(out, 1)
        correct = predicted.eq(targets).sum()
        val_loss += loss.detach().cpu()
        total_examples += targets.shape[0]
        correct_examples += correct.item()
avg_loss = val_loss / len(full_val_loader)
avg_acc = correct_examples / total_examples
print(
    "Validation loss: %.4f, Validation accuracy: %.4f" % (avg_loss, avg_acc)
)

Validation loss: 0.9731, Validation accuracy: 0.6649


In [14]:
## make some evaluation on upper masked image (upper mask)
upper_mask_val_transform = transforms.Compose(
    [
        transforms.ToTensor(),
        ImgMask([1, 2]),
        transforms.Normalize(
            (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)
        )
    ]
)
upper_mask_val_set = FacialImageData(
    directory="./data/test",
    transform=upper_mask_val_transform
)
upper_mask_val_loader = DataLoader(
    upper_mask_val_set,
    batch_size=VAL_BATCH_SIZE,
    shuffle=False,
    num_workers=2
)

In [15]:
with torch.no_grad():
    val_loss = 0
    total_examples = 0
    correct_examples = 0
    for batch_idx, (inputs, targets) in enumerate(upper_mask_val_loader):
        # copy inputs to device
        inputs = inputs.to(device)
        targets = targets.to(device)
        # compute the output and loss
        out = resnet50(inputs)
        loss = loss_func(out, targets)
        # count the number of correctly predicted samples
        # in the current batch
        _, predicted = torch.max(out, 1)
        correct = predicted.eq(targets).sum()
        val_loss += loss.detach().cpu()
        total_examples += targets.shape[0]
        correct_examples += correct.item()
avg_loss = val_loss / len(upper_mask_val_loader)
avg_acc = correct_examples / total_examples
print(
    "Validation loss: %.4f, Validation accuracy: %.4f" % (avg_loss, avg_acc)
)

Validation loss: 1.9188, Validation accuracy: 0.4489


In [16]:
## further fine tune the model
## define hyperparameters
LR = 0.01
MOMENTUM = 0.9
REG = 1e-4
NUM_EPOCH = 15

In [19]:
## trainning data
TRAIN_BATCH_SIZE = 128
upper_mask_train_transform = transforms.Compose(
    [
        # transforms.RandomCrop(48, padding=4),
        # transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        ImgMask([1, 2]),
        transforms.Normalize(
            (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)
        )
    ]
)
upper_mask_train_set = FacialImageData(
    directory="./data/test",
    transform=upper_mask_train_transform
)
upper_mask_train_loader = DataLoader(
    upper_mask_train_set,
    batch_size=TRAIN_BATCH_SIZE,
    shuffle=True,
    num_workers=2
)

In [20]:
## create optimizer
optimizer = optim.SGD(
    resnet50.parameters(),
    lr=LR,
    momentum=MOMENTUM,
    weight_decay=REG
)

# the folder where the trained model is saved
CHECKPOINT_FOLDER = "./saved_model"
best_val_acc = 0
resnet50 = resnet50.to(device)
current_learning_rate = LR


train_loss_hist = []
train_acc_hist = []
test_loss_hist = []
test_acc_hist = []
# start the training/validation process
print("==> Training starts!")
print("="*50)
for i in range(0, NUM_EPOCH):
    print("Epoch %d:" %i)
    ## Train on the train set
    #####################################################################
    # switch to train mode
    resnet50.train()
    # this help you compute the training accuracy
    total_examples = 0
    correct_examples = 0
    train_loss = 0 # track training loss if you want
    # Train the model for 1 epoch.
    for batch_idx, (inputs, targets) in enumerate(upper_mask_train_loader):
        # copy inputs to device
        inputs = inputs.to(device)
        targets = targets.to(device)
        # compute the output and loss
        out = resnet50(inputs)
        loss = loss_func(out, targets)
        # zero the gradient
        optimizer.zero_grad()
        # backpropagation
        loss.backward()
        # apply gradient and update the weights
        optimizer.step()
        # count the number of correctly predicted samples in the current batch
        _, predicted = torch.max(out, 1)
        correct = predicted.eq(targets).sum()
        train_loss += loss.detach().cpu()
        total_examples += targets.shape[0]
        correct_examples += correct.item()
    avg_loss = train_loss / len(upper_mask_train_loader)
    train_loss_hist.append(avg_loss)
    avg_acc = correct_examples / total_examples
    train_acc_hist.append(avg_acc)
    print("Training loss: %.4f, Training accuracy: %.4f" %(avg_loss, avg_acc))
    ######################################################################

    # Validate on the validation dataset (masked)
    ######################################################################
    # switch to eval mode
    resnet50.eval()
    # this help you compute the validation accuracy
    total_examples = 0
    correct_examples = 0
    val_loss = 0 # track the validation loss
    # disable gradient during validation, which can save GPU memory
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(upper_mask_val_loader):
            # copy inputs to device
            inputs = inputs.to(device)
            targets = targets.to(device)
            # compute the output and loss
            out = resnet50(inputs)
            loss = loss_func(out, targets)
            # count the number of correctly predicted samples
            # in the current batch
            _, predicted = torch.max(out, 1)
            correct = predicted.eq(targets).sum()
            val_loss += loss.detach().cpu()
            total_examples += targets.shape[0]
            correct_examples += correct.item()
    avg_loss = val_loss / len(upper_mask_val_loader)
    test_loss_hist.append(avg_loss)
    avg_acc = correct_examples / total_examples
    test_acc_hist.append(avg_acc)
    print(
        "Validation loss: %.4f, Validation accuracy: %.4f" % (avg_loss, avg_acc)
    )
    ######################################################################

    # save the model checkpoint
    if avg_acc > best_val_acc:
        best_val_acc = avg_acc
        if not os.path.exists(CHECKPOINT_FOLDER):
           os.makedirs(CHECKPOINT_FOLDER)
        print("Saving ...")
        state = {'state_dict': resnet50.state_dict(),
                'epoch': i,
                'lr': current_learning_rate}
        torch.save(state, os.path.join(CHECKPOINT_FOLDER, 'resnet50_on_upper_masked.pth'))
    print('')
    ## decay learning rate
    # if i % DECAY_EPOCH == 0 and i != 0:
    #     current_learning_rate *= LR_DECAY
    #     for param_group in optimizer.param_groups:
    #         param_group["lr"] = current_learning_rate
    #         print(f"Current learning rate has decayed to %f" %current_learning_rate)

print("="*50)
print(
    f"==> Optimization finished! Best validation accuracy: {best_val_acc:.4f}"
)

==> Training starts!
Epoch 0:
Training loss: 1.2988, Training accuracy: 0.5038


KeyboardInterrupt: 