In [None]:
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data.sampler import SubsetRandomSampler

import matplotlib
import matplotlib.pyplot as plt

from IPython.core.debugger import Tracer

In [None]:
class Options:
    def __init__(self):
        self.batch_size = 64 # input batch size for training (default: 64)
        self.test_batch_size = 100 # input batch size for testing (default: 100)
        self.epochs = 10 # number of epochs to train (default: 10)
        self.lr = 0.01 # learning rate (default: 0.01)
        self.momentum = 0.5 # SGD momentum (default: 0.5)
        self.no_cuda = False # disables CUDA training if True
        self.gpu = 2 # set which GPU to use
        self.seed = 1 # random seed (default: 1)
        self.log_interval = 10 # how many batches to wait before logging training status
        self.shift_test = True
        self.train_shift_per_class = 20
        
opt = Options()

In [None]:
def prepare_data(train_shift_per_class=0):
    ### Prepare Data ###
    ####################
    mnist_train = datasets.MNIST('./mydata', train=True, download=True,
                           transform=transforms.Compose([
                               transforms.ToTensor(),
                               transforms.Normalize((0.1307,), (0.3081,))
                           ]))

    idx0_train = mnist_train.train_labels==0
    idx0_train_val = idx0_train.nonzero()
    idx1_train = mnist_train.train_labels==1
    idx1_train_val = idx1_train.nonzero()
    idx = idx0_train + idx1_train
    mnist_train.train_labels = mnist_train.train_labels[idx]
    mnist_train.train_data = mnist_train.train_data[idx]

    # partially shift the train set
    idx0_train_val = (mnist_train.train_labels==0).nonzero()
    idx1_train_val = (mnist_train.train_labels==1).nonzero()
    
    idx0_train_shift = idx0_train_val[0:train_shift_per_class]
    idx0_train_noshift = idx0_train_val[train_shift_per_class:]
    idx1_train_shift = idx1_train_val[0:train_shift_per_class]
    idx1_train_noshift = idx1_train_val[train_shift_per_class:]
    
    # bias the training set
    mnist_train.train_data[idx0_train_shift] = 255 - mnist_train.train_data[idx0_train_shift]
    mnist_train.train_data[idx1_train_shift] = 255 - mnist_train.train_data[idx1_train_shift]

    mnist_test = datasets.MNIST('./mydata', train=False,
                           transform=transforms.Compose([
                               transforms.ToTensor(),
                               transforms.Normalize((0.1307,), (0.3081,))
                           ]))
    
    # make subsetsampler for 0, shifted 0, 1, and shifted 1
    train_0_shift_sampler = SubsetRandomSampler(idx0_train_shift)
    train_0_noshift_sampler = SubsetRandomSampler(idx0_train_noshift)
    train_1_shift_sampler = SubsetRandomSampler(idx1_train_shift)
    train_1_noshift_sampler = SubsetRandomSampler(idx1_train_noshift)
    
    idx0_test = mnist_test.test_labels==0
    idx1_test = mnist_test.test_labels==1
    idx = idx0_test + idx1_test
    mnist_test.test_labels = mnist_test.test_labels[idx]
    mnist_test.test_data = mnist_test.test_data[idx]
    if opt.shift_test:
        mnist_test.test_data = 255 - mnist_test.test_data
    
    # create data loaders
    train_loader = torch.utils.data.DataLoader(mnist_train, batch_size=opt.batch_size, shuffle=True, **kwargs)
    test_loader = torch.utils.data.DataLoader(mnist_test, batch_size=opt.test_batch_size, shuffle=True, **kwargs)
    train_0_shift_loader = torch.utils.data.DataLoader(mnist_train, batch_size=opt.batch_size, shuffle=True, 
                                                       sampler=train_0_shift_sampler, **kwargs)
    train_0_noshift_loader = torch.utils.data.DataLoader(mnist_train, batch_size=opt.batch_size, shuffle=True, 
                                                       sampler=train_0_noshift_sampler, **kwargs)
    train_1_shift_loader = torch.utils.data.DataLoader(mnist_train, batch_size=opt.batch_size, shuffle=True, 
                                                       sampler=train_1_shift_sampler, **kwargs)
    train_1_noshift_loader = torch.utils.data.DataLoader(mnist_train, batch_size=opt.batch_size, shuffle=True, 
                                                       sampler=train_1_noshift_sampler, **kwargs)

    # Check the number of samples in train and test sets
    print('Number of train samples = %i'%len(mnist_train))
    print('Number of 0 in train set = %i'%idx0_train.sum())
    print('Number of 1 in train set = %i'%idx1_train.sum())
    print('Number of test samples = %i'%len(mnist_test))
    print('Number of 0 in test set = %i'%idx0_test.sum())
    print('Number of 1 in test set = %i'%idx1_test.sum())

    # visualize train samples
    print('Visualize shifted class 0 train samples')
    for i in idx0_train_val[0:3]:
        image_i, target_i = mnist_train[i[0]]
        print(image_i.max())
        print(image_i.min())

        plt.imshow(image_i[0], cmap='gray')
        plt.show()

    print('Visualize non-shifted class 0 train samples')
    for i in idx0_train_val[train_shift_per_class:train_shift_per_class+3]:
        image_i, target_i = mnist_train[i[0]]
        print(image_i.max())
        print(image_i.min())

        plt.imshow(image_i[0], cmap='gray')
        plt.show()

    print('Visualize shifted class 1 train samples')
    for i in idx1_train_val[0:3]:
        image_i, target_i = mnist_train[i[0]]
        print(image_i.max())
        print(image_i.min())

        plt.imshow(image_i[0], cmap='gray')
        plt.show()

    print('Visualize non-shifted class 1 train samples')
    for i in idx1_train_val[train_shift_per_class:train_shift_per_class+3]:
        image_i, target_i = mnist_train[i[0]]
        print(image_i.max())
        print(image_i.min())

        plt.imshow(image_i[0], cmap='gray')
        plt.show()

    # visualize test samples
    print('Visualize test samples')
    for i in range(6):
        image_i, target_i = mnist_test[i]
        print(image_i.max())
        print(image_i.min())

        plt.imshow(image_i[0], cmap='gray')
        plt.show()
        
    return train_loader, test_loader, train_0_shift_loader, train_0_noshift_loader, train_1_shift_sampler, train_1_noshift_loader

In [None]:
def train_cyclegan(train_shift_per_class=0):
    use_cuda = not opt.no_cuda and torch.cuda.is_available()
    torch.manual_seed(opt.seed)
    device = torch.device(opt.gpu if use_cuda else "cpu")
    kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
    train_loader, test_loader, train_0_shift_loader, train_0_noshift_loader, train_1_shift_sampler, train_1_noshift_loader = prepare_data(train_shift_per_class)
    model = Net().to(device)
    optimizer = optim.SGD(model.parameters(), lr=opt.lr, momentum=opt.momentum)
    
    best_test_accu = 0
    for epoch in range(1, opt.epochs + 1):
        train(opt, model, device, train_loader, optimizer, epoch)
        test_accu = test(opt, model, device, test_loader)
        best_test_accu = test_accu if test_accu > best_test_accu else best_test_accu
        print('\nTest set: Best accuracy: ({:.2f}%)\n'.format(best_test_accu))
    
    return best_test_accu

In [None]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 2)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

In [None]:
def train(opt, model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % opt.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.2f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))

In [None]:
def test(opt, model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
            pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    test_accu = 100. * correct / len(test_loader.dataset)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        test_accu))
    return test_accu

In [None]:
def main(train_shift_per_class=0):
    use_cuda = not opt.no_cuda and torch.cuda.is_available()
    torch.manual_seed(opt.seed)
    device = torch.device(opt.gpu if use_cuda else "cpu")
    kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
    train_loader, test_loader, train_0_shift_loader, train_0_noshift_loader, train_1_shift_sampler, train_1_noshift_loader = prepare_data(train_shift_per_class)
    model = Net().to(device)
    optimizer = optim.SGD(model.parameters(), lr=opt.lr, momentum=opt.momentum)
    
    best_test_accu = 0
    for epoch in range(1, opt.epochs + 1):
        train(opt, model, device, train_loader, optimizer, epoch)
        test_accu = test(opt, model, device, test_loader)
        best_test_accu = test_accu if test_accu > best_test_accu else best_test_accu
        print('\nTest set: Best accuracy: ({:.2f}%)\n'.format(best_test_accu))
    
    return best_test_accu

In [None]:
accu_vec = []
for i in [25]:
    test_accu = main(i)
    accu_vec.append(test_accu)

In [None]:
print(accu_vec)
import time
time.sleep(36000)

In [None]:
import numpy as np
np.savez('test_accu_biased_mnist.npz', num_shifted=[0,10,20,30,40,50,60,70,80,90,100,110,120,130,140,150,160,170,180,190,200,400,600,800,1000], test_accu=accu_vec)

In [None]:
import seaborn as sns
import numpy as np
import pandas as pd

num_shifted = [0,10,20,30,40,50,60,70,80,90,100,110,120,130,140,150,160,170,180,190,200,400,600,800,1000]
d = {'number of train samples shifted per class': num_shifted[0:21], 'test accuracy': accu_vec[0:21]}
pdnumsqr = pd.DataFrame(d)

sns.set(style='darkgrid')
g=sns.lineplot(x='number of train samples shifted per class', y='test accuracy', data=pdnumsqr)
fig = g.get_figure()
fig.savefig("accu_vs_numshift.png")
# g.set(xscale="log")