In [1]:
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
from torch.distributions import Normal
from torch.distributions.kl import kl_divergence
import numpy as np
import os
import random
import matplotlib.pyplot as plt

In [64]:
batch_size=256
epochs=500
seed=1
cuda=False and torch.cuda.is_available()
r_dim=128
z_dim=128
result_path="results_np_z_y_hat_parmu/"

torch.manual_seed(seed)
random.seed(seed)
device = torch.device("cpu") #"cuda" if args.cuda else 



test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('data', train=False, transform=transforms.ToTensor()),
    batch_size=batch_size, shuffle=True)

def get_context_idx(N):
    # generate the indeces of the N context points in a flattened image
    idx = random.sample(range(0, 784), N)
    idx = torch.tensor(idx, device=device)
    return idx

def generate_grid(h, w):
    rows = torch.linspace(0, 1, h, device=device)
    cols = torch.linspace(0, 1, w, device=device)
    grid = torch.stack([cols.repeat(h, 1).t().contiguous().view(-1), rows.repeat(w)], dim=1)
    grid = grid.unsqueeze(0)
    return grid
def idx_to_y(idx, data):
    # get the [0;1] pixel intensity at each index
    y = torch.index_select(data, dim=1, index=idx)
    return y
def idx_to_xy(batch_size,N,data):
    context_idx = get_context_idx(N)
    x = torch.index_select(x_grid, dim=1, index=context_idx)
    x = x.expand(batch_size, -1, -1)
    for i in range(0,batch_size):
        idx=get_context_idx(N)
        x[i,:,:]=torch.index_select(x_grid, dim=1, index=idx)
        y[i,:,:]= torch.index_select(data, dim=1, index=idx)

    return x,y

In [3]:
class NP(nn.Module):
    def __init__(self, r_dim,z_dim):
        super(NP, self).__init__()
        self.r_dim = r_dim
        self.z_dim = z_dim
    
        self.h_1 = nn.Linear(3, 400)
        self.h_2 = nn.Linear(400, 400)
        self.h_3 = nn.Linear(400, self.r_dim)

        self.r_to_z_mean = nn.Linear(self.r_dim, self.z_dim)
        self.r_to_z_logvar = nn.Linear(self.r_dim, self.z_dim)

        self.g_1 = nn.Linear(self.z_dim + 2, 400)
        self.g_2 = nn.Linear(400, 400)
        self.g_3 = nn.Linear(400, 400)
        self.g_4 = nn.Linear(400, 400)
        self.g_y_mu = nn.Linear(400, 1)
        self.g_y_sigma = nn.Linear(400, 1)
        

    def h(self, x_y):
        x_y = F.relu(self.h_1(x_y))
        x_y = F.relu(self.h_2(x_y))
        x_y = F.relu(self.h_3(x_y))
        return x_y

    def aggregate(self, r):
        return torch.mean(r, dim=1)

    def g(self,z_sample, x_target):
        z_et_x = torch.cat([z_sample, x_target], dim=2)
        input = F.relu(self.g_1(z_et_x))
        input = F.relu(self.g_2(input))
        input = F.relu(self.g_3(input))
        input = F.relu(self.g_4(input))
        y_mu=self.g_y_mu(input)
        y_sigma=sigma = 0.1 + 0.9 * F.softplus(self.g_y_sigma(input))
        return y_mu,y_sigma
    
    
    def xy_to_z_params(self, x, y):
        
        x_y = torch.cat([x, y], dim=2)
        
        r_i = self.h(x_y)
        r = self.aggregate(r_i)
        mu = self.r_to_z_mean(r)
        logvar = self.r_to_z_logvar(r)
        sigma=0.1+0.9*torch.sigmoid(logvar)
        return mu, sigma

    def forward(self, x_context, y_context, x_all=None, y_all=None):
        
        #produire z
        z_context_mu,z_context_sigma = self.xy_to_z_params(x_context, y_context)  # (mu, logvar) of z
        q_context = Normal(z_context_mu, z_context_sigma)
        # reconstruct the whole image including the provided context points
        x_target = x_grid.expand(y_context.shape[0], -1, -1)
        
        if self.training:  # loss function will try to keep z_context close to z_all         
            z_target_mu,z_target_sigma = self.xy_to_z_params(x_context, y_context)
            q_target = Normal(z_target_mu, z_target_sigma) 
            z_sample = q_target.rsample()
            
            z_sample = z_sample.unsqueeze(1).expand(-1, 784, -1)
            
            # Get parameters of output distribution
            y_pred_mu, y_pred_sigma = self.g(z_sample,x_target)
            p_y_pred = Normal(y_pred_mu, y_pred_sigma)

            return p_y_pred, q_target, q_context
        else:  # at test time we don't have the image so we use only the context
            z_sample = q_context.rsample()
            z_sample = z_sample.unsqueeze(1).expand(-1, 784, -1)
            # Predict target points based on context
            y_pred_mu, y_pred_sigma = self.g(z_sample,x_target)
            p_y_pred = Normal(y_pred_mu, y_pred_sigma)
            return p_y_pred,q_context,q_context


# In[6]:



def np_loss(p_y_pred, y_target, q_target, q_context):
    
    #return logprob + KLD
    log_likelihood = p_y_pred.log_prob(y_target).mean(dim=0).sum()
    # KL has shape (batch_size, r_dim). Take mean over batch and sum over
    # r_dim (since r_dim is dimension of normal distribution)
    kl = kl_divergence(q_target, q_context).mean(dim=0).sum()
    return -log_likelihood + kl
    

In [4]:
model = NP(r_dim,z_dim).to(device)
model.load_state_dict(torch.load(result_path+"model_128_1.pt",map_location='cpu'))
model.eval()
x_grid = generate_grid(28, 28)

In [5]:
def plot_img(epoch):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for i, (y_all, _) in enumerate(test_loader):
            y_all = y_all.to(device).view(y_all.shape[0], -1, 1)
            

            if i == 0:  # save PNG of reconstructed examples
                plot_Ns = [10,100,  300, 784]
                num_examples = min(batch_size, 20)
                for N in plot_Ns:
                    recons = []
                    recons1=[]
                    recons2=[]
                    context_idx = get_context_idx(N)
                    
                    x_context = idx_to_x(context_idx, batch_size)
                    y_context = idx_to_y(context_idx, y_all)
                    for d in range(3):
                        p_y_pred, _, _ = model(x_context, y_context)
                        if d==0:
                            recons.append(p_y_pred.rsample()[:num_examples])
                        recons1.append(p_y_pred.mean[:num_examples])  
                        recons2.append(p_y_pred.stddev[:num_examples])  
                    recons = torch.cat(recons).view(-1, 1, 28, 28).expand(-1, 3, -1, -1)
                    recons1 = torch.cat(recons1).view(-1, 1, 28, 28).expand(-1, 3, -1, -1)
                    recons2 = torch.cat(recons2).view(-1, 1, 28, 28).expand(-1, 3, -1, -1)
                    background = torch.tensor([0., 0., 1.], device=device)
                    background = background.view(1, -1, 1).expand(num_examples, 3, 784).contiguous()
                    context_pixels = y_all[:num_examples].view(num_examples, 1, -1)[:, :, context_idx]
                    context_pixels = context_pixels.expand(num_examples, 3, -1)
                    background[:, :, context_idx] = context_pixels
                    comparison = torch.cat([background.view(-1, 3, 28, 28),
                                            recons1,recons2,recons])
                    save_image(comparison.cpu(),
                               
                               "nps_" + str(N) + ".png", nrow=num_examples)
                    #plt.imshow(comparison.detach().numpy())
                    #plt.savefig(result_path+"plt_" + str(epoch) +
                               #"_nps_" + str(N) + ".pdf")
                break