# Semi-Supervised Learning of MNIST CNN
Jared Nielsen



# **Results**


### MNIST
`model_B` achieves 90.6% accuracy on `data_A` when `model_A` has 89.9% accuracy on `data_A`.  
`model_B` achieves 96.6% accuracy on `data_A` when `model_A` has 97.9% accuracy on `data_A`.

### Fashion-MNIST
`model_B` achieves 81.1% accuracy on `data_A` when `model_A` has 83.5% accuracy on `data_A`.  
`model_B` achieves 85.5% accuracy on `data_A` when `model_A` has 87.6% accuracy on `data_A`.

# **Notes**

### Steps to Semi-Supervised Metrics
- Separate MNIST into `dataset_A` and `dataset_B`. Hide the labels from `dataset_B`.
- Instantiate `model_A` and `model_B` of the same architecture.  
- Train `model_A` on `dataset_A`. Use `model_A` to predict the labels for `dataset_B`. 
- Train `model_B` on `dataset_B_augmented`. Use `model_B` to predict the labels for `dataset_A`.

### Issues
- If a CNN gets 98% test performance on MNIST, then it will get at worst 0.98^2 = 96% performance on transfer learning.

### Comparable Architectures
- MLP vs CNN?

### Ideas
- Should I use `mag` to serialize the models? **Yes.**
- Fashion-MNIST instead of MNIST?

In [1]:
import torch
import torchvision
import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, random_split, TensorDataset
from tqdm import tqdm, tqdm_notebook
from time import sleep

import mnist_cnn
from mnist_cnn import Net, model_A_path, model_B_path

### Load `dataset_A` and `dataset_B`

In [32]:
batch_size_train = 64
batch_size_test = 1000
learning_rate = 1e-3 #0.01
momentum = 0.5
log_interval = 10

if torch.cuda.is_available():
    print('using cuda')
    device = torch.device('cuda')
else:
    print('using cpu')
    device = torch.device('cpu')

random_seed = 1
# torch.backends.cudnn.enabled = False
torch.manual_seed(random_seed)

root_dir = "../data/fashion-mnist/"
mnist_train = torchvision.datasets.FashionMNIST(root_dir, train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ]))
n_digits = len(mnist_train)
dataset_A, dataset_B = random_split(mnist_train, [n_digits // 2, n_digits - n_digits // 2])
loader_A, loader_B = [DataLoader(dataset, batch_size=batch_size_train, shuffle=True) 
                      for dataset in (dataset_A, dataset_B)]

using cuda


### Train `model_A` on `dataset_A`

In [33]:
model_A = Net().to(device)
opt_A = optim.SGD(model_A.parameters(), lr=learning_rate,
                 momentum=momentum)
opt_A = optim.Adam(model_A.parameters(), lr=learning_rate)

n_epochs = 10

def train(model_A, optimizer_A, epoch, train_loader):
    model_A.train()
    for batch_idx, (data, target) in tqdm_notebook(enumerate(train_loader), desc='epoch {}'.format(epoch),
                                                  total=len(train_loader)):
        data, target = data.to(device), target.to(device)
        optimizer_A.zero_grad()
        output = model_A(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer_A.step()
        if batch_idx % log_interval == 0:
            torch.save(model_A.state_dict(), model_A_path)
            
def test(network, test_loader):
    network.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = network(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()
            pred = output.data.max(1, keepdim=True)[1]
            correct += pred.eq(target.data.view_as(pred)).sum()
        test_loss /= len(test_loader.dataset)
        acc = 100 * correct.item() / len(test_loader.dataset)
        print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
            test_loss, correct, len(test_loader.dataset), acc))
        
for i_epoch in range(n_epochs):
    train(model_A=model_A, optimizer_A=opt_A, epoch=i_epoch,
          train_loader=loader_A)
    test(network=model_A, test_loader=loader_A)

HBox(children=(IntProgress(value=0, description='epoch 0', max=469), HTML(value='')))


Test set: Avg. loss: 0.5664, Accuracy: 23437/30000 (78.12%)



HBox(children=(IntProgress(value=0, description='epoch 1', max=469), HTML(value='')))


Test set: Avg. loss: 0.4992, Accuracy: 24565/30000 (81.88%)



HBox(children=(IntProgress(value=0, description='epoch 2', max=469), HTML(value='')))


Test set: Avg. loss: 0.4441, Accuracy: 25128/30000 (83.76%)



HBox(children=(IntProgress(value=0, description='epoch 3', max=469), HTML(value='')))


Test set: Avg. loss: 0.4130, Accuracy: 25493/30000 (84.98%)



HBox(children=(IntProgress(value=0, description='epoch 4', max=469), HTML(value='')))


Test set: Avg. loss: 0.3965, Accuracy: 25578/30000 (85.26%)



HBox(children=(IntProgress(value=0, description='epoch 5', max=469), HTML(value='')))


Test set: Avg. loss: 0.3956, Accuracy: 25645/30000 (85.48%)



HBox(children=(IntProgress(value=0, description='epoch 6', max=469), HTML(value='')))


Test set: Avg. loss: 0.3784, Accuracy: 25763/30000 (85.88%)



HBox(children=(IntProgress(value=0, description='epoch 7', max=469), HTML(value='')))


Test set: Avg. loss: 0.3604, Accuracy: 26100/30000 (87.00%)



HBox(children=(IntProgress(value=0, description='epoch 8', max=469), HTML(value='')))


Test set: Avg. loss: 0.3499, Accuracy: 26179/30000 (87.26%)



HBox(children=(IntProgress(value=0, description='epoch 9', max=469), HTML(value='')))


Test set: Avg. loss: 0.3351, Accuracy: 26302/30000 (87.67%)



### Load `model_A` with trained weights

In [34]:
model_A = Net().to(device)
model_A.load_state_dict(torch.load(model_A_path))
test(network=model_A, test_loader=loader_A)
test(network=model_A, test_loader=loader_B)


Test set: Avg. loss: 0.3350, Accuracy: 26295/30000 (87.65%)


Test set: Avg. loss: 0.3597, Accuracy: 26012/30000 (86.71%)



### Use `model_A` to label `dataset_B`, Train `model_B` on `dataset_B_hat`

In [35]:
model_B = Net().to(device)
opt_B = optim.SGD(model_B.parameters(), lr=learning_rate,
                 momentum=momentum)
opt_B = optim.Adam(model_B.parameters(), lr=learning_rate)

n_epochs = 10

def train_with_transfer_labels(model_A, model_B, optimizer_B, epoch, train_loader):
    model_A.eval()
    model_B.train()
    for batch_idx, (data, target) in tqdm_notebook(enumerate(train_loader), desc='epoch {}'.format(epoch),
                                                  total=len(train_loader)):
        data, target = data.to(device), target.to(device)
        target_hat = model_A(data)
        target_hat = torch.argmax(target_hat, dim=1)
        optimizer_B.zero_grad()
        output = model_B(data)
        loss_hat = F.nll_loss(output, target_hat)
        loss_hat.backward()
        optimizer_B.step()
        
        if batch_idx % log_interval == 0:
            torch.save(model_B.state_dict(), model_B_path)
        
for i_epoch in range(n_epochs):
    train_with_transfer_labels(model_A=model_A, model_B=model_B, optimizer_B=opt_B,
                              epoch=i_epoch, train_loader=loader_B)
    test(network=model_B, test_loader=loader_A)

HBox(children=(IntProgress(value=0, description='epoch 0', max=469), HTML(value='')))


Test set: Avg. loss: 0.6161, Accuracy: 23490/30000 (78.30%)



HBox(children=(IntProgress(value=0, description='epoch 1', max=469), HTML(value='')))


Test set: Avg. loss: 0.5593, Accuracy: 24355/30000 (81.18%)



HBox(children=(IntProgress(value=0, description='epoch 2', max=469), HTML(value='')))


Test set: Avg. loss: 0.5512, Accuracy: 24571/30000 (81.90%)



HBox(children=(IntProgress(value=0, description='epoch 3', max=469), HTML(value='')))


Test set: Avg. loss: 0.5413, Accuracy: 25085/30000 (83.62%)



HBox(children=(IntProgress(value=0, description='epoch 4', max=469), HTML(value='')))


Test set: Avg. loss: 0.5431, Accuracy: 25364/30000 (84.55%)



HBox(children=(IntProgress(value=0, description='epoch 5', max=469), HTML(value='')))


Test set: Avg. loss: 0.5197, Accuracy: 25438/30000 (84.79%)



HBox(children=(IntProgress(value=0, description='epoch 6', max=469), HTML(value='')))


Test set: Avg. loss: 0.5021, Accuracy: 25526/30000 (85.09%)



HBox(children=(IntProgress(value=0, description='epoch 7', max=469), HTML(value='')))


Test set: Avg. loss: 0.5338, Accuracy: 25525/30000 (85.08%)



HBox(children=(IntProgress(value=0, description='epoch 8', max=469), HTML(value='')))


Test set: Avg. loss: 0.5454, Accuracy: 25503/30000 (85.01%)



HBox(children=(IntProgress(value=0, description='epoch 9', max=469), HTML(value='')))


Test set: Avg. loss: 0.5474, Accuracy: 25599/30000 (85.33%)



In [36]:
model_B = Net().to(device)
model_B.load_state_dict(torch.load(model_B_path))
test(network=model_B, test_loader=loader_A)
test(network=model_B, test_loader=loader_B)


Test set: Avg. loss: 0.5539, Accuracy: 25661/30000 (85.54%)


Test set: Avg. loss: 0.5563, Accuracy: 25589/30000 (85.30%)

