In [None]:
%matplotlib inline

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision
import torchvision.transforms as transforms
from torchvision.utils import make_grid
from torch.distributions.multivariate_normal import MultivariateNormal

import matplotlib.pyplot as plt
import numpy as np

In [None]:
#Implementation of a random gaussian noise transform to artificially damage the dataset
class AddGaussianNoise(object):
    def __init__(self, mean=0., std=1.):
        self.std = std
        self.mean = mean
        
    def __call__(self, tensor):
        return tensor + torch.randn(tensor.size()) * self.std + self.mean
    
    def __repr__(self):
        return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)

In [None]:
#Values of noise for the gaussian transforms that are experimented on for image reconstruction
sigmas = [0.05, 0.25, 0.5]
train_loaders = {}
valid_loaders = {}
test_loaders = {}
no_noise_test_loaders = {}

batch_size = 128

for sigma in sigmas:
    transform = transforms.Compose(
        [transforms.ToTensor(),
         #transforms.GaussianBlur(kernel_size = (7,7), sigma=percent),
         AddGaussianNoise(0., sigma)
         #transforms.RandomErasing(p=1, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False),
    ])
    
    transform_no_noise = transforms.Compose([
        transforms.ToTensor(),
    ])

    dataset = torchvision.datasets.KMNIST(root='../data', train=True, download=True, transform=transform)
    test_set = torchvision.datasets.KMNIST(root='../data', train=False, download=True, transform=transform)
    no_noise_test = torchvision.datasets.KMNIST(root='../data', train=False, download=True, transform=transform_no_noise)
    train_set, val_set = torch.utils.data.random_split(dataset, [48000, 12000])
    
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=0)
    val_loader = torch.utils.data.DataLoader(val_set, batch_size=batch_size, shuffle=True, num_workers=0)
    test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=0)
    no_noise_test_loader = torch.utils.data.DataLoader(no_noise_test, batch_size = batch_size, shuffle=False, num_workers=0)
    
    train_loaders[str(sigma)] = train_loader
    valid_loaders[str(sigma)] = val_loader
    test_loaders[str(sigma)] = test_loader
    no_noise_test_loaders[str(sigma)] = no_noise_test_loader

device = 'cuda'

In [None]:
class BVAE(nn.Module):
    #Implementation of Beta-VAE, based on standard VAE with Beta parameter = 4
    #Beta determined through grid-search of parameters (2, 6)
    def __init__(self, n_in, n_hid, z_dim, beta: int=4):
        super(BVAE, self).__init__()
        self.fc1 = nn.Linear(n_in, n_hid)
        self.fc21 = nn.Linear(n_hid, z_dim)
        self.fc22 = nn.Linear(n_hid, z_dim)
        self.fc3 = nn.Linear(z_dim, n_hid)
        self.fc4 = nn.Linear(n_hid, n_in)

    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1)

    def reparameterize(self, mu, logvar):
        stdev = torch.exp(0.5*logvar)
        eps = torch.randn_like(stdev)
        return mu + eps*stdev

    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h3))

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar
    
def beta_loss_function(recon_x, x, mu, logvar, Beta):

    BCE = F.binary_cross_entropy(recon_x, x, reduction='sum') # BCE = -Negative Log-likelihood
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) # KL Divergence b/w q_\phi(z|x) || p(z)
    return BCE + Beta*KLD

In [None]:
class FAE(nn.Module):
    def __init__(self, n_in, n_hid, z_dim,num_classes=10):
        super(FAE, self).__init__()

        self.fc1 = nn.Linear(n_in, n_hid)
        self.fc21 = nn.Linear(n_hid, z_dim)
        self.fc22 = nn.Linear(n_hid, z_dim)
        self.fc3 = nn.Linear(z_dim, n_hid)
        self.fc4 = nn.Linear(n_hid, n_in)

    def encode(self, x):

        h1 = F.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1)

    def reparameterize(self, mu, logvar):

        stdev = torch.exp(0.5*logvar)
        eps = torch.randn_like(stdev)
        return mu + eps*stdev
    
    def decode(self, z):

        h3 = F.relu(self.fc3(z))
        out = self.fc4(h3)
        output = torch.sigmoid(out)
        return output
    


    def forward(self, x):
        
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

def fisher_loss_function(recon_x, x, mu, logvar, model):
    sigma = torch.exp(0.5*logvar)
    eps = torch.randn_like(sigma)
    z = mu + eps*sigma
    
    #gradient with respect to z of the derived distribution q(z|x)
    A = -eps/sigma
    
    #derived gradient of the prior standard normal distribution
    B = -(z)
    
    #gradient with respect to z of the ground-truth distribution
    C = C_grad(z,x,model)
    
    #L2 norm of A - B + C
    ABC = torch.linalg.norm(A-B+C,ord=2)**2
    
    #l2 norm of the difference between original images and image reconstructions
    D = torch.linalg.norm(x - recon_x,ord=2)**2
    
    #L2 norm of the gradient with respect to x of the derived distribution q(z|X)
    E = torch.linalg.norm(E_grad(z,x,model),ord=2)**2
    
    loss = ABC + D + E
    #loss = D
    #print(torch.sum(A).item(),torch.sum(B).item(),torch.sum(C).item(),D.item(),E.item())
    return loss

def C_grad(z,x,model):
    #representation of a point estimate gradient from the ground-truth distribution given a sampled value of z
    x = x.detach()
    z = z.detach()
    z.requires_grad = True #it is necessary to detach so that we can set z to require grad within the function, else it returns None
    p_x_z = model.decode(z) #using a sampled z latent vector, it is decoded into a reconstructed image
    p_x_z_distr = F.binary_cross_entropy(p_x_z, x, reduction='sum') #cross entropy loss is used to calculate the negative log likelihood
    p_x_z_distr.backward() #sets the stage for taking the gradient of this function
    return z.grad #take the gradient with respect to z as specified by C

def E_grad(z,x,model):
    #
    x = x.detach()
    z = z.detach()
    x.requires_grad = True #similarly need to detach in order to require grad so its not None
    mu,logvar = model.encode(x) #encode the input image into the parameters that define the latent distribution
    
    #The multivariate normal is able to create a distrbution for each batch
    #Each batch needs its own mu of size 20, and covariance matrix of size 20x20
    diag = []
    for b in logvar: #for each batch
        var = torch.exp(b) #we have logvar so convert to variance which is what goes on the diagonal of covariance matrix
        d = torch.diag(var).detach().cpu().numpy().tolist() #use diag to convert 20 vector into diagonal of 20x20 matrix
        diag += [d] #add batch diagonal to list
    diag = torch.FloatTensor(np.array(diag)).to(device) #convert back to a tensor
    
    m = MultivariateNormal(mu,diag) #create the normal gaussian distribution for all batches
    loglike = m.log_prob(z) #calculate the log likelihood of the distribution
    loglike.mean().backward() #take the mean across all of the batches because loss needs to be a scalar
    return x.grad # use this to calculate gradient with respect to x as needed for E

In [None]:
def fisher_train(model, device, train_loader, valid_loader, optimizer, epoch):
    #Used to train and validation the fisher autoencoder model
    train_loss = 0
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data = data.view(data.size(0),-1)
        data = data.to(device)
        
        optimizer.zero_grad()
        output, mu, logvar = model(data)
        loss = fisher_loss_function(output, data, mu, logvar, model)
        loss.backward()
        
        optimizer.step()
        train_loss += loss.item()
        if batch_idx % (len(train_loader)//2) == 0:
            print('Train({})[{:.0f}%]: Loss: {:.4f}'.format(
                epoch, 100. * batch_idx / len(train_loader), train_loss/(batch_idx+1)))

    model.eval()
    valid_loss = 0
    #with torch.no_grad():
    for data, target in valid_loader:
        data = data.view(data.size(0),-1)
        data = data.to(device)
        output, mu, logvar = model(data)
        loss = fisher_loss_function(output, data, mu, logvar, model)
        valid_loss += loss.item() # sum up batch loss
    valid_loss = (valid_loss*batch_size)/len(val_loader.dataset)
    print('Valid({}): Loss: {:.4f}'.format(
        epoch, valid_loss))
    return valid_loss

def fisher_test(model, device, test_loader, epoch):
    #used to test the ability of the fisher autoencoder to reconstruct on a test dataset
    model.eval()
    test_loss = 0
    for data, target in test_loader:
        data = data.view(data.size(0),-1)
        data = data.to(device)
        output, mu, logvar = model(data)
        loss = fisher_loss_function(output, data, mu, logvar, model)
        test_loss += loss.item() # sum up batch loss
    test_loss = (test_loss*batch_size)/len(test_loader.dataset)
    print('Test({}): Loss: {:.4f}'.format(
        epoch, test_loss))
    return test_loss

In [None]:
def beta_train(model, device, train_loader, valid_loader, optimizer, epoch, Beta=4):
    #Used to train and validate the beta VAE model
    train_loss = 0
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data = data.view(data.size(0),-1)
        data = data.to(device)
        
        optimizer.zero_grad()
        output, mu, logvar = model(data)
        loss = beta_loss_function(output, data, mu, logvar, Beta)
        loss.backward()
        
        optimizer.step()
        train_loss += loss.item()
        if batch_idx % (len(train_loader)//2) == 0:
            print('Train({})[{:.0f}%]: Loss: {:.4f}'.format(
                epoch, 100. * batch_idx / len(train_loader), train_loss/(batch_idx+1)))

    model.eval()
    valid_loss = 0
    with torch.no_grad():
        for data, target in valid_loader:
            data = data.view(data.size(0),-1)
            data = data.to(device)
            output, mu, logvar = model(data)
            loss = beta_loss_function(output, data, mu, logvar, Beta)
            valid_loss += loss.item() # sum up batch loss
    valid_loss = (valid_loss*batch_size)/len(val_loader.dataset)
    print('Valid({}): Loss: {:.4f}'.format(
        epoch, valid_loss))
    return valid_loss

def test_beta(model, device, test_loader, epoch, Beta):
    #Used to test the abilty of the beta autoencoder to reconstruct on a test dataset
    model.eval()
    test_loss = 0

    for data, target in test_loader:
        data = data.view(data.size(0),-1)
        data = data.to(device)
        output, mu, logvar = model(data)
        loss = beta_loss_function(output, data, mu, logvar, Beta)
        test_loss += loss.item() # sum up batch loss
    test_loss = (test_loss*batch_size)/len(test_loader.dataset)
    print('Test({}): Loss: {:.4f}'.format(
        epoch, test_loss))
    return test_loss

In [None]:
seed = 1
num_epochs = 10
lr = 0.001
n_in = 28*28
n_hid = 400
z_dim = 20

#Tuned from 2,6 by 0.25: 4 was the lowest validation loss
Beta = 4

device = torch.device(device)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

In [None]:
# dict of validation scores for each standard deviation value, will be used to plot validation score over epoch for different noise levels
fisher_validation_scores = {}

#Code to train each model
#Need to change the model type and loss function for obtaining results for different models
#Also need to change saved model paths
for key in sigmas:
    valid_scores = []
    key = str(key)
    train_loader = train_loaders[key]
    valid_loader = valid_loaders[key]
    test_loader = test_loaders[key]
    #change model for fisher
    model = FAE(n_in, n_hid, z_dim).to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    
    for epoch in range(1, num_epochs+1):
        min_validation_loss = np.inf
        #change train function for fisher
        val_loss = fisher_train(model, device, train_loader, valid_loader, optimizer, epoch)
        valid_scores.append(val_loss)
        if abs(min_validation_loss) > abs(val_loss):
            min_validation_loss = val_loss
            #change model location for fisher
            torch.save(model.state_dict(), str(key)+'_saved_fisher_model.pth')
    model_best_state_dict = torch.load(str(key)+'_saved_fisher_model.pth')
    model2 = FAE(n_in, n_hid, z_dim).to(device)
    model2.load_state_dict(model_best_state_dict)
    test_loss = fisher_test(model2, device, test_loader, epoch)
    print(test_loss)
    fisher_validation_scores[key] = valid_scores

In [None]:
#Run this to plot validation loss over different epochs
test = np.arange(1, num_epochs+1, 1)
for key in fisher_validation_scores:
    plt.plot(test, fisher_validation_scores[key], label = key)

plt.legend()
plt.xlabel("Epochs")
plt.ylabel("Validation Loss")
plt.title("Validation Loss over Epochs For Fisher AE Gaussian Noise")
plt.savefig('fisher_ae_blackbox.png')
plt.show()

In [None]:
#Plot in order top to bottom: Original Image, Gaussian Image, Reconstructed image

for key in sigmas:
    title = "Image Reconstruction: N(0, "+str(key)+")"
    key = str(key)
    test_data, test_target = next(iter(test_loaders[key]))
    nn_test_data, nn_test_target = next(iter(no_noise_test_loaders[key]))
    model_best_state_dict = torch.load(str(key)+'_saved_fisher_model.pth')
    model_best = FAE(n_in, n_hid, z_dim).to(device)
    model_best.load_state_dict(model_best_state_dict)
    test_data = test_data.to(device)
    test_data = test_data.to('cpu')

    data = test_data
    data_size = data.size()
    data = data.view(data.size(0),-1).to(device)
    output, _, _ = model_best(data)
    output = output.detach()
    output = output.to('cpu')
    
    f, axarr = plt.subplots(3,1, figsize=(15, 15)) 
    f.suptitle(title, fontsize = 20)
    f.tight_layout()
    # use the created array to output your multiple images. In this case I have stacked 4 images vertically
    axarr[1].imshow(test_data[3][0])
    axarr[0].imshow(nn_test_data[3][0])
    axarr[2].imshow(output[3].reshape(28, 28))
    plt.savefig('full_figure_{}_fisher_blackbox.png'.format(key))
    

In [None]:
#Plot black box images from the saved gaussian model
fig = plt.figure(figsize = (10,3))
for dig in range(10):
    idx = labels.index(dig)
    ax1 = fig.add_subplot(3,10,dig+1)
    plt.imshow(nn_test_data[idx][0])
    plt.axis('off')
    fig.add_subplot(3,10,10+dig+1)
    plt.imshow(test_data[idx][0])
    plt.axis('off')
    fig.add_subplot(3,10,20+dig+1)
    model_best_state_dict = torch.load(str(key)+'_saved_fisher_model.pth')
    model_best = FAE(n_in, n_hid, z_dim).to(device)
    model_best.load_state_dict(model_best_state_dict)
    output, _, _ = model_best(data)
    output = output.detach()
    output = output.to('cpu')
    plt.imshow(output[idx].reshape(28,28))
    plt.axis('off')
plt.tight_layout()

model_best_state_dict = torch.load(str(0.05)+'_saved_fisher_model.pth')
model_best = FAE(n_in, n_hid, z_dim).to(device)
model_best.load_state_dict(model_best_state_dict)
output, _, _ = model_best(torch.FloatTensor(BBI.reshape(1,784)).to(device))
output = output.detach()
output = output.to('cpu')
plt.imshow(output.reshape(28,28))


I = nn_test_data[3][0]
BBI = I
BBI[15:25,2:12] = 0
plt.imshow(BBI)

BBIR = BBI
BBIR[15:25,2:12] = output.reshape(28,28)[15:25,2:12]
plt.imshow(BBIR)
