In [None]:
import torch
from torchvision import transforms, datasets, models
from torch import nn, optim
import numpy as np
import matplotlib.pyplot as plt
import os
import torch.nn.functional as F
from IPython.display import clear_output
from time import sleep

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

## Create Model

In [None]:
model = models.resnet50()
model.fc = nn.Linear(2048,100)

model_file = 'model.pth'
state_dict = torch.load(model_file)
model.load_state_dict(state_dict)

linear_reg = nn.Linear(100,10)
model_ds = nn.Sequential(model, linear_reg)
#print(model)

model_ds.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(linear_reg.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=70, gamma=0.1)

## Load Synthetic MNIST

In [None]:
transform = transforms.Compose([transforms.Resize(245),
                                      transforms.CenterCrop(244),
                                      transforms.ToTensor(),
                                      transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])
train_set = datasets.ImageFolder('/scratch/vvg239/headcam/synthetic_digits/imgs_train', transform=transform)
val_set = datasets.ImageFolder('/scratch/vvg239/headcam/synthetic_digits/imgs_valid', transform=transform)
print(len(train_set))

train_loader = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)

val_loader = torch.utils.data.DataLoader(val_set, batch_size=64, shuffle=True)

## Train the Network

In [None]:
def train(epoch):
    model_ds.train()
    running_loss = 0
    lr = optimizer.param_groups[0]['lr']
    for images, labels  in train_loader:
        optimizer.zero_grad()
        images = images.to(device)
        preds = model_ds(images)
        labels = labels.to(device)
        loss = criterion(preds, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print("Epoch: %3d | Running Loss: %f | LR : %.9f" 
          % (epoch, running_loss, lr))
    return running_loss


def validation():
    model_ds.eval()
    scheduler.step()
    correct = 0
    with torch.no_grad():
        for data, target in val_loader:
            data, target = data.to(device), target.to(device)
            output = model_ds(data)
            pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
            correct += pred.eq(target.data.view_as(pred)).to(device).sum()
    return correct.item()

In [None]:
epochs = 50
#max_accuracy = 0
train_losses = []
for epoch in range(epochs):
    train_losses.append(train(epoch))
    if epoch%5 == 0:
        accuracy = validation()
        print("Correct Samples = %d || Accuracy = %.2f" %(accuracy, accuracy/len(val_set)))

plt.plot(train_losses)