# Semi-Supervised Learning of MNIST CNN
Jared Nielsen

### Steps to Semi-Supervised Metrics
- Separate MNIST into `dataset_A` and `dataset_B`. Hide the labels from `dataset_B`.
- Instantiate `model_A` and `model_B` of the same architecture.  
- Train `model_A` on `dataset_A`. Use `model_A` to predict the labels for `dataset_B`. 
- Train `model_B` on `dataset_B_augmented`. Use `model_B` to predict the labels for `dataset_A`.

### Issues
- If a CNN gets 98% test performance on MNIST, then it will get at worst 0.98^2 = 96% performance on transfer learning.

### Comparable Architectures
- MLP vs CNN?

### Ideas
- Should I use `mag` to serialize the models? **Yes.**
- Fashion-MNIST instead of MNIST?

In [1]:
import torch
import torchvision
import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, random_split, TensorDataset
from tqdm import tqdm, tqdm_notebook
from time import sleep

import mnist_cnn
from mnist_cnn import Net, model_A_path, model_B_path

### Load `dataset_A` and `dataset_B`

In [2]:
root_dir = "../data/"
batch_size_train = 64
batch_size_test = 1000
learning_rate = 0.01
momentum = 0.5
log_interval = 10

random_seed = 1
torch.backends.cudnn.enabled = False
torch.manual_seed(random_seed)

mnist_train = torchvision.datasets.MNIST(root_dir, train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ]))
n_digits = len(mnist_train)
dataset_A, dataset_B = random_split(mnist_train, [n_digits // 2, n_digits - n_digits // 2])
loader_A, loader_B = [DataLoader(dataset, batch_size=batch_size_train, shuffle=True) 
                      for dataset in (dataset_A, dataset_B)]

### Train `model_A` on `dataset_A`

In [4]:
model_A = Net()
opt_A = optim.SGD(model_A.parameters(), lr=learning_rate,
                 momentum=momentum)

n_epochs = 5

def train(model_A, optimizer_A, epoch, train_loader):
    model_A.train()
    for batch_idx, (data, target) in tqdm_notebook(enumerate(train_loader), desc='epoch {}'.format(epoch),
                                                  total=len(train_loader)):
        optimizer_A.zero_grad()
        output = model_A(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer_A.step()
        if batch_idx % log_interval == 0:
            torch.save(model_A.state_dict(), model_A_path)
            
def test(network, test_loader):
    network.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            output = network(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()
            pred = output.data.max(1, keepdim=True)[1]
            correct += pred.eq(target.data.view_as(pred)).sum()
        test_loss /= len(test_loader.dataset)
        acc = 100 * correct.item() / len(test_loader.dataset)
        print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
            test_loss, correct, len(test_loader.dataset), acc))
        
for i_epoch in range(n_epochs):
    train(model_A=model_A, optimizer_A=opt_A, epoch=i_epoch,
          train_loader=loader_A)
    test(network=model_A, test_loader=loader_A)

HBox(children=(IntProgress(value=0, description='epoch 0', max=469, style=ProgressStyle(description_width='ini…



Test set: Avg. loss: 0.3526, Accuracy: 26890/30000 (89.63%)



HBox(children=(IntProgress(value=0, description='epoch 1', max=469, style=ProgressStyle(description_width='ini…



Test set: Avg. loss: 0.1900, Accuracy: 28323/30000 (94.41%)



HBox(children=(IntProgress(value=0, description='epoch 2', max=469, style=ProgressStyle(description_width='ini…



Test set: Avg. loss: 0.1422, Accuracy: 28789/30000 (95.96%)



HBox(children=(IntProgress(value=0, description='epoch 3', max=469, style=ProgressStyle(description_width='ini…



Test set: Avg. loss: 0.1201, Accuracy: 28900/30000 (96.33%)



HBox(children=(IntProgress(value=0, description='epoch 4', max=469, style=ProgressStyle(description_width='ini…



Test set: Avg. loss: 0.1064, Accuracy: 29037/30000 (96.79%)



### Load `model_A` with trained weights

In [5]:
model_A = Net()
model_A.load_state_dict(torch.load(model_A_path))

### Use `model_A` to label `dataset_B`, Train `model_B` on `dataset_B_hat`

In [6]:
model_B = Net()
opt_B = optim.SGD(model_B.parameters(), lr=learning_rate,
                 momentum=momentum)

n_epochs = 5

def train_with_transfer_labels(model_A, model_B, optimizer_B, epoch, train_loader):
    model_A.eval()
    model_B.train()
    for batch_idx, (data, target) in tqdm_notebook(enumerate(train_loader), desc='epoch {}'.format(epoch),
                                                  total=len(train_loader)):
        target_hat = model_A(data)
        target_hat = torch.argmax(target_hat, dim=1)
        optimizer_B.zero_grad()
        output = model_B(data)
        loss_hat = F.nll_loss(output, target_hat)
        loss_hat.backward()
        optimizer_B.step()
        
        if batch_idx % log_interval == 0:
            torch.save(model_B.state_dict(), model_B_path)
        
for i_epoch in range(n_epochs):
    train_with_transfer_labels(model_A=model_A, model_B=model_B, optimizer_B=opt_B,
                              epoch=i_epoch, train_loader=loader_B)
    test(network=model_B, test_loader=loader_A)

HBox(children=(IntProgress(value=0, description='epoch 0', max=469, style=ProgressStyle(description_width='ini…



Test set: Avg. loss: 0.3422, Accuracy: 26984/30000 (89.95%)



HBox(children=(IntProgress(value=0, description='epoch 1', max=469, style=ProgressStyle(description_width='ini…



Test set: Avg. loss: 0.2328, Accuracy: 27898/30000 (92.99%)



HBox(children=(IntProgress(value=0, description='epoch 2', max=469, style=ProgressStyle(description_width='ini…



Test set: Avg. loss: 0.1952, Accuracy: 28244/30000 (94.15%)



HBox(children=(IntProgress(value=0, description='epoch 3', max=469, style=ProgressStyle(description_width='ini…



Test set: Avg. loss: 0.1679, Accuracy: 28494/30000 (94.98%)



HBox(children=(IntProgress(value=0, description='epoch 4', max=469, style=ProgressStyle(description_width='ini…



Test set: Avg. loss: 0.1628, Accuracy: 28588/30000 (95.29%)

