In [2]:
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as utils 

import cv2
import glob

from PIL import Image
from torchvision import transforms, utils
import random

import os
import time

In [3]:

# ADD your model here
class WrapperMNIST(nn.Module):
    def __init__(self, in_channels, num_classes):
        super(WrapperMNIST, self).__init__()

        self.in_channels = in_channels
        self.num_classes = num_classes

        self.conv1_1 = nn.Conv2d(self.in_channels, 32, kernel_size=3, padding=1)
        self.conv1_2 = nn.Conv2d(32, 32, kernel_size=3, padding=1)

        self.maxpool1 = nn.MaxPool2d(kernel_size=2)

        self.conv2_1 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv2_2 = nn.Conv2d(64, 64, kernel_size=3, padding=1)

        self.maxpool2 = nn.MaxPool2d(kernel_size=2)

        self.fc1 = nn.Linear(7*7*64, 200)
        self.fc2 = nn.Linear(200, self.num_classes)

    # you will have to set puvae and source_classifier to be eval mode
    def forward(self, x, puvae, source_classifier, device):

        
        x = purify(source_classifier, puvae, device, x)
        
        x = F.relu(self.conv1_1(x))
        x = F.relu(self.conv1_2(x))

        x = self.maxpool1(x)

        x = F.relu(self.conv2_1(x))
        x = F.relu(self.conv2_2(x))

        x = self.maxpool2(x)

        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)

        return x







class Model_MNIST(nn.Module):
    def __init__(self, in_channels, num_classes):
        super(Model_MNIST, self).__init__()

        self.in_channels = in_channels
        self.num_classes = num_classes

        self.conv1_1 = nn.Conv2d(self.in_channels, 32, kernel_size=3, padding=1)
        self.conv1_2 = nn.Conv2d(32, 32, kernel_size=3, padding=1)

        self.maxpool1 = nn.MaxPool2d(kernel_size=2)

        self.conv2_1 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv2_2 = nn.Conv2d(64, 64, kernel_size=3, padding=1)

        self.maxpool2 = nn.MaxPool2d(kernel_size=2)

        self.fc1 = nn.Linear(7*7*64, 200)
        self.fc2 = nn.Linear(200, self.num_classes)


    def forward(self, x):
        
        x = F.relu(self.conv1_1(x))
        x = F.relu(self.conv1_2(x))

        x = self.maxpool1(x)

        x = F.relu(self.conv2_1(x))
        x = F.relu(self.conv2_2(x))

        x = self.maxpool2(x)

        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)

        return x
    
# assumes everything is moved to GPU
def test_MNIST(model, testInputs, testLabels):

    correct = 0.0
    i = 0
    total_loss = 0.0
    with torch.no_grad():
        data = testInputs
        target = testLabels
        output = model(data)

        return output
    
    

class PuVAE(nn.Module):
    
    #                    32,                       32,   
    def __init__(self, num_chan, cl_num_chan, bottleneck_size, device):
        super(PuVAE, self).__init__()
        self.num_chan = num_chan
        self.cl_num_chan = cl_num_chan
        self.bottleneck_size = bottleneck_size
        self.device = device
        
        self.encoder = nn.Sequential(     
            # 1x28x28 -> nc x22x22
            nn.Conv2d(1, num_chan, kernel_size=4, padding=0, dilation=2),
            nn.ReLU(),         
            # nc x22x22 -> nc x16x16
            nn.Conv2d(num_chan, num_chan, kernel_size=4, padding=0, dilation=2),
            nn.ReLU(),
            # nc x16x16 -> nc x 10x 10
            nn.Conv2d(num_chan, num_chan, kernel_size=4, padding=0, dilation=2),
            nn.ReLU()
        )
        
        self.encoder_linear = nn.Sequential(       
            nn.Linear(100 * num_chan + 10, 1024), # 10 for y 
            nn.ReLU(),
            nn.Linear(1024, 1024),
            nn.ReLU()
        )
        
        self.mean_layer = nn.Linear(1024, bottleneck_size)
        self.uncertainty_layer = nn.Linear(1024, bottleneck_size) 
        
        # this layer does not exist in actual PuVAE paper
        self.decoder_linear = nn.Sequential(
            nn.Linear(10 + bottleneck_size , 128), # 10 for y 
            nn.ReLU(),
            nn.Linear(128, 512),
            nn.ReLU()
        )
        
        self.decoder = nn.Sequential(  
            # linear -> nc x4x4
            nn.ConvTranspose2d(512, num_chan, 4, 1, 0, bias=False),
            nn.ReLU(),
            # nc x4x4 -> nc x7x7
            nn.ConvTranspose2d(num_chan, num_chan, 4, 2, 1, bias=False),
            nn.ReLU(),
            # nc x8x8 -> nc x14x14
            nn.ConvTranspose2d(num_chan, num_chan, 4, 2, 2, bias=False),
            nn.ReLU(),
            # nc x16x16 -> 1x28x28
            nn.ConvTranspose2d(num_chan, 1, 4, 2, 1, bias=False),
            nn.Sigmoid()
        )
        
        
        
    def forward(self, x, y, source_classifier):
        
        # go through encoder     
        h = self.encoder(x)
        h = h.view(-1, 100 * self.num_chan )
        
        h_e = torch.cat(( h, y), 1)
        h = self.encoder_linear(h_e)
        mu = self.mean_layer(h)
        std = F.softplus(self.uncertainty_layer(h))        
        esp = torch.randn(*mu.size()).mul_(0.1) 
        
        # random sapling
        z = mu + std * esp.to(self.device) # <--------- this

        # go through decoder    
        z_e = torch.cat(( z, y), 1)
        z_e = self.decoder_linear(z_e)
        z_e = z_e.view( -1, 512, 1, 1)
        x = self.decoder(z_e)
        
        # go through classifier
        
        c = source_classifier(x)
        

        return z, mu, std, x, c

    
def recon_kld_ce_loss(true_x, x, mu, std, true_c, c):
    
    bceloss_f = nn.BCELoss()
    bceloss_l = nn.BCEWithLogitsLoss()
    
    ce_loss = bceloss_l(c, true_c)
    rc_loss = bceloss_f(x, true_x)
    
    KLD_element = mu.pow(2).add_(std.pow(2)) - 1 - std.log().mul_(2)
    KLD = torch.mean(KLD_element).mul_(0.5)
    
    return  rc_loss.mul_(0.01)  + KLD.mul_(0.1)  + ce_loss.mul_(10)


def purify(sc, model, device, testInputs):
    
    batch_size = testInputs.shape[0]
    
    base_label = torch.Tensor(np.eye(10)).to(device)
        
    label0 = (base_label[0]).expand(batch_size, 10 )
    label1 = (base_label[1]).expand(batch_size, 10 )
    label2 = (base_label[2]).expand(batch_size, 10 )
    label3 = (base_label[3]).expand(batch_size, 10 )
    label4 = (base_label[4]).expand(batch_size, 10 )
    label5 = (base_label[5]).expand(batch_size, 10 )
    label6 = (base_label[6]).expand(batch_size, 10 )
    label7 = (base_label[7]).expand(batch_size, 10 )
    label8 = (base_label[8]).expand(batch_size, 10 )
    label9 = (base_label[9]).expand(batch_size, 10 )
    
    tv = torch.arange(0, batch_size).to(device)
    
    with torch.no_grad():
            
        data = testInputs.to(device)

        _, _, _, x0, _ = model(data, label0)
        _, _, _, x1, _ = model(data, label1)
        _, _, _, x2, _ = model(data, label2)
        _, _, _, x3, _ = model(data, label3)
        _, _, _, x4, _ = model(data, label4)
        _, _, _, x5, _ = model(data, label5)
        _, _, _, x6, _ = model(data, label6)
        _, _, _, x7, _ = model(data, label7)
        _, _, _, x8, _ = model(data, label8)
        _, _, _, x9, _ = model(data, label9)

        images = torch.stack( [data, data, data, data, data, data, data, data, data, data] )
        image_bar = torch.stack([x0, x1, x2, x3, x4, x5, x6, x7, x8, x9])

        diff = images.sub(image_bar)
        diffsq = torch.mul(diff, diff)

        diffsq_flat = diffsq.view(10, -1, 28 * 28) 

        mse = (torch.mean(diffsq_flat, dim = 2)).t()

        values, indices = torch.min(mse, 1)
        index_vector = tv.add(indices * batch_size ) 
            
        return image_bar.view(-1, 1, 28, 28)[index_vector] 

In [None]:
# laod puvae here
model = PuVAE(32, 8, 50, device).to(device)
model.load_state_dict( "Name" )
model.eval()

# load source classifier here
f = Model_MNIST(1, 10).to(device)
model_name = "_Model_FMNIST"
checkpoint_path = os.path.join( 'best%s.pth.tar'%(model_name))
checkpoint = torch.load(checkpoint_path, map_location='cpu')
f.load_state_dict(checkpoint["state_dict"])
f.eval()