# Load datasets

Download datasets from bucket

In [1]:
!rm data/casia-100/.DS_Store
!rm data/digiface_subjects_0-1999_72_imgs/.DS_Store

rm: data/casia-100/.DS_Store: No such file or directory
rm: data/digiface_subjects_0-1999_72_imgs/.DS_Store: No such file or directory


Load test and train sets

In [2]:
import torch
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader

train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(), 
    transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),  
])

train_data_path = 'data/train/casia-100'
test_data_path = 'data/test/casia-100'

batch_size = 256

train_data = datasets.ImageFolder(train_data_path, transform=train_transform)
train_loader = torch.utils.data.DataLoader(
    train_data, shuffle=True, batch_size=batch_size)

test_data = datasets.ImageFolder(test_data_path, transform=train_transform)
test_loader = torch.utils.data.DataLoader(
    test_data, shuffle=True, batch_size=batch_size)

In [3]:
iterator = iter(train_loader)
image, label = next(iterator)
print(len(train_data))
print(len(test_data))

print("image", image[0].shape)
print("label", label[0])

11790
2896
image torch.Size([3, 112, 112])
label tensor(9)


Load digiface pretrain data

In [4]:
pretrain_data_path = 'data/digiface_subjects_0-1999_72_imgs'

pretrain_data = datasets.ImageFolder(pretrain_data_path, transform=train_transform)
pretrain_loader = torch.utils.data.DataLoader(
    pretrain_data, shuffle=True, batch_size=batch_size)

Load casia pretrain data

In [5]:
pretrain_data_path_casia = 'data/casia-144000'

pretrain_data_casia = datasets.ImageFolder(pretrain_data_path_casia, transform=train_transform)
pretrain_loader_casia = torch.utils.data.DataLoader(
    pretrain_data_casia, shuffle=True, batch_size=batch_size)

# Init the model

In [6]:
from torchvision.models import resnet18, ResNet18_Weights
from torch import nn, optim
import os
model = resnet18()

device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"

# Train and validate the model

functions for training and validating

In [7]:
import time
import datetime
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt

def get_accuracy(model: nn.Module, train=False):
    correct = 0
    total = 0
    n = 0
    loader = train_loader if train else test_loader 
    predictions, labels = [], []
    with torch.no_grad():
        for imgs, labels in iter(loader):
            imgs, labels = imgs.to(device), labels.to(device)
            model.eval()
            output = model(imgs).data
            pred = output.max(1)[1]
            correct += pred.eq(labels.data).sum().item()
            total += imgs.shape[0]
            predictions.append(pred)
            labels.append(labels)
            n += 1
    return correct / total, predictions, labels

def plot_training_curve(iters, losses, batches, train_acc, val_acc):
    plt.title("Learning Curve")
    plt.plot(iters, losses, label="Train")
    plt.xlabel("Iterations")
    plt.ylabel("Loss")
    #save plot to file
    time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    plt.savefig("figures/loss_curve_{}.png".format(time))

    plt.title("Learning Curve")
    plt.plot(batches, train_acc, label="Train")
    plt.plot(batches, val_acc, label="Validation")
    plt.xlabel("Iterations")
    plt.ylabel("Training Accuracy")
    plt.legend(loc='best')
    time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    plt.savefig("figures/learning_curve_{}.png".format(time))

def train(
    learning_rate = 0.1,
    num_epochs = 22,
    weight_decay = 0.0,
    momentum = 0.9,
    output_to_file = True,
    always_output = True, 
    scheduling = False,
    lr_milestones = [8, 16, 20],
    lr_gamma = 0.1,
    eval = True,
    data_loader = train_loader,
    loss_output_mod = 10,
    filename_prefix = 'model',
    savestate = True,
    ):

    loss_fn = nn.CrossEntropyLoss()
    optimizer = optim.SGD(
        model.parameters(),
        lr=learning_rate,
        momentum=momentum,
        weight_decay=weight_decay
        )

    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, milestones=lr_milestones, gamma=lr_gamma)

    if output_to_file:
        outputfile = open('output_'+datetime.datetime.now().strftime("%Y%m%d-%H%M%S")+'.txt', 'w')

    def output(text):
        if output_to_file:
            outputfile.write(text + '\n')
        else:
            print(text)

    iters, losses, train_acc, val_acc = [], [], [], []

    n = 0
    for epoch in range(num_epochs):
        epoch_tic = time.perf_counter()
        for imgs, labels in iter(data_loader):
            tic = time.perf_counter()
            imgs = imgs.to(device)
            labels = labels.to(device)


            model.train()

            out = model(imgs)
            print(out.shape)
            print(labels.shape)
            
            loss = loss_fn(out, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            toc = time.perf_counter()

            if (n % loss_output_mod == 0) or always_output:
                output('epoch: {}, iter: {}, loss: {}, time: {}'.format(epoch, n, loss, toc - tic))
            if eval:
                iters.append(n)
                losses.append(float(loss)/batch_size) # compute *average* loss
            n += 1

        if eval: 
            curr_train_acc = get_accuracy(model, train=True)[0]
            curr_val_acc = get_accuracy(model, train=False)[0]
            train_acc.append(curr_train_acc) # compute training accuracy 
            val_acc.append(curr_val_acc)  # compute validation accuracy
        
        if savestate:
            torch.save(model.state_dict(), 'model-states/' + filename_prefix + '_epoch_' + str(epoch) + '.pt')
        
        if scheduling:
            lr_scheduler.step() 

        epoch_toc = time.perf_counter()
        output('epoch: {}, time: {}'.format(epoch, epoch_toc - epoch_tic))
        if output_to_file:
            outputfile.flush()

    if eval: 
        output("Final Training Accuracy: {}".format(train_acc[-1]))
        output("Final Validation Accuracy: {}".format(val_acc[-1]))
        plot_training_curve(iters, losses, train_acc, val_acc)
    else: 
        torch.save(model.state_dict(), 'model-states/' + filename_prefix + '_final.pt')
        output("Final Training Accuracy: {}".format(get_accuracy(model, train=True)))
        # output("Final Validation Accuracy: {}".format(get_accuracy(model, train=False)))
    
    if output_to_file:
        outputfile.close()
    
    if eval:
        return get_accuracy(model, train=True)


Perform experiment

Pretrain on digiface

In [8]:
import copy

num_classes_pretrain = 2000
model.fc = nn.Linear(512, num_classes_pretrain)
model.to(device)

train(
    eval=False,
    data_loader=pretrain_loader,
    learning_rate=1e-3,
    num_epochs=20,
    output_to_file=False,
    always_output=False,
    scheduling=False,
    loss_output_mod=100,
    savestate=False,
    filename_prefix='digiface_pretrained',
      )

pretrained_model_digiface = copy.deepcopy(model)
#load state dict from model-states/digiface_pretrained_final.pt
pretrained_model_digiface.load_state_dict(torch.load('model-states/digiface_pretrained_final.pt'))

torch.Size([256, 2000])
torch.Size([256])
epoch: 0, iter: 0, loss: 7.707782745361328, time: 0.779980915998749
torch.Size([256, 2000])
torch.Size([256])
torch.Size([256, 2000])
torch.Size([256])
torch.Size([256, 2000])
torch.Size([256])
torch.Size([256, 2000])
torch.Size([256])
torch.Size([256, 2000])
torch.Size([256])
torch.Size([256, 2000])
torch.Size([256])
torch.Size([256, 2000])
torch.Size([256])
torch.Size([256, 2000])
torch.Size([256])
torch.Size([256, 2000])
torch.Size([256])
torch.Size([256, 2000])
torch.Size([256])
torch.Size([256, 2000])
torch.Size([256])
torch.Size([256, 2000])
torch.Size([256])
torch.Size([256, 2000])
torch.Size([256])
torch.Size([256, 2000])
torch.Size([256])
torch.Size([256, 2000])
torch.Size([256])
torch.Size([256, 2000])
torch.Size([256])
torch.Size([256, 2000])
torch.Size([256])
torch.Size([256, 2000])
torch.Size([256])
torch.Size([256, 2000])
torch.Size([256])
torch.Size([256, 2000])
torch.Size([256])
torch.Size([256, 2000])
torch.Size([256])
torch.Si

KeyboardInterrupt: 

Train and eval on first 100 casia

In [None]:
model = pretrained_model_digiface
model.fc = nn.Linear(512, 100)
model.to(device)

acc, predictions, labels = train(
    num_epochs=20,
    learning_rate=0.1,
    eval=True,
    data_loader=train_loader,
    output_to_file=False,
    always_output=True,
    scheduling=True,
    filename_prefix='digiface_casia',
    savestate=True,
    )

Pretrain on casia

In [None]:
num_classes_pretrain = len(os.listdir('data/casia-144000/'))
print(num_classes_pretrain)

1994


In [None]:
import copy

model = resnet18()
model.fc = nn.Linear(512, num_classes_pretrain)
model.to(device)

train(
    eval=False,
    data_loader=pretrain_loader_casia,
    learning_rate=1e-3,
    num_epochs=20,
    output_to_file=False,
    always_output=False,
    scheduling=False,
    loss_output_mod=100,
    savestate=False,
    filename_prefix='casia_pretrained',
      )

pretrained_model_casia = copy.deepcopy(model)
pretrained_model_casia.load_state_dict(torch.load('model-states/casia_pretrained_final.pt'))

In [None]:
model = pretrained_model_casia
model.fc = nn.Linear(512, 100)
model.to(device)

acc, predictions, labels = train(
    num_epochs=20,
    learning_rate=0.1,
    eval=True,
    data_loader=train_loader,
    output_to_file=False,
    always_output=True,
    scheduling=True,
    filename_prefix='casia_casia',
    savestate=True,
    )

Load a resnet model with imagenet weights and train/eval again

In [None]:
model = resnet18(weights = ResNet18_Weights.IMAGENET1K_V1)
model.fc = nn.Linear(512, 100)
model.to(device)

acc, predictions, labels = train(
    num_epochs=20,
    learning_rate=0.1,
    eval=True,
    data_loader=train_loader,
    output_to_file=False,
    always_output=True,
    scheduling=True,
    filename_prefix='imagenet_casia',
    savestate=True,
    )

Load a clean resnet18 without any weights and train/eval again

In [None]:
model = resnet18()
model.fc = nn.Linear(512, 100)
model.to(device)
acc, predictions, labels = train(
    num_epochs=20,
    learning_rate=0.1,
    eval=True,
    data_loader=train_loader,
    output_to_file=False,
    always_output=True,
    scheduling=True,
    filename_prefix='imagenet_casia',
    savestate=True,
    )

Reset experminent