# CNN from Pretrained encoder
## Setup

In [1]:
import torch
import torch.nn as nn
from tensorboardX import SummaryWriter
from pa2_sample_code import get_datasets

train_data, eval_data = get_datasets()

# split train set into holdout_train and holdout_eval sets
holdout_train_len = int(len(train_data) * 0.8)
holdout_eval_len = len(train_data) - holdout_train_len
holdout_train_data, holdout_eval_data = torch.utils.data.random_split(train_data, [holdout_train_len, holdout_eval_len])


## Define Model

In [2]:
class CnnFromPretrained(nn.Module):
    def __init__(self, n_hidden):
        super(CnnFromPretrained, self).__init__()
        
        self.encoder = torch.load('pretrained_encoder.pt')['model']
        
        self.predictor = nn.Sequential(
            nn.Linear(in_features=32, out_features=n_hidden),
            nn.ReLU(),
            nn.Linear(in_features=n_hidden, out_features=47)
        )
        
        self.loss_func = nn.CrossEntropyLoss(reduction='sum')
        
    def forward(self, in_data):
        img_features = self.encoder(in_data).view(in_data.size(0), 32)
        logits = self.predictor(img_features)
        return logits

    def loss(self, logits, labels):
        return self.loss_func(logits, labels) / logits.size(0)
    
    def top_k_acc(self, logits, labels, k=1):
        _, k_labels_pred = torch.topk(logits, k=k, dim=1) # shape (n, k)
        k_labels = labels.unsqueeze(dim=1).expand(-1, k) # broadcast from (n) to (n, 1) to (n, k)
        # flatten tensors for comparison
        k_labels_pred_flat = k_labels_pred.reshape(1,-1).squeeze()
        k_labels_flat = k_labels.reshape(1,-1).squeeze()
        # get num_correct in float
        num_correct = k_labels_pred_flat.eq(k_labels_flat).sum(0).float().item()
        return num_correct / labels.size(0)
        

## Refactored runner function

In [3]:
def run(model, loaders, optimizer, writer, num_epoch=10, device='cpu'):
    def run_epoch(mode):
        epoch_loss = 0.0
        epoch_top1 = 0.0
        epoch_top3 = 0.0
        for i, batch in enumerate(loaders[mode], 0):
            in_data, labels = batch
            in_data, labels = in_data.to(device), labels.to(device)

            if mode == 'train':
                optimizer.zero_grad()

            logits = model(in_data)
            batch_loss = model.loss(logits, labels)
            batch_top1 = model.top_k_acc(logits, labels, k=1)
            batch_top3 = model.top_k_acc(logits, labels, k=3)

            epoch_loss += batch_loss.item()
            epoch_top1 += batch_top1
            epoch_top3 += batch_top3

            if mode == 'train':
                batch_loss.backward()
                optimizer.step()

        # sum of all batchs / num of batches
        epoch_loss /= i + 1 
        epoch_top1 /= i + 1
        epoch_top3 /= i + 1
        
        print('epoch %d %s loss %.4f top1 %.4f top3 %.4f' % (epoch, mode, epoch_loss, epoch_top1, epoch_top3))
        # log to tensorboard
        if not (writer is None):
            writer.add_scalars('%s_loss' % model.__class__.__name__,
                         tag_scalar_dict={mode: epoch_loss}, 
                         global_step=epoch)
            writer.add_scalars('%s_top1' % model.__class__.__name__,
                         tag_scalar_dict={mode: epoch_top1}, 
                         global_step=epoch)
            writer.add_scalars('%s_top3' % model.__class__.__name__,
                         tag_scalar_dict={mode: epoch_top3}, 
                         global_step=epoch)
    for epoch in range(num_epoch):
        run_epoch('train')
        run_epoch('eval')

## Holdout validation for choosing hyper-parameters

In [4]:
for n_hidden in [32, 64]:
    for optim_conf in [
        {'optim':'adam', 'lr':0.001},
        {'optim':'sgd', 'lr':0.1},
        {'optim':'sgd', 'lr':0.01}
    ]:
        model = CnnFromPretrained(n_hidden=n_hidden)
        if optim_conf['optim'] == 'adam':
            optimizer = torch.optim.Adam(model.parameters(), lr=optim_conf['lr'])
        else:
            optimizer = torch.optim.SGD(model.parameters(), lr=optim_conf['lr'])
        conf_str = str(n_hidden)+'_'+optim_conf['optim']+'_'+str(optim_conf['lr'])
        print(conf_str)
        run(
            model=model,
            loaders={
                'train': torch.utils.data.DataLoader(holdout_train_data, batch_size=32, shuffle=True),
                'eval': torch.utils.data.DataLoader(holdout_eval_data, batch_size=32, shuffle=True)
            },
            optimizer=optimizer, 
            writer=SummaryWriter('./logs/cnn_pretrained/%s' % (conf_str)), 
            num_epoch=10, 
            device='cpu'
        )

32_adam_0.001
epoch 0 train loss 3.0870 top1 0.2599 top3 0.4856
epoch 0 eval loss 2.3301 top1 0.4100 top3 0.6902
epoch 1 train loss 1.9963 top1 0.4715 top3 0.7420
epoch 1 eval loss 1.7845 top1 0.5110 top3 0.7787
epoch 2 train loss 1.6394 top1 0.5436 top3 0.7996
epoch 2 eval loss 1.5478 top1 0.5585 top3 0.8188
epoch 3 train loss 1.4738 top1 0.5769 top3 0.8260
epoch 3 eval loss 1.4218 top1 0.5889 top3 0.8335
epoch 4 train loss 1.3811 top1 0.5967 top3 0.8401
epoch 4 eval loss 1.3566 top1 0.6036 top3 0.8475
epoch 5 train loss 1.3201 top1 0.6115 top3 0.8478
epoch 5 eval loss 1.3087 top1 0.6130 top3 0.8548
epoch 6 train loss 1.2764 top1 0.6230 top3 0.8543
epoch 6 eval loss 1.2644 top1 0.6258 top3 0.8608
epoch 7 train loss 1.2434 top1 0.6300 top3 0.8605
epoch 7 eval loss 1.2343 top1 0.6368 top3 0.8642
epoch 8 train loss 1.2188 top1 0.6368 top3 0.8637
epoch 8 eval loss 1.2125 top1 0.6384 top3 0.8685
epoch 9 train loss 1.1974 top1 0.6428 top3 0.8675
epoch 9 eval loss 1.1922 top1 0.6445 top3 0.8

## Training final model
Hyper-parameter selected:
- Hidden layers: 64
- Optimization: SGD, learning rate 0.1

In [6]:
for i in range(5):
    model = CnnFromPretrained(n_hidden=64)
    optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
    conf_str = 'final_'+str(i)
    print(conf_str)
    run(
        model=model,
        loaders={
            'train': torch.utils.data.DataLoader(train_data, batch_size=32, shuffle=True),
            'eval': torch.utils.data.DataLoader(eval_data, batch_size=32, shuffle=True)
        },
        optimizer=optimizer, 
        writer=SummaryWriter('./logs/cnn_pretrained/%s' % (conf_str)), 
        num_epoch=20, 
        device='cpu'
    )

final_0
epoch 0 train loss 2.8651 top1 0.3119 top3 0.5411
epoch 0 eval loss 1.8038 top1 0.5174 top3 0.7787
epoch 1 train loss 1.4926 top1 0.5711 top3 0.8227
epoch 1 eval loss 1.3382 top1 0.6087 top3 0.8438
epoch 2 train loss 1.2616 top1 0.6235 top3 0.8580
epoch 2 eval loss 1.2344 top1 0.6306 top3 0.8618
epoch 3 train loss 1.1841 top1 0.6424 top3 0.8697
epoch 3 eval loss 1.1815 top1 0.6436 top3 0.8739
epoch 4 train loss 1.1422 top1 0.6517 top3 0.8777
epoch 4 eval loss 1.1355 top1 0.6635 top3 0.8787
epoch 5 train loss 1.1138 top1 0.6629 top3 0.8821
epoch 5 eval loss 1.1265 top1 0.6592 top3 0.8803
epoch 6 train loss 1.0933 top1 0.6688 top3 0.8842
epoch 6 eval loss 1.1069 top1 0.6617 top3 0.8868
epoch 7 train loss 1.0763 top1 0.6714 top3 0.8888
epoch 7 eval loss 1.0773 top1 0.6740 top3 0.8864
epoch 8 train loss 1.0607 top1 0.6762 top3 0.8920
epoch 8 eval loss 1.0669 top1 0.6785 top3 0.8877
epoch 9 train loss 1.0451 top1 0.6797 top3 0.8945
epoch 9 eval loss 1.0448 top1 0.6852 top3 0.8924
ep

epoch 2 train loss 1.2560 top1 0.6254 top3 0.8587
epoch 2 eval loss 1.2397 top1 0.6302 top3 0.8591
epoch 3 train loss 1.1766 top1 0.6462 top3 0.8712
epoch 3 eval loss 1.1538 top1 0.6559 top3 0.8722
epoch 4 train loss 1.1322 top1 0.6567 top3 0.8784
epoch 4 eval loss 1.1412 top1 0.6582 top3 0.8766
epoch 5 train loss 1.1027 top1 0.6664 top3 0.8838
epoch 5 eval loss 1.1308 top1 0.6561 top3 0.8799
epoch 6 train loss 1.0793 top1 0.6718 top3 0.8872
epoch 6 eval loss 1.1070 top1 0.6609 top3 0.8860
epoch 7 train loss 1.0589 top1 0.6769 top3 0.8911
epoch 7 eval loss 1.0695 top1 0.6718 top3 0.8897
epoch 8 train loss 1.0388 top1 0.6815 top3 0.8944
epoch 8 eval loss 1.0615 top1 0.6774 top3 0.8931
epoch 9 train loss 1.0233 top1 0.6855 top3 0.8973
epoch 9 eval loss 1.0376 top1 0.6854 top3 0.8953
epoch 10 train loss 1.0060 top1 0.6896 top3 0.9001
epoch 10 eval loss 1.0197 top1 0.6857 top3 0.8985
epoch 11 train loss 0.9926 top1 0.6922 top3 0.9030
epoch 11 eval loss 1.0287 top1 0.6867 top3 0.8952
epoch 