In [1]:
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as utils 

import cv2
import glob

from PIL import Image
from torchvision import transforms, utils
import random

In [2]:
mnist_test_input = np.load("MNIST/testImages.npy")/255.0
mnist_train_input = np.load("MNIST/trainImages.npy")/255.0

testInputs = torch.Tensor(mnist_test_input)
trainInputs = torch.Tensor(mnist_train_input)


mnist_test_label = np.load("MNIST/testLabels.npy")
mnist_train_label = np.load("MNIST/trainLabels.npy")

testLabels =torch.Tensor(mnist_test_label)
trainLabels =torch.Tensor(mnist_train_label)


print(trainInputs.size())
print(testInputs.size())
print(testLabels.size())
print(trainLabels.size())


torch.Size([60000, 1, 28, 28])
torch.Size([10000, 1, 28, 28])
torch.Size([10000, 10])
torch.Size([60000, 10])


In [12]:
class PuVAE(nn.Module):
    
    #                    32,                       32,   
    def __init__(self, num_chan, cl_num_chan, bottleneck_size, device):
        super(PuVAE, self).__init__()
        self.num_chan = num_chan
        self.cl_num_chan = cl_num_chan
        self.bottleneck_size = bottleneck_size
        self.device = device
        
        self.encoder = nn.Sequential(     
            # 1x28x28 -> nc x22x22
            nn.Conv2d(1, num_chan, kernel_size=4, padding=0, dilation=2),
            nn.ReLU(),         
            # nc x22x22 -> nc x16x16
            nn.Conv2d(num_chan, num_chan, kernel_size=4, padding=0, dilation=2),
            nn.ReLU(),
            # nc x16x16 -> nc x 10x 10
            nn.Conv2d(num_chan, num_chan, kernel_size=4, padding=0, dilation=2),
            nn.ReLU()
        )
        
        self.encoder_linear = nn.Sequential(       
            nn.Linear(100 * num_chan , 1024), # 10 for y 
            nn.ReLU(),
            nn.Linear(1024, 1024),
            nn.ReLU()
        )
        
        self.mean_layer = nn.Linear(1024, bottleneck_size)
        self.uncertainty_layer = nn.Linear(1024, bottleneck_size) 
        
        # this layer does not exist in actual PuVAE paper
        self.decoder_linear = nn.Sequential(
            nn.Linear(10 + bottleneck_size , 128), # 10 for y 
            nn.ReLU(),
            nn.Linear(128, 512),
            nn.ReLU()
        )
        
        self.decoder = nn.Sequential(  
            # linear -> nc x4x4
            nn.ConvTranspose2d(512, num_chan, 4, 1, 0, bias=False),
            nn.ReLU(),
            # nc x4x4 -> nc x7x7
            nn.ConvTranspose2d(num_chan, num_chan, 4, 2, 1, bias=False),
            nn.ReLU(),
            # nc x8x8 -> nc x14x14
            nn.ConvTranspose2d(num_chan, num_chan, 4, 2, 2, bias=False),
            nn.ReLU(),
            # nc x16x16 -> 1x28x28
            nn.ConvTranspose2d(num_chan, 1, 4, 2, 1, bias=False),
            nn.Sigmoid()
        )
        
        self.classifier = nn.Sequential(
            
            # input is 1x32x32 -> nc x16x16
            nn.Conv2d(1, cl_num_chan,  4, 2, 1, bias=False),
            nn.ReLU(),
            # nc x16x16 -> 2*nc x8x8
            nn.Conv2d(cl_num_chan, cl_num_chan*2,  4, 2, 1, bias=False),
            nn.ReLU(),
            # 2*nc x8x8 -> 4*nc x4x4
            nn.Conv2d(cl_num_chan*2, cl_num_chan*4,  4, 2, 1, bias=False),
            nn.ReLU()
        )
            
        kernel_length2 = self.cl_num_chan * 4 * 3 * 3
        
        self.classifying_linear = nn.Sequential(
            
            nn.Linear(kernel_length2, 512 ),
            nn.ReLU(),
            nn.Linear(512, 128 ),
            nn.ReLU(),
            nn.Linear(128, 10),
            nn.Softmax() 
        )
        
        
        
    def forward(self, x, y):
        
        # go through encoder     
        h = self.encoder(x)
        h = h.view(-1, 100 * self.num_chan )
        h = self.encoder_linear(h)
        mu = self.mean_layer(h)
        std = F.softplus(self.uncertainty_layer(h))        
        esp = torch.randn(*mu.size()) 
        
        # random sapling
        z = mu + std * esp.to(self.device) * 0.1 # <--------- this

        # go through decoder    
        z_e = torch.cat(( z, y), 1)
        z_e = self.decoder_linear(z_e)
        z_e = z_e.view( -1, 512, 1, 1)
        x = self.decoder(z_e)
        
        # go through classifier
        c = self.classifier(x)
        c = c.view(-1, self.cl_num_chan * 4 * 3 * 3)
        c = self.classifying_linear(c)

        return z, mu, std, x, c

    
def recon_kld_ce_loss(true_x, x, mu, std, true_c, c):
    
    bceloss_f = nn.BCELoss()
    ce_loss = bceloss_f(c, true_c)
    rc_loss = bceloss_f(x, true_x)
    
    KLD_element = mu.pow(2).add_(std.pow(2)) - 1 - std.log().mul_(2)
    KLD = torch.mean(KLD_element).mul_(0.5)
    
    return  rc_loss.mul_(0.01)  + KLD.mul_(0.1)  + ce_loss.mul_(10)

In [13]:
def train(model, device, optiomizer, trainInputs, trainLabels, batch_size):
    model.train()
    i = 0
    total_loss = 0.0
    while(i < len(trainInputs)):
        #print(i)
        data = trainInputs[i:i+batch_size].to(device)
        label = trainLabels[i:i+batch_size].to(device)
        
        optimizer.zero_grad() # clean of gradients
        

        z, mu, std, x, c = model(data, label)
        
        
        loss = recon_kld_ce_loss(data, x, mu, std, label, c)
        
        total_loss =+ loss.item()
        
        loss.backward() # back propagation
        optimizer.step() # move a step :D? 
        
        i = i + batch_size
        
    #print("training loss: ", total_loss)
    
def test(model, device, testInputs, testLabels,batch_size):
    model.eval()
    correct = 0.0
    i = 0
    total_loss = 0.0
    with torch.no_grad():
        while( i < len(testInputs)):
            data = testInputs[i:i+batch_size].to(device)
            label = testLabels[i:i+batch_size].to(device)

            
            z, mu, std, x, c = model(data, label)
            loss = recon_kld_ce_loss(data, x, mu, std, label, c)
            total_loss =+ loss.item()
        
            i = i + batch_size

    return total_loss/len(testInputs)


def purify(model, device, testInputs):
    
    batch_size = testInputs.shape[0]
    
    base_label = torch.Tensor(np.eye(10)).to(device)
        
    label0 = (base_label[0]).expand(batch_size, 10 )
    label1 = (base_label[1]).expand(batch_size, 10 )
    label2 = (base_label[2]).expand(batch_size, 10 )
    label3 = (base_label[3]).expand(batch_size, 10 )
    label4 = (base_label[4]).expand(batch_size, 10 )
    label5 = (base_label[5]).expand(batch_size, 10 )
    label6 = (base_label[6]).expand(batch_size, 10 )
    label7 = (base_label[7]).expand(batch_size, 10 )
    label8 = (base_label[8]).expand(batch_size, 10 )
    label9 = (base_label[9]).expand(batch_size, 10 )
    
    tv = torch.arange(0, batch_size).to(device)
    
    with torch.no_grad():
            
        data = testInputs.to(device)

        _, _, _, x0, _ = model(data, label0)
        _, _, _, x1, _ = model(data, label1)
        _, _, _, x2, _ = model(data, label2)
        _, _, _, x3, _ = model(data, label3)
        _, _, _, x4, _ = model(data, label4)
        _, _, _, x5, _ = model(data, label5)
        _, _, _, x6, _ = model(data, label6)
        _, _, _, x7, _ = model(data, label7)
        _, _, _, x8, _ = model(data, label8)
        _, _, _, x9, _ = model(data, label9)

        images = torch.stack( [data, data, data, data, data, data, data, data, data, data] )
        image_bar = torch.stack([x0, x1, x2, x3, x4, x5, x6, x7, x8, x9])

        diff = images.sub(image_bar)
        diffsq = torch.mul(diff, diff)

        diffsq_flat = diffsq.view(10, -1, 28 * 28) 

        mse = (torch.mean(diffsq_flat, dim = 2)).t()

        values, indices = torch.min(mse, 1)
        index_vector = tv.add(indices * batch_size ) 
            
        return image_bar.view(-1, 1, 28, 28)[index_vector] 

In [14]:
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
cpu = torch.device("cpu")

epoch = 500
batch_size = 500

#num_chan, cl_num_chan, bottleneck_size, batch, device
#model = PuVAE(1, 16, 64, batch_size, device).to(device)
#model = PuVAE(2, 16, 100, device).to(device)

import time

model = PuVAE(32, 8, 100, device).to(device)

tlr = 0.0001
lrtime  = 10
accs = []
optimizer = optim.Adam(model.parameters(), lr = tlr)

currTime = time.time()
for epochu in range(epoch):
    train(model, device, optimizer, trainInputs, trainLabels, batch_size)    
    #print(epochu)
    #continue
    # testing
    if epochu%1 == 0:
    
        loss = test(model, device, testInputs, testLabels, batch_size)

        print("Epoch: ", epochu, " loss: ", loss)

        fn = trainInputs[0:batch_size].to(device)
        fn_l = trainLabels[0:batch_size].to(device)

        z, mu, std, x, c= model(fn, fn_l)
        print(c.to(cpu).detach().numpy()[50])

        outputs = x.to(cpu).detach().numpy()#.reshape(32, 32)     
        output_sample = outputs[50].reshape(28, 28)
        #print(output_sample.shape)
        cv2.imwrite("output/" + str( epochu) + ".jpg",(output_sample) * 255)
        
        print(time.time()-currTime)

torch.Size([500, 1, 28, 28])
torch.Size([500, 1, 28, 28])
torch.Size([500, 1, 28, 28])
torch.Size([500, 1, 28, 28])
torch.Size([500, 1, 28, 28])
torch.Size([500, 1, 28, 28])
torch.Size([500, 1, 28, 28])
torch.Size([500, 1, 28, 28])
torch.Size([500, 1, 28, 28])
torch.Size([500, 1, 28, 28])
torch.Size([500, 1, 28, 28])
torch.Size([500, 1, 28, 28])
torch.Size([500, 1, 28, 28])
torch.Size([500, 1, 28, 28])
torch.Size([500, 1, 28, 28])
torch.Size([500, 1, 28, 28])
torch.Size([500, 1, 28, 28])
torch.Size([500, 1, 28, 28])
torch.Size([500, 1, 28, 28])
torch.Size([500, 1, 28, 28])
torch.Size([500, 1, 28, 28])
torch.Size([500, 1, 28, 28])
torch.Size([500, 1, 28, 28])
torch.Size([500, 1, 28, 28])
torch.Size([500, 1, 28, 28])
torch.Size([500, 1, 28, 28])
torch.Size([500, 1, 28, 28])
torch.Size([500, 1, 28, 28])
torch.Size([500, 1, 28, 28])
torch.Size([500, 1, 28, 28])
torch.Size([500, 1, 28, 28])


KeyboardInterrupt: 

In [None]:
def purify2(model, device, testInputs):
    
    batch_size = testInputs.shape[0]
    
    base_label = torch.Tensor(np.eye(10)).to(device)
        
    label0 = (base_label[0]).expand(batch_size, 10 )
    label1 = (base_label[1]).expand(batch_size, 10 )
    label2 = (base_label[2]).expand(batch_size, 10 )
    label3 = (base_label[3]).expand(batch_size, 10 )
    label4 = (base_label[4]).expand(batch_size, 10 )
    label5 = (base_label[5]).expand(batch_size, 10 )
    label6 = (base_label[6]).expand(batch_size, 10 )
    label7 = (base_label[7]).expand(batch_size, 10 )
    label8 = (base_label[8]).expand(batch_size, 10 )
    label9 = (base_label[9]).expand(batch_size, 10 )
    
    tv = torch.arange(0, batch_size).to(device)
    
    with torch.no_grad():
            
        data = testInputs.to(device)

        _, _, _, x0, _ = model(data, label0)
        _, _, _, x1, _ = model(data, label1)
        _, _, _, x2, _ = model(data, label2)
        _, _, _, x3, _ = model(data, label3)
        _, _, _, x4, _ = model(data, label4)
        _, _, _, x5, _ = model(data, label5)
        _, _, _, x6, _ = model(data, label6)
        _, _, _, x7, _ = model(data, label7)
        _, _, _, x8, _ = model(data, label8)
        _, _, _, x9, _ = model(data, label9)

        images = torch.stack( [data, data, data, data, data, data, data, data, data, data] )
        image_bar = torch.stack([x0, x1, x2, x3, x4, x5, x6, x7, x8, x9])

        diff = images.sub(image_bar)
        diffsq = torch.mul(diff, diff)

        diffsq_flat = diffsq.view(10, -1, 28 * 28) 

        mse = (torch.mean(diffsq_flat, dim = 2)).t()

        values, indices = torch.min(mse, 1)
        index_vector =     tv.add(indices * batch_size ) 
        
        
        outputs = (x0).to(cpu).detach().numpy()

        for i in range(99):
            output_sample = outputs[i].reshape(28, 28)
            cv2.imwrite("delmo/" + str(i) + "delmo.jpg",(output_sample) * 255)
        
        
        return image_bar.view(-1, 1, 28, 28)[index_vector] 

asd = purify2(model, device, testInputs[101:200])

outputs = (testInputs[101:200]).to(cpu).detach().numpy()

for i in range(99):
    output_sample = outputs[i].reshape(28, 28)
    cv2.imwrite("purify/" + str(i) + "real.jpg",(output_sample) * 255)
    
outputs = (asd).to(cpu).detach().numpy()

for i in range(99):
    output_sample = outputs[i].reshape(28, 28)
    cv2.imwrite("purify/" + str(i) + "puri.jpg",(output_sample) * 255)

In [None]:
torch.save(model.state_dict(), "PuVAE_param4_100")
