In [2]:
%load_ext autoreload

In [26]:
%autoreload
import os
import importlib
import itertools
import numpy as np

import torch
import torch.nn as nn
from torch.autograd import Variable
import torchvision.datasets as dset
from torchvision.utils import save_image
import torchvision.transforms as transforms

import MNISTM_dataset
from usps_dataset import USPS
from blocks import ResidualBlock, weights_initialization

In [28]:
use_cuda = torch.cuda.is_available()
root = 'data/'
batch_size = 32

latent_dimension = 10
img_channels = 1 # fixed for all
img_size = 32
# lr_C = 2e-4
learning_rate = 2e-4
beta_1 = 0.5 # ?
beta_2 = 0.999 # ?

## Define two datasets (MNIST, MNIST_M, SVHN):

In [29]:
os.makedirs('data/', exist_ok=True)

In [30]:
trans = transforms.Compose([transforms.Resize(img_size), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5),
                                                                                        (0.5, 0.5, 0.5))])

# # 
# transforms.ToPILImage(mode='F'), 
usps_trans = transforms.Compose([transforms.Resize(img_size),
                                 transforms.ToTensor(), 
                                 transforms.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5))])

In [31]:
training_set_original = dset.MNIST(root=root, train=True, transform=trans, download=False)
test_set_original = dset.MNIST(root=root, train=False, transform=trans, download=False)

In [32]:
training_set_m = MNISTM_dataset.MNISTM(root=root, train=True, transform=trans, download=False)
test_set_m = MNISTM_dataset.MNISTM(root=root, train=False, transform=trans, download=False)

In [33]:
# training_set_svhn = dset.SVHN(root=root, split='extra', transform=trans, download=False)
# test_set_svhn = dset.SVHN(root=root, split='test', transform=trans, download=False)

In [34]:
training_usps = USPS(root=root, train=True,transform=usps_trans, download=False)

In [35]:
dataloader_original = torch.utils.data.DataLoader(training_set_original, batch_size=batch_size, 
                                                  shuffle=True, num_workers=6, pin_memory=True, drop_last=True)
dataloader_MNISTM = torch.utils.data.DataLoader(training_set_m, batch_size=batch_size, 
                                                shuffle=True, num_workers=6, pin_memory=True)
# dataloader_svhn = torch.utils.data.DataLoader(training_set_svhn, batch_size=batch_size, 
#                                                 shuffle=True, num_workers=6, pin_memory=True, drop_last=True)

usps_data_loader = torch.utils.data.DataLoader(
        dataset=training_usps,
        batch_size=batch_size, shuffle=True, num_workers=6, pin_memory=True, drop_last=True)

dataloader_original_test = torch.utils.data.DataLoader(test_set_original, batch_size=batch_size, 
                                                  shuffle=False, num_workers=6, pin_memory=True)
dataloader_MNISTM_test = torch.utils.data.DataLoader(test_set_m, batch_size=batch_size, 
                                                shuffle=False, num_workers=6, pin_memory=True)
# dataloader_svhn_test = torch.utils.data.DataLoader(test_set_svhn, batch_size=batch_size, 
#                                                 shuffle=False, num_workers=6, pin_memory=True)

## Define generator, discriminator and classifier architectures


In [119]:
class Generator(nn.Module):
    def __init__(self, latent_dimension=10, img_channels=1, img_size=32):
        super().__init__()
        
        self.fc = nn.Linear(latent_dimension, img_channels*img_size**2)
        self.l1 = nn.Sequential(nn.Conv2d(img_channels*2, 64, 3, 1, 1), nn.ReLU(inplace=True))
        
        residual_blocks = []
        
        for i in range(6):
            residual_blocks.append(ResidualBlock())
        self.residual_blocks = nn.Sequential(*residual_blocks)
        
        self.l2 = nn.Sequential(nn.Conv2d(64, img_channels, 3, 1, 1), nn.Tanh())
         
    def forward(self, img, z):
        transformed_noise = self.fc(z).view(*img.shape)
        transformed_input = torch.cat((img, transformed_noise ), 1)
        out = self.l1(transformed_input)
        out = self.residual_blocks(out)
        img = self.l2(out)
        
        return img

In [120]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        
        def block(in_features, out_features, normalize=True):
            layers = [nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
                     nn.LeakyReLU(0.2, inplace=True)]
            if normalize:
                layers.append(nn.InstanceNorm2d(out_features))
            return layers
        
        self.model = nn.Sequential(
            *block(img_channels, 64, normalize=False),
            torch.nn.Dropout(p=0.1, inplace=True),
            *block(64, 128),
            torch.nn.Dropout(p=0.1, inplace=True),
            *block(128,256),
            torch.nn.Dropout(p=0.1, inplace=True),
            *block(256,512),
            torch.nn.Dropout(p=0.1, inplace=True),
            nn.Conv2d(512,1,3,1,1)
        )
        
    def forward(self, img):
        out = self.model(img)
        
        return out

In [121]:
class Classifier(nn.Module):
    def __init__(self, img_size=32, number_of_classes=10):
        super().__init__()
        
        input_size = img_size // 2**4
        
        def block(in_features, out_features, normalize=True):
            layers = [nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
                     nn.LeakyReLU(0.2, inplace=True)]
            if normalize:
                layers.append(nn.InstanceNorm2d(out_features))
            return layers
        
        self.model = nn.Sequential(
            *block(img_channels, 64, normalize=False),
            *block(64, 128),
            *block(128,256),
            *block(256,512)
        ) # pop the last layer from the Discriminator
        
        self.classifier = nn.Sequential(
            nn.Linear(512*input_size**2, number_of_classes),
            nn.Softmax(dim=0)
        )
        
    def forward(self, img):
        features = self.model(img)
        features = features.view(features.size(0), -1)
        labels = self.classifier(features)
        return labels

In [174]:
generator, discriminator, classifier = Generator(), Discriminator(), Classifier()

if use_cuda: generator.cuda(); discriminator.cuda(); classifier.cuda()

generator.apply(weights_initialization); discriminator.apply(weights_initialization); classifier.apply(weights_initialization); 

## Define loss function and optimizers

In [175]:
MSE_loss = torch.nn.MSELoss()
task_loss = torch.nn.CrossEntropyLoss()

lambda_gan = 1 #0.05
lambda_task = 1 # Loss weights

In [176]:
opt_G = torch.optim.Adam(itertools.chain(generator.parameters(), classifier.parameters()),
                        lr = learning_rate, betas = (beta_1, beta_2))

opt_D = torch.optim.Adam(itertools.chain(discriminator.parameters()),
                        lr = learning_rate, betas = (beta_1, beta_2))

In [177]:
FloatTensor = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if use_cuda else torch.LongTensor

## GAN Training

In [178]:
patch = int(img_size / 2**4)
patch = (1, patch, patch)

In [179]:
task_performance = []
target_performance = []

In [180]:
len(dataloader_original)

1875

In [181]:
n_epochs = 200

classifier.train()

for epoch in range(n_epochs):
    for i, ((imgs_A, labels_A), (imgs_B, labels_B)) in enumerate(zip(itertools.cycle(usps_data_loader),
                                                                     dataloader_original)):
        
        real = Variable(FloatTensor(batch_size, *patch).fill_(1.0), requires_grad=False)
        fake = Variable(FloatTensor(batch_size, *patch).fill_(0.0), requires_grad=False)

        imgs_A =  Variable(imgs_A.type(FloatTensor))#.expand(batch_size, 3, img_size, img_size) # expand if MNIST
        labels_A = Variable(labels_A.type(LongTensor).squeeze())
        imgs_B = Variable(imgs_B.type(FloatTensor))#.expand(batch_size, 3, img_size, img_size) # pointless expanding?
        
        ####################
        # Train generator
        ####################
        
        opt_G.zero_grad()
         
        # Generate noise: # used to be uniform -1 to 1 # sample from Normal instead?
        z = Variable(FloatTensor(np.random.normal(0,1, (batch_size, latent_dimension)))) 
        
        # Generate artificial images based on noise input:
        fake_B = generator(imgs_A, z)
        
        pred_fB = classifier(fake_B)
#         print(labels_A)
#         print(labels_B)
        task_loss_ = (task_loss(pred_fB, labels_A) + task_loss(classifier(imgs_A), labels_A)) / 2.
        
        g_loss = lambda_gan * MSE_loss(discriminator(fake_B), real) + lambda_task * task_loss_
        
        g_loss.backward()
        opt_G.step()
        
        ####################
        # Train discriminator
        ####################
        
        opt_D.zero_grad()
        
        real_loss = MSE_loss(discriminator(imgs_B), real)
        fake_loss = MSE_loss(discriminator(fake_B.detach()), fake)
        
        d_loss = (real_loss + fake_loss) / 2
        
        d_loss.backward()
        opt_D.step()
        
        #############################
        # Evaluate target performance
        #############################
        
        acc = np.mean(np.argmax(pred_fB.data.cpu().numpy(), axis=1) == labels_A.data.cpu().numpy())
        task_performance.append(acc)
        
        pred_B = classifier(imgs_B)
        target_acc = np.mean(np.argmax(pred_B.data.cpu().numpy(), axis=1) == labels_B.numpy())
        target_performance.append(target_acc)
        if len(target_performance) > 100:
            target_performance.pop(0)
        
        batches_done = len(dataloader_original) * epoch + i
        
        if batches_done%50==0:
            print ("[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f] \
            [CLF acc: %3d%% (%3d%%), target_acc: %3d%% (%3d%%)]" %
            (epoch, n_epochs,
            i, len(dataloader_original),
            d_loss.data[0], g_loss.data[0],
            100*acc, 100*np.mean(task_performance),
            100*target_acc, 100*np.mean(target_performance)))
        
        interval = 600
        
        if batches_done % interval == 0:
            sample = torch.cat((imgs_A.data[:5], fake_B.data[:5], imgs_B.data[:5]), -2)
            save_image(sample, 'usps_images/%d.png' % batches_done, nrow=int(np.sqrt(batch_size)), normalize=True)        















































































































































































































































## Save/Load trained models and results

In [60]:
# torch.save(generator,'generator_250.pt')
# torch.save(discriminator,'discriminator_250.pt')
# torch.save(classifier,'classifier_g_250.pt')
generator = torch.load('generator_250.pt')
discriminator = torch.load('discriminator_250.pt')
classifier = torch.load('classifier_g_250.pt')

In [36]:
# np.save('target_performance.npy', target_performance)
# np.save('task_performance.npy', task_performance)
task_performance = np.load('task_performance.npy')
target_performance = np.load('target_performance.npy')

## Train a clean MNIST-M-only classifier

In [130]:
class SimpleClassifier(nn.Module): # non-convolutional
    def __init__(self,input_size,hidden_size, num_classes=10):
        super().__init__()
        self.l1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.l2 = nn.Linear(hidden_size, num_classes)
        self.softmax = nn.Softmax(dim=0)

    def forward(self, img):
        out = self.l1(img)
        out = self.relu(out)
        out = self.l2(out)
        out = self.softmax(out)
        return out

In [131]:
simple_classifier = SimpleClassifier(input_size=img_channels*img_size**2, hidden_size=500).cuda()
simple_classifier.apply(weights_initialization);

classifier_performance = []

lr_C = 1e-4 # 1e-4 worked
opt_C = torch.optim.Adam(itertools.chain( simple_classifier.parameters()),
                        lr = lr_C, betas = (beta_1, beta_2))

In [160]:
n_epochs = 50
simple_classifier.train()

for epoch in range(n_epochs):
    for i, (imgs, labels) in enumerate(dataloader_original):
        
        imgs = Variable(imgs.view(-1,img_channels*img_size**2).type(FloatTensor))
        labels_var = Variable(labels.type(LongTensor))
        labels = labels.type(LongTensor)
        
        opt_C.zero_grad()
        
#         images.view(-1, 28*28))
#         print(imgs)
        output = simple_classifier(imgs)
        c_loss = task_loss(output, labels_var)
        c_loss.backward()
        opt_C.step()
        
        acc = np.mean(torch.max(output.data, 1)[1] == labels)
        
        classifier_performance.append(acc)
        
        print ("[Epoch %d/%d] [Batch %d/%d] [C loss: %f] \
                [CLF acc: %3d%% (%3d%%)]" %
                (epoch, n_epochs,
                i, len(dataloader_MNISTM),
                c_loss.data[0],
                100*acc, 100*np.mean(classifier_performance[-100:])))
        
        # watch out for overfit - we don't have validation accuracy here

















































































































Process Process-614:
Process Process-616:
Process Process-618:
Process Process-615:
Process Process-613:
Process Process-617:
Traceback (most recent call last):
  File "/usr/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
Traceback (most recent call last):
  File "/usr/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()




  File "/usr/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
Traceback (most recent call last):
Traceback (most recent call last):
  File "/usr/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 55, in _worker_loop
    samples = collate_fn([dataset[i] for i in batch_indices])
Traceback (most recent call last):
  File "/usr/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 50, in _worker_loop
    r = index_queue.get()
  File "/usr/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 55, in _worker_loop
    samp

KeyboardInterrupt: 

## Validate on the original and domain-translated datasets

In [171]:
# turn off model traning/somehow switch to evaluation

simple_classifier.eval()

correct_orignal, correct_MNISTM, correct_translated, total1 = 0,0,0,0

for i, ((imgs_A, labels_A), (imgs_B, labels_B)) in enumerate(zip(usps_data_loader, dataloader_original)):        

        base_tensor = imgs_A.type(FloatTensor) ## .expand(batch_size, 3, img_size, img_size).contiguous()
        imgs_A_g =  Variable(base_tensor)
        imgs_A_c = Variable(base_tensor.view(-1,img_channels*img_size**2))
        
        labels_A = labels_A.type(LongTensor).squeeze()
        imgs_B = Variable(imgs_B.view(-1,img_channels*img_size**2).type(FloatTensor))
        labels_B = labels_B.type(LongTensor).squeeze()

#         z = Variable(FloatTensor(np.random.uniform(-1,1, (batch_size, latent_dimension))))
        z = Variable(FloatTensor(np.random.normal(0,1, (batch_size, latent_dimension)))) 
        original_output = simple_classifier(imgs_A_c)
        MNISTM_output = simple_classifier(imgs_B)
        translated_output = simple_classifier(generator(imgs_A_g, z).view(-1,img_channels*img_size**2))

#             original_loss = task_loss( original_prediction, labels_A)
#             MNISTM_loss = task_loss( MNISTM_prediction, labels_B)
#             translated_loss = task_loss( translated_prediction, labels_A)

        _, predicted = torch.max(original_output.data, 1) # untranslated val (should be lower)
        correct_orignal += (predicted == labels_A).sum()

        _, predicted = torch.max(MNISTM_output.data, 1) # proper val
        correct_MNISTM += (predicted == labels_B).sum()

        _, predicted = torch.max(translated_output.data, 1) # translated val (should be good)
        correct_translated += (predicted == labels_A).sum()

        total1 += labels_A.size(0)
        
        if i%50==0: print(i)

0
50
100
150
200


In [170]:
# lab

In [166]:
total1

7424

In [147]:
# predicted

In [173]:
correct_orignal/total1, correct_MNISTM/total1, correct_translated/total1 # (0.9779, 0.969, 0.9803)

(0.7992995689655172, 0.9451778017241379, 0.8755387931034483)

In [142]:
# (28484, 6990, 29733)

In [None]:
# produce a classifier plot of accuracy of various letters

In [113]:
# training set: (0.9775833333333334, 0.9827, 0.9805833333333334) 
# (as predicted - classifier_clean is probably overfitting)

In [None]:
# add some way to evaluate GAN overfitting (Vizualize the latent space before 12)

In [None]:
# use the below to guide the training on generated samples instead of just randomly selecting them

# class_correct = list(0. for i in range(10))
# class_total = list(0. for i in range(10))
# with torch.no_grad():
#     for data in testloader:
#         images, labels = data
#         outputs = net(images)
#         _, predicted = torch.max(outputs, 1)
#         c = (predicted == labels).squeeze()
#         for i in range(4):
#             label = labels[i]
#             class_correct[label] += c[i].item()
#             class_total[label] += 1


# for i in range(10):
#     print('Accuracy of %5s : %2d %%' % (
#         classes[i], 100 * class_correct[i] / class_total[i]))

In [None]:
Results

#1 MNIST and MNIST-M are too similar - domain translation effect is marginal, as mnist itself is an easier task
# hyperparameter tweaking based on paper - outline - it doesn't like training on SVNH
# show closes images on L2 distance from MNIST