In [None]:
import torchvision.models as models
import torch.nn as nn
import torch
import loaders
import hyperparameters
import training

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
## setup current model and load weights
resnet50 = models.resnet50()
num_classes = 7
resnet50.fc = nn.Linear(resnet50.fc.in_features, num_classes)
pretrained = torch.load("./saved_model/resnet50_on_FER.pth")
resnet50.load_state_dict(pretrained["state_dict"])
resnet50.to(device)
resnet50.eval()

In [None]:
loss_func = nn.CrossEntropyLoss().to(device)

In [None]:
## evaluate on full face
full_val_loader = loaders.get_loader(
    mask="full", train=False, shuffle=False
)
with torch.no_grad():
    val_loss = 0
    total_examples = 0
    correct_examples = 0
    for batch_idx, (inputs, targets) in enumerate(full_val_loader):
        # copy inputs to device
        inputs = inputs.to(device)
        targets = targets.to(device)
        # compute the output and loss
        out = resnet50(inputs)
        loss = loss_func(out, targets)
        # count the number of correctly predicted samples
        # in the current batch
        _, predicted = torch.max(out, 1)
        correct = predicted.eq(targets).sum()
        val_loss += loss.detach().cpu()
        total_examples += targets.shape[0]
        correct_examples += correct.item()
avg_loss = val_loss / len(full_val_loader)
avg_acc = correct_examples / total_examples
print(
    "Validation loss: %.4f, Validation accuracy: %.4f" % (avg_loss, avg_acc)
)

In [None]:
## evaluate on lower face
lower_mask_val_loader = loaders.get_loader(
    mask="lower", train=False, shuffle=False
)
with torch.no_grad():
    val_loss = 0
    total_examples = 0
    correct_examples = 0
    for batch_idx, (inputs, targets) in enumerate(lower_mask_val_loader):
        # copy inputs to device
        inputs = inputs.to(device)
        targets = targets.to(device)
        # compute the output and loss
        out = resnet50(inputs)
        loss = loss_func(out, targets)
        # count the number of correctly predicted samples
        # in the current batch
        _, predicted = torch.max(out, 1)
        correct = predicted.eq(targets).sum()
        val_loss += loss.detach().cpu()
        total_examples += targets.shape[0]
        correct_examples += correct.item()
avg_loss = val_loss / len(lower_mask_val_loader)
avg_acc = correct_examples / total_examples
print(
    "Validation loss: %.4f, Validation accuracy: %.4f" % (avg_loss, avg_acc)
)

In [None]:
lower_mask_train_loader = loaders.get_loader(
    mask="lower", train=True, shuffle=True
)

In [None]:
training.train_model(
    resnet50,
    device,
    hyperparameters.CHECKPOINT_FOLDER,
    "lower.pth",
    hyperparameters.LR,
    hyperparameters.MOMENTUM,
    hyperparameters.REG,
    15,
    lower_mask_train_loader,
    lower_mask_val_loader,
    loss_func,
    full_val_loader,
    0.1,
    5
)