# CAE from pretrained encoder
## Setup

In [1]:
import torch
import torch.nn as nn
from torchvision.utils import make_grid
from tensorboardX import SummaryWriter
from pa2_sample_code import get_datasets

train_data, eval_data = get_datasets()

# split train set into holdout_train and holdout_eval sets
holdout_train_len = int(len(train_data) * 0.8)
holdout_eval_len = len(train_data) - holdout_train_len
holdout_train_data, holdout_eval_data = torch.utils.data.random_split(train_data, [holdout_train_len, holdout_eval_len])

## Define model

In [2]:
class CaeFromPretrained(nn.Module):
    def __init__(self):
        super(CaeFromPretrained, self).__init__()
        
        self.encoder = torch.load('pretrained_encoder.pt')['model']
        
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(in_channels=32, out_channels=16, kernel_size=3, stride=1, padding=0),
            nn.ReLU(),
            nn.ConvTranspose2d(in_channels=16, out_channels=8, kernel_size=3, stride=1, padding=0),
            nn.ReLU(),
            nn.ConvTranspose2d(in_channels=8, out_channels=8, kernel_size=3, stride=2, padding=0),
            nn.ReLU(),
            nn.ConvTranspose2d(in_channels=8, out_channels=4, kernel_size=3, stride=1, padding=0),
            nn.ReLU(),
            nn.ConvTranspose2d(in_channels=4, out_channels=1, kernel_size=4, stride=2, padding=0),
            nn.Sigmoid()
        )
        
        self.loss_func = nn.MSELoss(reduction='sum')
        
    def forward(self, in_data):
        img_features = self.encoder(in_data)
        logits = self.decoder(img_features)
        return logits

    def loss(self, logits, in_data):
        return self.loss_func(logits, in_data) / logits.size(0)

## Refactored runner function

In [3]:
def run(model, loaders, optimizer, writer, num_epoch=10, device='cpu'):
    def run_epoch(mode):
        epoch_loss = 0.0
        for i, batch in enumerate(loaders[mode], 0):
            in_data, labels = batch
            in_data, labels = in_data.to(device), labels.to(device)

            if mode == 'train':
                optimizer.zero_grad()

            logits = model(in_data)
            batch_loss = model.loss(logits, in_data)

            epoch_loss += batch_loss.item()

            if mode == 'train':
                batch_loss.backward()
                optimizer.step()

        # sum of all batchs / num of batches
        epoch_loss /= i + 1 
        
        print('epoch %d %s loss %.4f' % (epoch, mode, epoch_loss))
        # log to tensorboard
        if not (writer is None):
            writer.add_scalars('%s_loss' % model.__class__.__name__,
                         tag_scalar_dict={mode: epoch_loss}, 
                         global_step=epoch)
            # log image
            if mode == 'eval':
                if len(in_data.size()) == 2: # reshape if flattened
                    in_data = in_data.view(-1, 1, 28, 28)
                if len(logits.size()) == 2: # reshape if flattened
                    logits = logits.view(-1, 1, 28, 28)
                img_grid = make_grid(in_data.to('cpu'))
                writer.add_image('%s/eval_input' % model.__class__.__name__, img_grid, epoch)
                img_grid = make_grid(logits.to('cpu'))
                writer.add_image('%s/eval_output' % model.__class__.__name__, img_grid, epoch)
    for epoch in range(num_epoch):
        run_epoch('train')
        run_epoch('eval')

## Holdout Validation for choosing hyper-parameters

In [4]:
for optim_conf in [
    {'optim':'adam', 'lr':0.001},
    {'optim':'sgd', 'lr':0.1},
    {'optim':'sgd', 'lr':0.01}
]:
    model = CaeFromPretrained()
    if optim_conf['optim'] == 'adam':
        optimizer = torch.optim.Adam(model.parameters(), lr=optim_conf['lr'])
    else:
        optimizer = torch.optim.SGD(model.parameters(), lr=optim_conf['lr'])
    conf_str = optim_conf['optim']+'_'+str(optim_conf['lr'])
    print(conf_str)
    run(
        model=model,
        loaders={
            'train': torch.utils.data.DataLoader(holdout_train_data, batch_size=32, shuffle=True),
            'eval': torch.utils.data.DataLoader(holdout_eval_data, batch_size=32, shuffle=True)
        },
        optimizer=optimizer, 
        writer=SummaryWriter('./logs/cae/%s' % (conf_str)), 
        num_epoch=10, 
        device='cpu'
    )

adam_0.001
epoch 0 train loss 57.0275
epoch 0 eval loss 37.5400
epoch 1 train loss 34.3430
epoch 1 eval loss 32.0027
epoch 2 train loss 30.9334
epoch 2 eval loss 29.8890
epoch 3 train loss 29.1871
epoch 3 eval loss 28.4027
epoch 4 train loss 28.1395
epoch 4 eval loss 27.5926
epoch 5 train loss 27.4126
epoch 5 eval loss 26.9978
epoch 6 train loss 26.8849
epoch 6 eval loss 26.5602
epoch 7 train loss 26.4478
epoch 7 eval loss 26.2031
epoch 8 train loss 25.8842
epoch 8 eval loss 25.4438
epoch 9 train loss 25.0411
epoch 9 eval loss 23.9046
sgd_0.1
epoch 0 train loss 111.0752
epoch 0 eval loss 111.1485
epoch 1 train loss 111.0128
epoch 1 eval loss 111.1485
epoch 2 train loss 111.0127
epoch 2 eval loss 111.1485
epoch 3 train loss 111.0127
epoch 3 eval loss 111.1484
epoch 4 train loss 111.0127
epoch 4 eval loss 111.1484
epoch 5 train loss 111.0127
epoch 5 eval loss 111.1484
epoch 6 train loss 111.0127
epoch 6 eval loss 111.1484
epoch 7 train loss 111.0126
epoch 7 eval loss 111.1483
epoch 8 tra

## Training final model
Hyper-parameters:
Optimization: SGD, learning rate 0.01

In [7]:
model = CaeFromPretrained()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
conf_str = 'final'
print(conf_str)
run(
    model=model,
    loaders={
        'train': torch.utils.data.DataLoader(train_data, batch_size=32, shuffle=True),
        'eval': torch.utils.data.DataLoader(eval_data, batch_size=32, shuffle=True)
    },
    optimizer=optimizer, 
    writer=SummaryWriter('./logs/cae/%s' % (conf_str)), 
    num_epoch=15, 
    device='cpu'
)

final
epoch 0 train loss 50.6091
epoch 0 eval loss 34.6583
epoch 1 train loss 32.1074
epoch 1 eval loss 30.2909
epoch 2 train loss 29.0967
epoch 2 eval loss 28.4016
epoch 3 train loss 27.7402
epoch 3 eval loss 27.4155
epoch 4 train loss 26.8294
epoch 4 eval loss 26.3066
epoch 5 train loss 25.0460
epoch 5 eval loss 23.9013
epoch 6 train loss 23.5204
epoch 6 eval loss 23.2212
epoch 7 train loss 23.0278
epoch 7 eval loss 22.9713
epoch 8 train loss 22.6826
epoch 8 eval loss 22.7582
epoch 9 train loss 22.4043
epoch 9 eval loss 22.5769
epoch 10 train loss 22.1619
epoch 10 eval loss 22.1497
epoch 11 train loss 21.9616
epoch 11 eval loss 22.0626
epoch 12 train loss 21.7847
epoch 12 eval loss 21.8373
epoch 13 train loss 21.6437
epoch 13 eval loss 21.6873
epoch 14 train loss 21.5191
epoch 14 eval loss 21.6005
