In [1]:
import numpy as np
import torch
from torch.autograd import Variable
from timeit import default_timer as timer
import torch.nn as nn
import torch.nn.functional as F
import random
import matplotlib.pyplot as plt
from torch.nn.utils import weight_norm


import torchvision
from torchvision import datasets, transforms

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [2]:
def load_mnist():
    
    transform = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize((0.1307,), (0.3081,))])
    
    train_set = datasets.MNIST(root = "../data",
                               transform = transform,
                               train = True,
                               download = True)
    test_set = datasets.MNIST(root = "../data",
                              transform = transform,
                              train = False)
    return train_set, test_set

In [3]:
class GaussianNoise(nn.Module):
    def __init__(self, batch_size=100, input_shape = (1,28,28),std=0.15):
        super(GaussianNoise, self).__init__()
        self.shape = (batch_size,) + input_shape
        self.noise = Variable(torch.zeros(self.shape).cuda())
        self.std = std
    
    def forward(self, x):
        self.noise.data.normal_(0, std=self.std)
        return x + self.noise

In [4]:
class CNN(torch.nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.GaussianNoise = GaussianNoise()
        self.conv = torch.nn.Sequential(torch.nn.Conv2d(1, 16, 3, padding=1),
                                         torch.nn.MaxPool2d(3, stride=2, padding=1),
                                         torch.nn.ReLU(),
                                         torch.nn.Conv2d(16, 32, 3, padding=1),
                                         torch.nn.MaxPool2d(3, stride=2, padding=1),
                                         torch.nn.ReLU())
        self.dense = torch.nn.Sequential(torch.nn.Dropout(p=0.5),
                                         torch.nn.Linear(32 * 7 * 7, 10))
    def forward(self, x):
        if(self.training):
            x = self.GaussianNoise(x)
        x = self.conv(x)
        x = x.view(-1, 32 * 7 * 7)
        x = self.dense(x)
        return x
            

In [5]:
def ramp_up(epoch, max_epochs, max_val, mult):
    if epoch == 0:
        return 0.
    elif epoch >= max_epochs:
        return max_val
    return max_val * np.exp(mult * (1. - float(epoch) / max_epochs) ** 2)

def weight_schedule(epoch, max_epochs, max_val, mult, n_labeled, n_samples):
    max_val = max_val * (float(n_labeled) / n_samples)
    return ramp_up(epoch, max_epochs, max_val, mult)


In [6]:
def sample_train(train_dataset, test_dataset, batch_size, n_labels, n_classes):
    
    n = len(train_dataset)
    
    random.seed(5)
    
    cnt = 0
    labels_index = torch.zeros(n_labels)
    unlabel_index = torch.zeros(n - n_labels)
    labels_class = n_labels // n_classes # num of labeled instances in each class
    
    for i in range(n_classes):
        class_items = (train_dataset.targets == i).nonzero()[:,0]
        n_class = len(class_items) # num instances in this class
        rand_index = np.random.permutation(np.arange(n_class)) # pertub the index
        labels_index[i * labels_class : (i + 1) * labels_class] = class_items[rand_index[:labels_class]]
        unlabel_index[cnt:cnt+n_class-labels_class] = class_items[rand_index[labels_class:]]
        cnt += n_class - labels_class
    
    unlabel_index = unlabel_index.long() #tensors used as indices must be long or byte tensors

    train_dataset.targets[unlabel_index] = -1
    train_loader = torch.utils.data.DataLoader(dataset = train_dataset,
                                              batch_size = batch_size,
                                              shuffle = False)
    test_loader = torch.utils.data.DataLoader(dataset = test_dataset,
                                              batch_size = batch_size,
                                              shuffle = False)
    return train_loader, test_loader

In [7]:
def temporal_loss(out1, out2, w, labels):
    #Supervised Loss
    def mse_loss(out1, out2):
        quad_diff = torch.sum((F.softmax(out1, dim=1) - F.softmax(out2, dim=1)) ** 2)
        return quad_diff / out1.data.nelement()
    
    def masked_crossentropy(out, labels):
        labeled_index = (labels >= 0)
        nnz = torch.nonzero(labeled_index)
        num_labeled = len(nnz)
        
        #check if labeled samples in batch, return 0 if none
        if num_labeled > 0:
            masked_outputs = torch.index_select(out, 0, nnz.view(num_labeled))
            masked_labels = labels[labeled_index]                                 
            loss = F.cross_entropy(masked_outputs, masked_labels)
            return loss, num_labeled
        return Variable(torch.FloatTensor([0.]).cuda(), requires_grad=False), 0
    
    sup_loss, num_labeled = masked_crossentropy(out1, labels)
    unsup_loss = mse_loss(out1, out2)
    return sup_loss + w * unsup_loss, sup_loss, unsup_loss, num_labeled
        
    

In [8]:
def train(model, n_labels = 100, n_epochs=300, batch_size=100,
          max_epochs=80, max_val=30., ramp_up_mult=-5.,
          n_classes=10,n_samples=60000, alpha=0.6, learning_rate=0.002):
    
    train_dataset, test_dataset = load_mnist()
    n_train = len(train_dataset)
    
    #build model
    model.cuda()
    
    train_loader, test_loader = sample_train(train_dataset, test_dataset, batch_size,
                                            n_labels, n_classes)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.99))
    
    #train
    model.train()
    losses = []
    supvised_loss = []
    unsupvised_loss = []
    
    
    Z = torch.zeros(n_train, n_classes).float().cuda()
    z = torch.zeros(n_train, n_classes).float().cuda()
    outputs = torch.zeros(n_train, n_classes).float().cuda()
    
    for epoch in range(n_epochs):
        
        #unsupervised loss weight
        w = weight_schedule(epoch, max_epochs, max_val, ramp_up_mult, n_labels, n_samples)
        if (epoch + 1) % 10 == 0:
            print('Unsupervised loss weight: {}'.format(w))
            
        w = torch.autograd.Variable(torch.FloatTensor([w]).cuda(), requires_grad=False)
        
        l = []
        sup_l = []
        unsup_l = []
        for i, (images, labels) in enumerate(train_loader):
            
            t = timer()
            images = Variable(images.cuda())
            labels = Variable(labels.cuda(), requires_grad=False)
            
            optimizer.zero_grad()
            
            out = model(images)
            
            zcomp = Variable(z[i * batch_size: (i+1)*batch_size], requires_grad=False)
            loss, sup_loss, unsup_loss, num_sup = temporal_loss(out, zcomp, w, labels)
            
            #save outputs and losses
            outputs[i * batch_size: (i+1) * batch_size] = out.data.clone()
            l.append(loss.item())
            sup_l.append(num_sup * sup_loss.item())
            unsup_l.append(unsup_loss.item())
            
            #backprop
            loss.backward()
            optimizer.step()
            
            if (epoch + 1) % 10 == 0:
                if i + 1 == 2 * 300:
                    print ('Epoch [%d/%d], Step [%d/%d], Loss: %.6f, Time (this epoch): %.2f s' 
                           %(epoch + 1, n_epochs, i + 1, len(train_dataset) // batch_size, np.mean(l), timer() - t))
                elif (i + 1) % 300 == 0:
                    print ('Epoch [%d/%d], Step [%d/%d], Loss: %.6f' 
                           %(epoch + 1, n_epochs, i + 1, len(train_dataset) // batch_size, np.mean(l)))

                
                
        Z = alpha * Z + (1. - alpha) * outputs
        z = Z * (1. / (1. - alpha ** (epoch + 1)))
        
        #handle metrics, losses, etc
        ave_loss = np.mean(l)
        losses.append(ave_loss)
        supvised_loss.append((1. / n_labels) * np.sum(sup_l))
        unsupvised_loss.append(np.mean(unsup_l))
        

    model.eval()
    with torch.no_grad():
        correct = 0
        total = 0
        for i, (samples, labels) in enumerate(test_loader):
            samples = Variable(samples.cuda())
            labels = Variable(labels.cuda())
            out = model(samples)
            _,pred = torch.max(out.data, 1)
            total += labels.size(0)
            correct += (pred == labels.data).sum()
        print("Test Accuracy: %.4f", 1.0 * correct.item() / total)
    return losses, supvised_loss, unsupvised_loss

In [9]:
model = CNN()
losses, supervised_loss, unsupervised_loss = train(model)

Unsupervised loss weight: 0.0009740835053472244
Epoch [10/300], Step [300/600], Loss: 0.098268
Epoch [10/300], Step [600/600], Loss: 0.070119, Time (this epoch): 0.01 s
Unsupervised loss weight: 0.0027318847495129677
Epoch [20/300], Step [300/600], Loss: 0.024207
Epoch [20/300], Step [600/600], Loss: 0.044935, Time (this epoch): 0.01 s
Unsupervised loss weight: 0.0065534508315721485
Epoch [30/300], Step [300/600], Loss: 0.014359
Epoch [30/300], Step [600/600], Loss: 0.007275, Time (this epoch): 0.01 s
Unsupervised loss weight: 0.013446808012990042
Epoch [40/300], Step [300/600], Loss: 0.120434
Epoch [40/300], Step [600/600], Loss: 0.060298, Time (this epoch): 0.01 s
Unsupervised loss weight: 0.023599883024449233
Epoch [50/300], Step [300/600], Loss: 0.000698
Epoch [50/300], Step [600/600], Loss: 0.053739, Time (this epoch): 0.01 s
Unsupervised loss weight: 0.035427620478902945
Epoch [60/300], Step [300/600], Loss: 0.051736
Epoch [60/300], Step [600/600], Loss: 0.026030, Time (this epoc