In [1]:
import argparse
import os
import time
import shutil
from collections import OrderedDict
from tqdm.auto import tqdm

import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

import torch
import torch.nn.functional as F
import torch.optim as optim
import torch.nn as nn
import torch.utils.model_zoo as model_zoo

from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# required to download pretrained model
import ssl
ssl._create_default_https_context = ssl._create_unverified_context

In [2]:
def ensure_dir(path):
    if not os.path.exists(path):
        print("Creating folder {}".format(path))
        os.makedirs(path)

In [3]:
def load_data(batch_size, data_root, num_workers=1):
    
    def _load_data(data_root, train, batch_size):
        return torch.utils.data.DataLoader(
        datasets.MNIST(root=data_root, train=train, download=True,
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])),
        batch_size=batch_size, shuffle=True)
        
        
    train_loader = _load_data(data_root, True, batch_size)
    test_loader = _load_data(data_root, False, batch_size)
    
    return train_loader, test_loader

train_loader, test_loader = load_data(batch_size=200, 
                                      data_root='tmp/public_dataset/pytorch/mnist-data', 
                                      num_workers=1)

  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


In [4]:
model_urls = {
    'mnist': 'http://ml.cs.tsinghua.edu.cn/~chenxi/pytorch-models/mnist-b07bb66b.pth'
}

class MLP(nn.Module):
    def __init__(self, input_dims, n_hiddens, n_class, display=False):
        super(MLP, self).__init__()
        assert isinstance(input_dims, int), 'Please provide int for input_dims'
        self.input_dims = input_dims
        current_dims = input_dims
        layers = OrderedDict()

        if isinstance(n_hiddens, int):
            n_hiddens = [n_hiddens]
        else:
            n_hiddens = list(n_hiddens)
        for i, n_hidden in enumerate(n_hiddens):
            layers['fc{}'.format(i+1)] = nn.Linear(current_dims, n_hidden)
            layers['relu{}'.format(i+1)] = nn.ReLU()
#             layers['drop{}'.format(i+1)] = nn.Dropout(0.2)
            current_dims = n_hidden
        layers['out'] = nn.Linear(current_dims, n_class)

        self.model= nn.Sequential(layers)
        if display:
            print(self.model)

    def forward(self, input):
        input = input.view(input.size(0), -1)
        assert input.size(1) == self.input_dims
        return self.model.forward(input)

def mnist(input_dims=784, n_hiddens=[256, 256], n_class=10, pretrained=False):
    model = MLP(input_dims, n_hiddens, n_class)
    if pretrained:
        print('Loading pretrained model')
        m = model_zoo.load_url(model_urls['mnist'], map_location=torch.device('cpu'))
        state_dict = m.state_dict() if isinstance(m, nn.Module) else m
        model.load_state_dict(state_dict)
    return model

In [5]:
def eval(model, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    for data, target in test_loader:
        indx_target = target.clone()

        with torch.no_grad():
            data, target = Variable(data), Variable(target)
            output = model(data)
            test_loss += F.cross_entropy(output, target).data
            pred = output.data.max(1)[1]  # get the index of the max log-probability
            correct += pred.eq(indx_target).sum()

    test_loss = test_loss / len(test_loader) # average over number of mini-batch
    acc = 100. * correct / len(test_loader.dataset)
    
    return {
        'test_loss':test_loss.item(),
        'test_acc':acc.item()
    }

In [14]:
def train_epoch(model, train_loader, optimizer):
    model.train()
    train_loss = 0
    correct = 0
    for batch_idx, (data, target) in tqdm(enumerate(train_loader), total=len(train_loader), leave=False):
        indx_target = target.clone()
        print('Max value: ',  data.max(), data[0].max())
        assert False
        data, target = Variable(data), Variable(target)

        optimizer.zero_grad()
        output = model(data)
        
        pred = output.data.max(1)[1]
        correct += pred.eq(indx_target).sum()
        
        loss = F.cross_entropy(output, target)
        
        loss.backward()
        train_loss += loss.data
        
        optimizer.step()
        
    train_loss = train_loss / len(train_loader)
    acc = 100. * correct / len(train_loader.dataset)
    
    return {
        'train_loss':train_loss.item(),
        'train_acc':acc.item()
    }

In [7]:
def model_snapshot(model, new_file, old_file=None):
    if isinstance(model, torch.nn.DataParallel):
        model = model.module
    if old_file and os.path.exists(old_file):
        os.remove(old_file)

    state_dict = OrderedDict()
    for k, v in model.state_dict().items():
        state_dict[k] = v
    torch.save(state_dict, new_file)

In [8]:
def train(model, train_loader, test_loader, logdir):
    ensure_dir(logdir)
    
    best_acc = 0
    old_file = None
    epochs = 30
    start_time = time.time()

    history = []
    
    optimizer = optim.SGD(model.parameters(), lr=0.01, weight_decay=0.0001, momentum=0.9)
    
    for epoch in tqdm(range(epochs)):

        train_result = train_epoch(model, train_loader, optimizer)
        test_result = eval(model, test_loader)

        history.append(train_result | test_result)

        if test_result['test_acc'] > best_acc:
            new_file = os.path.join(logdir, 'best-{}.pth'.format(epoch))
            model_snapshot(model, new_file, old_file=old_file)
            best_acc = test_result['test_acc']
            old_file = new_file
    return history           

In [15]:
pretrained_model = mnist(input_dims=784, n_hiddens=[256, 256], n_class=10, pretrained=True)
# torch.save(pretrained_model, 'pretrained_model')

pretrained_history = train(pretrained_model, train_loader, test_loader, logdir='pretrained')
# pre_df = pd.DataFrame(pretrained_history)

# torch.save(pretrained_model, 'pretrained_model_finetuned')

# pre_df.to_csv('pretrained_history.csv', index=False)
# pre_df.head()

Loading pretrained model


  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/300 [00:00<?, ?it/s]

Max value:  tensor(2.8215) tensor(2.8215)


AssertionError: 

In [None]:
new_model = mnist(input_dims=784, n_hiddens=[256, 256], n_class=10, pretrained=False)
new_history = train(new_model, train_loader, test_loader, logdir='new_model')
torch.save(new_model, 'new_model.data')
new_df = pd.DataFrame(new_history)
new_df.to_csv('not_pretrained_history.csv', index=False)
new_df.head()