In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image



In [2]:

import pickle
import gzip
import random
import numpy as np

MNIST_PATH = 'mnist_28.pkl.gz'

def load_mnist(path):
    
    with gzip.open(path, 'rb') as f:
        u = pickle._Unpickler(f)
        u.encoding = 'latin1'
        train, valid, test = u.load()

    train_x, train_y = train
    valid_x, valid_y = valid
    test_x,  test_y  = test        
    return train_x, train_y, valid_x, valid_y, test_x, test_y

# Loads data where data is split into class labels
def load_mnist_split(path = MNIST_PATH):
    train_x, train_y, valid_x, valid_y, test_x, test_y = load_mnist(path)
    
    def split_by_class(x, y, num_classes):
        result_x = [0]*num_classes
        result_y = [0]*num_classes
        for i in range(num_classes):
            idx_i = np.where(y == i)[0]
            result_x[i] = x[idx_i]
            result_y[i] = y[idx_i]
        return result_x, result_y
      
    train_x, train_y = split_by_class(train_x, train_y, 10)
    return train_x, train_y, valid_x, valid_y, test_x, test_y

def create_semisupervised(x, y, n_labeled):
    n_x = x[0].shape[0]
    n_classes = 10
    if n_labeled % n_classes != 0: 
        raise("n_labeled (wished number of labeled samples) not divisible by n_classes (number of classes)")
    n_labels_per_class = n_labeled//n_classes
    x_labeled = [0]*n_classes
    x_unlabeled = [0]*n_classes
    y_labeled = [0]*n_classes
    y_unlabeled = [0]*n_classes
    #p=range(10)#10% error
    #p=list(p)#10% error
    #random.shuffle(p)#10% error
    for i in range(n_classes):
        idx = range(x[i].shape[0])
        random.seed(1412)
        idx=list(idx)
        random.shuffle(idx)
        x_labeled[i]   = x[i][idx[:n_labels_per_class]]
        y_labeled[i]   = y[i][idx[:n_labels_per_class]]
     #   y_labeled[i][0]=p[i]#10% error
        x_unlabeled[i] = x[i][idx[n_labels_per_class:]]
        y_unlabeled[i] = y[i][idx[n_labels_per_class:]]
    return np.vstack(x_labeled), np.hstack(y_labeled), np.vstack(x_unlabeled), np.hstack(y_unlabeled)




In [3]:
def batch_generator(data, batch_size, num_epoch, shuffle = True):
    data = list(data)
    data = np.array(data)
    data_size = data.shape[0]
    num_batches_per_epoch = (data_size + batch_size - 1)//batch_size
    for epoch in range(num_epoch):
        if shuffle:
            shuffle_indices = np.random.permutation(np.arange(data_size))
            shuffled_data = data[shuffle_indices]
        else:
            shuffled_data = data

        for batch_idx in range(num_batches_per_epoch):
            start_idx = batch_idx * batch_size
            end_idx   = min((batch_idx + 1)*batch_size, data_size)
            yield(shuffled_data[start_idx:end_idx])
            


In [4]:
import math
from torch.utils.data import DataLoader,TensorDataset
data_size = 50000
batch_size = 100
n_batch_size=100
n_labelled = 100
n_epoch = 10
max_iter = n_epoch*(data_size-n_labelled)//n_batch_size

# load data from mnist
train_x, train_y, valid_x, valid_y, test_x, test_y = load_mnist_split()
# split training set
data_x_l, data_y_l, data_x_u, data_y_u = create_semisupervised(train_x, train_y, n_labelled)
data_x_u=torch.FloatTensor(data_x_u)
data_y_l=torch.LongTensor(data_y_l)
data_y_u=torch.LongTensor(data_y_u)
data_x_l=torch.FloatTensor(data_x_l)
test_x=torch.FloatTensor(test_x)
test_y=torch.LongTensor(test_y)
unlabelled_dataset=TensorDataset(data_x_u, data_y_u)
u_loader=DataLoader(dataset=unlabelled_dataset, batch_size=100, shuffle=True)
test_dataset=TensorDataset(test_x, test_y)
test_loader=DataLoader(dataset=test_dataset, batch_size=100, shuffle=True)

In [5]:
for x,y in u_loader:
        print(x.size())

torch.Size([100, 784])
torch.Size([100, 784])
torch.Size([100, 784])
torch.Size([100, 784])
torch.Size([100, 784])
torch.Size([100, 784])
torch.Size([100, 784])
torch.Size([100, 784])
torch.Size([100, 784])
torch.Size([100, 784])
torch.Size([100, 784])
torch.Size([100, 784])
torch.Size([100, 784])
torch.Size([100, 784])
torch.Size([100, 784])
torch.Size([100, 784])
torch.Size([100, 784])
torch.Size([100, 784])
torch.Size([100, 784])
torch.Size([100, 784])
torch.Size([100, 784])
torch.Size([100, 784])
torch.Size([100, 784])
torch.Size([100, 784])
torch.Size([100, 784])
torch.Size([100, 784])
torch.Size([100, 784])
torch.Size([100, 784])
torch.Size([100, 784])
torch.Size([100, 784])
torch.Size([100, 784])
torch.Size([100, 784])
torch.Size([100, 784])
torch.Size([100, 784])
torch.Size([100, 784])
torch.Size([100, 784])
torch.Size([100, 784])
torch.Size([100, 784])
torch.Size([100, 784])
torch.Size([100, 784])
torch.Size([100, 784])
torch.Size([100, 784])
torch.Size([100, 784])
torch.Size(

In [6]:
class M2(nn.Module):
    def __init__(self, x_dim,y_dim, h_dim1, h_dim2, z_dim):
        super(M2, self).__init__()
        
        # encoder part
        self.fc1 = nn.Linear(x_dim+y_dim, h_dim1)
        self.fc2 = nn.Linear(h_dim1, h_dim2)
        self.fc31 = nn.Linear(h_dim2, z_dim)
        self.fc32 = nn.Linear(h_dim2, z_dim)
        # decoder part
        self.fc4 = nn.Linear(z_dim+y_dim, h_dim2)
        self.fc5 = nn.Linear(h_dim2, h_dim1)
        self.fc6 = nn.Linear(h_dim1, x_dim)
        self.fc7 = nn.Linear(x_dim,h_dim1)
        self.fc8 = nn.Linear(h_dim1,y_dim)
    def encoder(self, x,y):#p(z|x,y)
        t=torch.cat((x,y),1)
        h =F.softplus(self.fc1(t))
        h = F.softplus(self.fc2(h))
        mu0=self.fc31(h)
        logvar=self.fc32(h) 
        return mu0, logvar
    
    def sampling(self, mu, log_var):
        std = torch.exp(0.5*log_var)
        eps = torch.randn_like(std)
        return eps.mul(std).add_(mu) # return z sample
        
    def decoder(self, z,y):
        t=torch.cat((z,y),1)
        h = F.softplus(self.fc4(t))
        h =F.softplus(self.fc5(h))
        return F.sigmoid(self.fc6(h)) 
    def classify(self,x):
        h=F.softplus(self.fc7(x))
        y_pred=F.softmax(self.fc8(h),dim=-1)
        return y_pred
    def predict(self,x):
        logits_=self.classify(x)
        y_p=logits_.detach().cpu().numpy()
        return np.argmax(y_p,axis=1)
    def forward(self, x,y):
        mu,log_var= self.encoder(x.view(-1, 784),y)
        z = self.sampling(mu, log_var)
        x_recon=self.decoder(z,y)
        return x_recon,z, mu, log_var

# build model
vae = M2(x_dim=784,y_dim=10, h_dim1= 512, h_dim2=512, z_dim=50)
if torch.cuda.is_available():
    vae.cuda()

In [7]:
optimizer = optim.Adam(vae.parameters(),lr=0.001)
# return reconstruction error + KL divergence losses
n_cluster=10
def L(x, x_recon, y, mu, logvar):
  
    def KLD(mu,logvar):
        return - 0.5*(1+logvar-mu.pow(2)-torch.exp(logvar))
    def log_bernoulli(p, x):
        epsilon = 1e-7
        return x * torch.log(p + epsilon) + (1-x) * torch.log(1-p + epsilon)
        
        # uniform dist
    prior_y = (1. / n_cluster) * torch.ones_like( y )
    logpy = - torch.sum(y * torch.log(prior_y + 1e-8), dim=1)

        # (batch_size, z_dim) -> batch_size
    kldloss = torch.sum(KLD(mu, logvar),1)
        # (batch_size, 784) -> batch_size,
    logpx   = torch.sum(log_bernoulli(x_recon, x), 1)
        
    loss = kldloss - logpx - logpy
    return loss

In [8]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [9]:
alpha=0.1*100
def train(epoch):
    vae.train()
    train_loss = 0
    for data,_ in u_loader:
        data = data.to(device)
        y_l_onehot=torch.zeros(n_labelled, n_cluster).scatter_(1, data_y_l.view(-1,1), 1)
        y_l_onehot=y_l_onehot.to(device)
        data_l=data_x_l.to(device)
#
       
        optimizer.zero_grad()
        
        x_recon_l,z_l, mu_l, log_var_l = vae(data_l,y_l_onehot)
        loss_L =  L(data_l, x_recon_l, y_l_onehot, mu_l, log_var_l)#labelled loss
        # encoder, unlabelled data
        z_u      = [0]*n_cluster
        mu_u     = [0]*n_cluster
        logvar_u =[0]*n_cluster
        y_us = []
        u=[]
        for i in range(n_cluster):
            _y = i * torch.ones(batch_size,1)
            y_us.append(torch.zeros(batch_size, n_cluster).scatter_(1, _y.long(), 1))
            y_u_onehot=y_us[i].to(device)
            x_recon_u,z_u, mu_u, log_var_u = vae(data,y_u_onehot)

            u.append(L(data, x_recon_u, y_u_onehot, mu_u, log_var_u).view(-1,1))#unlabelled loss
        #print("u",u)
        loss_u=torch.cat(u,1)
        #print("loss_u",loss_u)
        y_u_prob=vae.classify(data)
        U=torch.mul(y_u_prob, torch.sub(loss_u, -torch.log(y_u_prob+1e-8)))
        U_sum=U.sum(1)
        # Add auxiliary classification loss q(y|x)
        logits = vae.classify(data_l)
        
        # Regular cross entropy
        classication_loss = -torch.sum(y_l_onehot* torch.log(logits + 1e-8), dim=1).mean()

        loss =torch.mean(loss_L) +alpha * classication_loss + U_sum.mean()

       
        
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        
        
    print('====> Epoch: {} Average loss: {:.4f}'.format(epoch, train_loss / len(u_loader.dataset)))

In [10]:
from sklearn.metrics import accuracy_score
def test(loader):
    vae.eval()
    
    pre=[]
    tru=[]
 
    with torch.no_grad():
        for data, y in loader:
            data = data.cuda()
            tru.append(y.numpy())
            
            pre.append(vae.predict(data))
    tru=np.concatenate(tru,0)
    pre=np.concatenate(pre,0)
    print(accuracy_score(tru,pre))

In [11]:
for epoch in range(1, 51):
    train(epoch)
    print("train acc")
    test(u_loader)
    print("test acc")
    test(test_loader)
    



====> Epoch: 1 Average loss: 3.2859
train acc
0.7362925851703407
test acc
0.7433
====> Epoch: 2 Average loss: 2.1309
train acc
0.7554108216432865
test acc
0.7571
====> Epoch: 3 Average loss: 1.8779
train acc
0.7585170340681363
test acc
0.7575
====> Epoch: 4 Average loss: 1.7961
train acc
0.762184368737475
test acc
0.7599
====> Epoch: 5 Average loss: 1.7560
train acc
0.7659919839679359
test acc
0.7609
====> Epoch: 6 Average loss: 1.7286
train acc
0.7634869739478958
test acc
0.7584
====> Epoch: 7 Average loss: 1.7065
train acc
0.7834869739478958
test acc
0.7797
====> Epoch: 8 Average loss: 1.6890
train acc
0.786813627254509
test acc
0.7831
====> Epoch: 9 Average loss: 1.6734
train acc
0.7970541082164329
test acc
0.7931
====> Epoch: 10 Average loss: 1.6603
train acc
0.7987374749498998
test acc
0.7948
====> Epoch: 11 Average loss: 1.6481
train acc
0.8164128256513026
test acc
0.811
====> Epoch: 12 Average loss: 1.6377
train acc
0.8224048096192385
test acc
0.8209
====> Epoch: 13 Average loss

In [12]:
y_us=[]
_y = 1 * torch.ones(batch_size,1)
y_us.append(torch.zeros(batch_size, n_cluster).scatter_(1, _y.long(), 1))
y_u_onehot=y_us[i].to(device)
#x_recon_u,z_u, mu_u, log_var_u = vae(data_l,y_u_onehot)

NameError: ignored

In [None]:
y_us = []
u=[]
for i in range(n_cluster):
  _y = i * torch.ones(batch_size,1)
  #y_us.append(torch.zeros(batch_size, n_cluster).scatter_(1, _y, 1))



In [None]:
torch.ones(batch_size,1)