In [1]:
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets

from datasets import limits

from PIL import Image
from utils import NormalizeRangeTanh, UnNormalizeRangeTanh
import torchvision.transforms as transforms
from torch.autograd import Variable
import digits_model
import torchvision

import scipy.io as sio
import numpy as np
import matplotlib.pyplot as plt

In [74]:
class ZeroPadBottom(object):
    ''' Zero pads batch of image tensor Variables on bottom to given size. Input (B, C, H, W) - padded on H axis. '''
    def __init__(self, size, use_gpu=True):
        self.size = size
        self.use_gpu = use_gpu
        
    def __call__(self, sample):
        B, C, H, W = sample.size()
        diff = self.size - H
        padding = Variable(torch.zeros(B, C, diff, W), requires_grad=False)
        if self.use_gpu:
            padding = padding.cuda()
        zero_padded = torch.cat((sample, padding), dim=2)
        return zero_padded
unnormRange = UnNormalizeRangeTanh()

In [4]:
f_old_model = torch.load('./pretrained_model/model_F_SVHN_NormRange.tar')['best_model']
f_old_dict = f_old_model.state_dict()
f_new_model = digits_model.F(3,False)
f_new_dict = f_new_model.state_dict()
f_new_dict = {k: v for k, v in f_old_dict.items() if k in f_new_dict}
f_old_dict.update(f_new_dict)
f_new_model.load_state_dict(f_new_dict)
f_model = f_new_model

for param in f_model.parameters():
    param.requires_grad = False
f_model = f_model.eval()

In [60]:
g_model = torch.load('./final_models/fin_model.tar')['best_model']

In [61]:
SVHN_transform = transforms.Compose([transforms.ToTensor(), NormalizeRangeTanh()])
s_train_set = limits.LimitDataset(torchvision.datasets.SVHN(root = './data/svhn', split='extra',download = False, transform = SVHN_transform), 1024)
s_train_loader = torch.utils.data.DataLoader(s_train_set, batch_size=128, shuffle=True, num_workers=8)

s_test_set = limits.LimitDataset(torchvision.datasets.SVHN(root = './data/svhn/', split='test', download = False, transform = SVHN_transform),256)
s_test_loader = torch.utils.data.DataLoader(s_test_set, batch_size=128, shuffle=False, num_workers=8)

In [96]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)
    
def make_image(M):
    tList = [torchvision.utils.make_grid(unnormRange(m[:16]), nrow=4) for m in torch.unbind(M, dim=0) ]
    res = torch.stack(tList, dim=0)
    return res

def train(classifier, device, train_loader, optimizer, epoch):
    classifier.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        # Call Generator
        #data = Variable(data.cpu().float())
        s_f = f_model(data)
        s_g = model['G'].cpu()(s_f)
        #print(s_g.size())
        #apply(torch.inverse, torch.randn(100, 200, 200))
        s_g = s_g[:, :, 2:30, 2:30]
        #s_g = make_image(s_g)
        #s_g = torchvision.utils.make_grid(unnormRange(s_g[:16]), nrow=4)
        #print(s_g.size())
        if torch.cuda.is_available():
            s_g.cuda()
        optimizer.zero_grad()
        output = classifier(s_g)
        #print(output.size(), target.size())
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))

def test(classifier, device, test_loader):
    classifier.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            # Call Generator
            data = Variable(data.cpu().float())
            s_f = f_model(data)
            s_g = model['G'].cpu()(s_f)
            s_g = s_g[:, :, 2:30, 2:30]
            if torch.cuda.is_available():
                s_g.cuda()
            output = classifier(s_g)
            test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
            pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

In [71]:
use_cuda = False#torch.cuda.is_available()
seed = 1
torch.manual_seed(seed)
device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
log_interval = 50

In [97]:
classifier = Net().to(device)
optimizer = optim.SGD(classifier.parameters(), lr=0.01, momentum=0.5)

In [98]:
epochs = 200
for epoch in range(1, epochs + 1):
    train(classifier, device, s_train_loader, optimizer, epoch)
    test(classifier, device, s_test_loader)


Test set: Average loss: 2.2846, Accuracy: 28/256 (11%)


Test set: Average loss: 2.2700, Accuracy: 43/256 (17%)


Test set: Average loss: 2.2575, Accuracy: 75/256 (29%)


Test set: Average loss: 2.2430, Accuracy: 98/256 (38%)


Test set: Average loss: 2.2264, Accuracy: 79/256 (31%)


Test set: Average loss: 2.2070, Accuracy: 77/256 (30%)


Test set: Average loss: 2.1845, Accuracy: 61/256 (24%)


Test set: Average loss: 2.1595, Accuracy: 59/256 (23%)


Test set: Average loss: 2.1324, Accuracy: 66/256 (26%)


Test set: Average loss: 2.1061, Accuracy: 77/256 (30%)


Test set: Average loss: 2.0742, Accuracy: 83/256 (32%)


Test set: Average loss: 2.0418, Accuracy: 94/256 (37%)


Test set: Average loss: 1.9959, Accuracy: 97/256 (38%)


Test set: Average loss: 1.9436, Accuracy: 109/256 (43%)


Test set: Average loss: 1.8825, Accuracy: 129/256 (50%)


Test set: Average loss: 1.8124, Accuracy: 137/256 (54%)


Test set: Average loss: 1.7319, Accuracy: 144/256 (56%)


Test set: Average loss: 1.


Test set: Average loss: 0.7332, Accuracy: 213/256 (83%)


Test set: Average loss: 0.7384, Accuracy: 215/256 (84%)


Test set: Average loss: 0.7310, Accuracy: 214/256 (84%)


Test set: Average loss: 0.7418, Accuracy: 215/256 (84%)


Test set: Average loss: 0.7376, Accuracy: 215/256 (84%)


Test set: Average loss: 0.7408, Accuracy: 214/256 (84%)


Test set: Average loss: 0.7576, Accuracy: 214/256 (84%)


Test set: Average loss: 0.7572, Accuracy: 214/256 (84%)


Test set: Average loss: 0.7472, Accuracy: 216/256 (84%)


Test set: Average loss: 0.7410, Accuracy: 216/256 (84%)


Test set: Average loss: 0.7556, Accuracy: 216/256 (84%)


Test set: Average loss: 0.7617, Accuracy: 214/256 (84%)


Test set: Average loss: 0.7472, Accuracy: 217/256 (85%)


Test set: Average loss: 0.7540, Accuracy: 214/256 (84%)


Test set: Average loss: 0.7737, Accuracy: 216/256 (84%)


Test set: Average loss: 0.7589, Accuracy: 216/256 (84%)


Test set: Average loss: 0.7603, Accuracy: 214/256 (84%)


Test set: Ave


Test set: Average loss: 0.8469, Accuracy: 215/256 (84%)


Test set: Average loss: 0.8595, Accuracy: 217/256 (85%)


Test set: Average loss: 0.8609, Accuracy: 215/256 (84%)


Test set: Average loss: 0.8727, Accuracy: 217/256 (85%)


Test set: Average loss: 0.8443, Accuracy: 215/256 (84%)


Test set: Average loss: 0.8433, Accuracy: 217/256 (85%)


Test set: Average loss: 0.8665, Accuracy: 219/256 (86%)


Test set: Average loss: 0.8607, Accuracy: 218/256 (85%)


Test set: Average loss: 0.8584, Accuracy: 217/256 (85%)


Test set: Average loss: 0.8594, Accuracy: 217/256 (85%)


Test set: Average loss: 0.8711, Accuracy: 216/256 (84%)


Test set: Average loss: 0.8770, Accuracy: 216/256 (84%)


Test set: Average loss: 0.8766, Accuracy: 214/256 (84%)


Test set: Average loss: 0.8997, Accuracy: 214/256 (84%)


Test set: Average loss: 0.8894, Accuracy: 215/256 (84%)


Test set: Average loss: 0.8923, Accuracy: 213/256 (83%)


Test set: Average loss: 0.8844, Accuracy: 214/256 (84%)


Test set: Ave