# CAE from pretrained encoder
## Setup

In [1]:
import torch
import torch.nn as nn
from torchvision.utils import make_grid
from tensorboardX import SummaryWriter
from pa2_sample_code import get_datasets

train_data, eval_data = get_datasets()

# split train set into holdout_train and holdout_eval sets
holdout_train_len = int(len(train_data) * 0.8)
holdout_eval_len = len(train_data) - holdout_train_len
holdout_train_data, holdout_eval_data = torch.utils.data.random_split(train_data, [holdout_train_len, holdout_eval_len])

## Define model

In [2]:
class CaeFromPretrained(nn.Module):
    def __init__(self):
        super(CaeFromPretrained, self).__init__()
        
        self.encoder = torch.load('pretrained_encoder.pt')['model']
        
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(in_channels=32, out_channels=16, kernel_size=3, stride=1, padding=0),
            nn.ReLU(),
            nn.ConvTranspose2d(in_channels=16, out_channels=8, kernel_size=3, stride=1, padding=0),
            nn.ReLU(),
            nn.ConvTranspose2d(in_channels=8, out_channels=8, kernel_size=3, stride=2, padding=0),
            nn.ReLU(),
            nn.ConvTranspose2d(in_channels=8, out_channels=4, kernel_size=3, stride=1, padding=0),
            nn.ReLU(),
            nn.ConvTranspose2d(in_channels=4, out_channels=1, kernel_size=4, stride=2, padding=0),
            nn.Sigmoid()
        )
        
        self.loss_func = nn.MSELoss(reduction='sum')
        
    def forward(self, in_data):
        img_features = self.encoder(in_data)
        logits = self.decoder(img_features)
        return logits

    def loss(self, logits, in_data):
        return self.loss_func(logits, in_data) / logits.size(0)

## Refactored runner function

In [3]:
def run(model, loaders, optimizer, writer, num_epoch=10, device='cpu'):
    def run_epoch(mode):
        epoch_loss = 0.0
        for i, batch in enumerate(loaders[mode], 0):
            in_data, labels = batch
            in_data, labels = in_data.to(device), labels.to(device)

            if mode == 'train':
                optimizer.zero_grad()

            logits = model(in_data)
            batch_loss = model.loss(logits, in_data)

            epoch_loss += batch_loss.item()

            if mode == 'train':
                batch_loss.backward()
                optimizer.step()

        # sum of all batchs / num of batches
        epoch_loss /= i + 1 
        
        print('epoch %d %s loss %.4f' % (epoch, mode, epoch_loss))
        # log to tensorboard
        if not (writer is None):
            writer.add_scalars('%s_loss' % model.__class__.__name__,
                         tag_scalar_dict={mode: epoch_loss}, 
                         global_step=epoch)
            # log image every 5 epoch
            if mode == 'eval':
                img_grid = make_grid(in_data.to('cpu'))
                writer.add_image('%s/eval_input' % model.__class__.__name__, img_grid, epoch)
                img_grid = make_grid(logits.to('cpu'))
                writer.add_image('%s/eval_output' % model.__class__.__name__, img_grid, epoch)
    for epoch in range(num_epoch):
        run_epoch('train')
        run_epoch('eval')

## Holdout Validation for choosing hyper-parameters

In [4]:
for optim_conf in [
    {'optim':'adam', 'lr':0.001},
    {'optim':'sgd', 'lr':0.1},
    {'optim':'sgd', 'lr':0.01}
]:
    model = CaeFromPretrained()
    if optim_conf['optim'] == 'adam':
        optimizer = torch.optim.Adam(model.parameters(), lr=optim_conf['lr'])
    else:
        optimizer = torch.optim.SGD(model.parameters(), lr=optim_conf['lr'])
    conf_str = optim_conf['optim']+'_'+str(optim_conf['lr'])
    print(conf_str)
    run(
        model=model,
        loaders={
            'train': torch.utils.data.DataLoader(holdout_train_data, batch_size=32, shuffle=True),
            'eval': torch.utils.data.DataLoader(holdout_eval_data, batch_size=32, shuffle=True)
        },
        optimizer=optimizer, 
        writer=SummaryWriter('./logs/cae/%s' % (conf_str)), 
        num_epoch=10, 
        device='cpu'
    )

adam_0.001
epoch 0 train loss 59.8725
epoch 0 eval loss 36.5190
epoch 1 train loss 32.9882
epoch 1 eval loss 30.6163
epoch 2 train loss 29.1995
epoch 2 eval loss 28.0643
epoch 3 train loss 27.2804
epoch 3 eval loss 26.5362
epoch 4 train loss 26.0952
epoch 4 eval loss 25.5648
epoch 5 train loss 25.3106
epoch 5 eval loss 25.0196
epoch 6 train loss 24.7278
epoch 6 eval loss 24.4132
epoch 7 train loss 24.2700
epoch 7 eval loss 24.0777
epoch 8 train loss 23.8909
epoch 8 eval loss 23.9027
epoch 9 train loss 23.5714
epoch 9 eval loss 23.6148
sgd_0.1
epoch 0 train loss 111.2021
epoch 0 eval loss 110.6310
epoch 1 train loss 111.1424
epoch 1 eval loss 110.6310
epoch 2 train loss 111.1424
epoch 2 eval loss 110.6310
epoch 3 train loss 111.1424
epoch 3 eval loss 110.6310
epoch 4 train loss 111.1424
epoch 4 eval loss 110.6310
epoch 5 train loss 111.1424
epoch 5 eval loss 110.6310
epoch 6 train loss 111.1424
epoch 6 eval loss 110.6310
epoch 7 train loss 111.1424
epoch 7 eval loss 110.6310
epoch 8 tra

## Training final model
Hyper-parameters:
Optimization: SGD, learning rate 0.01

In [5]:
model = CaeFromPretrained()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
conf_str = 'final'
print(conf_str)
run(
    model=model,
    loaders={
        'train': torch.utils.data.DataLoader(train_data, batch_size=32, shuffle=True),
        'eval': torch.utils.data.DataLoader(eval_data, batch_size=32, shuffle=True)
    },
    optimizer=optimizer, 
    writer=SummaryWriter('./logs/cae/%s' % (conf_str)), 
    num_epoch=20, 
    device='cpu'
)

final
epoch 0 train loss 56.8960
epoch 0 eval loss 37.1465
epoch 1 train loss 30.5123
epoch 1 eval loss 28.4783
epoch 2 train loss 25.9272
epoch 2 eval loss 25.6136
epoch 3 train loss 24.1849
epoch 3 eval loss 22.5642
epoch 4 train loss 23.0729
epoch 4 eval loss 22.3533
epoch 5 train loss 22.2945
epoch 5 eval loss 21.4649
epoch 6 train loss 21.7284
epoch 6 eval loss 21.8093
epoch 7 train loss 21.2337
epoch 7 eval loss 21.3707
epoch 8 train loss 20.8897
epoch 8 eval loss 21.1576
epoch 9 train loss 20.5797
epoch 9 eval loss 20.4619
epoch 10 train loss 20.3135
epoch 10 eval loss 19.9867
epoch 11 train loss 20.1514
epoch 11 eval loss 19.8553
epoch 12 train loss 19.9933
epoch 12 eval loss 20.2345
epoch 13 train loss 19.9044
epoch 13 eval loss 19.9957
epoch 14 train loss 19.8460
epoch 14 eval loss 19.6911
epoch 15 train loss 19.7230
epoch 15 eval loss 19.4345
epoch 16 train loss 19.6655
epoch 16 eval loss 19.3933
epoch 17 train loss 19.5704
epoch 17 eval loss 19.5384
epoch 18 train loss 19.5