In [1]:
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
import os
import random

In [2]:
batch_size=512
epochs=300
seed=1
cuda=False and torch.cuda.is_available()
log_interval=1
r_dim=128
z_dim=128
result_path="results_np_rz_y_hat/"

In [3]:

torch.manual_seed(seed)
random.seed(seed)
device = torch.device("cpu") #"cuda" if args.cuda else 

kwargs = {'num_workers': 1, 'pin_memory': True} if cuda else {}

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.ToTensor()),
    batch_size=batch_size, shuffle=True, **kwargs)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.ToTensor()),
    batch_size=batch_size, shuffle=True, **kwargs)


In [4]:
def get_context_idx(N):
    # generate the indeces of the N context points in a flattened image
    idx = random.sample(range(0, 784), N)
    idx = torch.tensor(idx, device=device)
    return idx
def generate_grid(h, w):
    rows = torch.linspace(0, 1, h, device=device)
    cols = torch.linspace(0, 1, w, device=device)
    grid = torch.stack([cols.repeat(h, 1).t().contiguous().view(-1), rows.repeat(w)], dim=1)
    grid = grid.unsqueeze(0)
    return grid
def idx_to_y(idx, data):
    # get the [0;1] pixel intensity at each index
    y = torch.index_select(data, dim=1, index=idx)
    return y
def idx_to_x(idx, batch_size):
    # From flat idx to 2d coordinates of the 28x28 grid. E.g. 35 -> (1, 7)
    # Equivalent to np.unravel_index()
    x = torch.index_select(x_grid, dim=1, index=idx)
    x = x.expand(batch_size, -1, -1)
    return x

In [None]:
class NP(nn.Module):
    def __init__(self, r_dim,z_dim):
        super(NP, self).__init__()
        self.r_dim = r_dim
        self.z_dim = z_dim
        
        self.h_1 = nn.Linear(3, 400)
        self.h_2 = nn.Linear(400, 400)
        self.h_3 = nn.Linear(400, self.r_dim)

        self.latent_layer_1=nn.Linear(3,400)
        self.latent_layer_2=nn.Linear(400,400)
        self.latent_layer_3=nn.Linear(400,self.r_dim)
        
        self.s_to_z_mean = nn.Linear(self.r_dim, self.z_dim)
        self.s_to_z_logvar = nn.Linear(self.r_dim, self.z_dim)

        self.g_1 = nn.Linear(self.r_dim+ self.z_dim + 2, 400)
        self.g_2 = nn.Linear(400,400)
        self.g_3 = nn.Linear(400,400)
        self.g_y = nn.Linear(400, 1)
        

    def h(self, x_y):
        x_y = F.relu(self.h_1(x_y))
        x_y = F.relu(self.h_2(x_y))
        x_y = F.relu(self.h_3(x_y))
        return x_y

    def latent(self, x_y):
        x_y = F.relu(self.latent_layer_1(x_y))
        x_y = F.relu(self.latent_layer_2(x_y))
        x_y = F.relu(self.latent_layer_3(x_y))
        return x_y

    def aggregate(self, r):
        return torch.mean(r, dim=1)

    def reparameterise(self, z):
        mu, std= z
        eps = torch.randn_like(std)
        z_sample = eps.mul(std).add_(mu)
        z_sample = z_sample.unsqueeze(1).expand(-1, 784, -1)
        return z_sample

    def g(self, r,z_sample, x_target):
        rz_et_x = torch.cat([r,z_sample, x_target], dim=2)
        input = F.relu(self.g_1(rz_et_x))
        input = F.relu(self.g_2(input))
        input = F.relu(self.g_3(input))
        y_hat=torch.sigmoid(self.g_y(input))
        return y_hat
    
    def xy_to_r_params(self, x, y):       
        x_y = torch.cat([x, y], dim=2)
        r_i = self.h(x_y)
        r = self.aggregate(r_i)
        return r
    
    def xy_to_z_params(self, x, y):
        
        x_y = torch.cat([x, y], dim=2)
        
        s_i = self.latent(x_y)
        s = self.aggregate(s_i)
        mu = self.s_to_z_mean(s)
        logvar = self.s_to_z_logvar(s)
        sigma=0.1+0.9*F.softplus(logvar)
        return mu, sigma

    def forward(self, x_context, y_context, x_all=None, y_all=None):
        #produire r
        r_context = self.xy_to_r_params(x_context, y_context)
        r_expand = r_context.unsqueeze(1).expand(-1, 784, -1)
        #produire z
        z_context = self.xy_to_z_params(x_context, y_context)  # (mu, logvar) of z
        if self.training:  # loss function will try to keep z_context close to z_all         
            z_all = self.xy_to_z_params(x_all, y_all)  
        else:  # at test time we don't have the image so we use only the context
            z_all = z_context
        z_sample = self.reparameterise(z_all)
        
        # reconstruct the whole image including the provided context points
        x_target = x_grid.expand(y_context.shape[0], -1, -1)
        
        y_hat = self.g(r_expand,z_sample, x_target)
    
        
        return y_hat,z_all, z_context

In [6]:
def kl_div_gaussians(mu_q, logvar_q, mu_p, logvar_p):
    var_p = torch.exp(logvar_p)
    kl_div = (torch.exp(logvar_q) + (mu_q - mu_p) ** 2) / var_p \
             - 1.0 \
             + logvar_p - logvar_q
    kl_div = 0.5 * kl_div.sum()
    return kl_div


def np_loss(y_hat, y, z_all, z_context):
    # get log probability
    #logprob=-(y - mu)**2 / (2 * sigma**2) - torch.log(sigma)
    logprob=F.binary_cross_entropy(y_hat, y, reduction="sum")
    # get KL divergence between prior and posterior
    KLD = kl_div_gaussians(z_all[0], z_all[1], z_context[0], z_context[1])
    # maximize prob and minimize KL divergence
    return logprob + KLD


In [7]:
model = NP(r_dim,z_dim).to(device)
optimizer = optim.Adam(model.parameters(), lr=5e-4)
x_grid = generate_grid(28, 28)
os.makedirs(result_path, exist_ok=True)

In [8]:
def train(epoch):
    model.train()
    train_loss = 0
    for batch_idx, (y_all, _) in enumerate(train_loader):
        
        batch_size = y_all.shape[0]
        
        y_all = y_all.to(device).view(batch_size, -1, 1)
        
        N = random.randint(1, 784)  # number of context points
        context_idx = get_context_idx(N)
        x_context = idx_to_x(context_idx, batch_size)
        y_context = idx_to_y(context_idx, y_all)
        x_all = x_grid.expand(batch_size, -1, -1)

        optimizer.zero_grad()
        
        y_hat, z_all, z_context = model(x_context, y_context, x_all, y_all)
        
        loss = np_loss(y_hat, y_all, z_all, z_context).sum(dim=0).mean()
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(y_all), len(train_loader.dataset),
                       100. * batch_idx / len(train_loader),
                       loss.item() / len(y_all)))
    print('====> Epoch: {} Average loss: {:.4f}'.format(
        epoch, train_loss / len(train_loader.dataset)))

In [11]:
def test(epoch):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for i, (y_all, _) in enumerate(test_loader):
            y_all = y_all.to(device).view(y_all.shape[0], -1, 1)
            batch_size = y_all.shape[0]

            N = 300
            context_idx = get_context_idx(N)
            x_context = idx_to_x(context_idx, batch_size)
            y_context = idx_to_y(context_idx, y_all)
            
            y_hat, z_all, z_context = model(x_context, y_context)
            test_loss += np_loss(y_hat, y_all, z_all, z_context).sum(dim=0).mean().item()

            if i == 0:  # save PNG of reconstructed examples
                plot_Ns = [10, 100, 300, 784]
                num_examples = min(batch_size, 16)
                for N in plot_Ns:
                    recons = []
                    context_idx = get_context_idx(N)
                    x_context = idx_to_x(context_idx, batch_size)
                    y_context = idx_to_y(context_idx, y_all)
                    for d in range(20):
                        y_hat, _, _ = model(x_context, y_context)
                        recons.append(y_hat[:num_examples])     
                    recons = torch.cat(recons).view(-1, 1, 28, 28).expand(-1, 3, -1, -1)
                    background = torch.tensor([0., 0., 1.], device=device)
                    background = background.view(1, -1, 1).expand(num_examples, 3, 784).contiguous()
                    context_pixels = y_all[:num_examples].view(num_examples, 1, -1)[:, :, context_idx]
                    context_pixels = context_pixels.expand(num_examples, 3, -1)
                    background[:, :, context_idx] = context_pixels
                    comparison = torch.cat([background.view(-1, 3, 28, 28),
                                            recons])
                    save_image(comparison.cpu(),
                               result_path+"ep_" + str(epoch) +
                               "_nps_" + str(N) + ".png", nrow=num_examples)

    test_loss /= len(test_loader.dataset)
    print('====> Test set loss: {:.4f}'.format(test_loss))

In [16]:
for epoch in range(2, epochs + 1):
    train(epoch)
    test(epoch)



====> Epoch: 2 Average loss: 271.6481
====> Test set loss: 226.8204


====> Epoch: 3 Average loss: 205.2964
====> Test set loss: 198.6929




====> Epoch: 4 Average loss: 198.4832
====> Test set loss: 198.9482


====> Epoch: 5 Average loss: 196.7129
====> Test set loss: 196.4944




====> Epoch: 6 Average loss: 195.5947
====> Test set loss: 195.0649


====> Epoch: 7 Average loss: 193.3544
====> Test set loss: 190.6200




====> Epoch: 8 Average loss: 190.4227
====> Test set loss: 188.5436


====> Epoch: 9 Average loss: 186.4583
====> Test set loss: 186.4170




====> Epoch: 10 Average loss: 193.8673
====> Test set loss: 186.7386


====> Epoch: 11 Average loss: 190.7706
====> Test set loss: 193.9361




====> Epoch: 12 Average loss: 186.5296
====> Test set loss: 186.8469


====> Epoch: 13 Average loss: 183.6743
====> Test set loss: 185.1148




====> Epoch: 14 Average loss: 182.0539
====> Test set loss: 180.0911


====> Epoch: 15 Average loss: 177.0346
====> Test set loss: 175.7667




====> Epoch: 16 Average loss: 179.1769
====> Test set loss: 174.7323




====> Epoch: 17 Average loss: 183.8104
====> Test set loss: 177.3371


====> Epoch: 18 Average loss: 176.5619
====> Test set loss: 175.1365




====> Epoch: 19 Average loss: 174.2737
====> Test set loss: 178.8487


====> Epoch: 20 Average loss: 167.8889
====> Test set loss: 162.7790




====> Epoch: 21 Average loss: 166.9344
====> Test set loss: 160.4313


====> Epoch: 22 Average loss: 159.4464
====> Test set loss: 156.9041




====> Epoch: 23 Average loss: 157.7292
====> Test set loss: 155.0359


====> Epoch: 24 Average loss: 153.4796
====> Test set loss: 153.3646




====> Epoch: 25 Average loss: 156.1347
====> Test set loss: 153.3637


====> Epoch: 26 Average loss: 151.7163
====> Test set loss: 154.9638




====> Epoch: 27 Average loss: 157.8642
====> Test set loss: 152.1193


====> Epoch: 28 Average loss: 156.0870
====> Test set loss: 155.6299




====> Epoch: 29 Average loss: 149.8581
====> Test set loss: 147.8385


====> Epoch: 30 Average loss: 152.1526
====> Test set loss: 146.5790




====> Epoch: 31 Average loss: 144.7296
====> Test set loss: 151.7921


====> Epoch: 32 Average loss: 151.6413
====> Test set loss: 143.1599




====> Epoch: 33 Average loss: 145.0490
====> Test set loss: 141.3460


====> Epoch: 34 Average loss: 145.2976
====> Test set loss: 144.0465




====> Epoch: 35 Average loss: 144.7802
====> Test set loss: 137.1814


====> Epoch: 36 Average loss: 146.3379
====> Test set loss: 147.4621




====> Epoch: 37 Average loss: 139.6982
====> Test set loss: 139.5078


====> Epoch: 38 Average loss: 141.1521
====> Test set loss: 135.8808




====> Epoch: 39 Average loss: 136.1170
====> Test set loss: 138.4522


====> Epoch: 40 Average loss: 140.4647
====> Test set loss: 140.2586




====> Epoch: 41 Average loss: 137.5585
====> Test set loss: 136.9337


====> Epoch: 42 Average loss: 135.6045
====> Test set loss: 146.2316




====> Epoch: 43 Average loss: 140.4533
====> Test set loss: 133.4075


====> Epoch: 44 Average loss: 133.3637
====> Test set loss: 130.1368




====> Epoch: 45 Average loss: 132.2475
====> Test set loss: 133.6025


====> Epoch: 46 Average loss: 131.3031
====> Test set loss: 131.3144




====> Epoch: 47 Average loss: 134.7398
====> Test set loss: 136.1870




====> Epoch: 48 Average loss: 141.8334
====> Test set loss: 130.3469


====> Epoch: 49 Average loss: 131.7503
====> Test set loss: 126.5045




====> Epoch: 50 Average loss: 130.8775
====> Test set loss: 130.6835


====> Epoch: 51 Average loss: 151.6246
====> Test set loss: 143.4925




====> Epoch: 52 Average loss: 140.5931
====> Test set loss: 131.7125


====> Epoch: 53 Average loss: 135.1091
====> Test set loss: 135.0811




====> Epoch: 54 Average loss: 132.4119
====> Test set loss: 128.2852


====> Epoch: 55 Average loss: 128.4343
====> Test set loss: 129.7053




====> Epoch: 56 Average loss: 126.2339
====> Test set loss: 127.2367


====> Epoch: 57 Average loss: 128.0613
====> Test set loss: 125.2373




====> Epoch: 58 Average loss: 126.9715
====> Test set loss: 123.3704


====> Epoch: 59 Average loss: 127.3874
====> Test set loss: 121.1909




====> Epoch: 60 Average loss: 126.6901
====> Test set loss: 121.5065


====> Epoch: 61 Average loss: 124.8832
====> Test set loss: 129.6315




====> Epoch: 62 Average loss: 125.0846
====> Test set loss: 121.3969


====> Epoch: 63 Average loss: 123.3276
====> Test set loss: 127.1839




====> Epoch: 64 Average loss: 121.0731
====> Test set loss: 122.2801


====> Epoch: 65 Average loss: 122.9079
====> Test set loss: 119.6887




====> Epoch: 66 Average loss: 122.3001
====> Test set loss: 122.7166


====> Epoch: 67 Average loss: 118.0169
====> Test set loss: 117.3106




====> Epoch: 68 Average loss: 118.0478
====> Test set loss: 122.4760


====> Epoch: 69 Average loss: 117.9849
====> Test set loss: 117.7381




====> Epoch: 70 Average loss: 118.6042
====> Test set loss: 117.0509


====> Epoch: 71 Average loss: 116.5442
====> Test set loss: 116.7771




====> Epoch: 72 Average loss: 117.4733
====> Test set loss: 115.8575


====> Epoch: 73 Average loss: 117.0765
====> Test set loss: 118.9092




====> Epoch: 74 Average loss: 116.3710
====> Test set loss: 114.1362


====> Epoch: 75 Average loss: 114.3366
====> Test set loss: 112.4920




====> Epoch: 76 Average loss: 113.3613
====> Test set loss: 113.8279


====> Epoch: 77 Average loss: 113.8506
====> Test set loss: 114.4745




====> Epoch: 78 Average loss: 115.0469
====> Test set loss: 115.5333




====> Epoch: 79 Average loss: 111.7900
====> Test set loss: 112.3714


====> Epoch: 80 Average loss: 111.0489
====> Test set loss: 111.6565




====> Epoch: 81 Average loss: 112.7250
====> Test set loss: 110.4905


====> Epoch: 82 Average loss: 113.0304
====> Test set loss: 111.6302




====> Epoch: 83 Average loss: 110.4436
====> Test set loss: 110.8657


====> Epoch: 84 Average loss: 111.7570
====> Test set loss: 109.9233




====> Epoch: 85 Average loss: 111.5683
====> Test set loss: 120.7985


====> Epoch: 86 Average loss: 109.5784
====> Test set loss: 111.1585




====> Epoch: 87 Average loss: 108.6822
====> Test set loss: 108.3717


====> Epoch: 88 Average loss: 107.2406
====> Test set loss: 106.9642




====> Epoch: 89 Average loss: 107.8075
====> Test set loss: 112.7880


====> Epoch: 90 Average loss: 106.4204
====> Test set loss: 110.5479




====> Epoch: 91 Average loss: 105.2003
====> Test set loss: 106.9903


====> Epoch: 92 Average loss: 106.3120
====> Test set loss: 110.5292




====> Epoch: 93 Average loss: 106.6550
====> Test set loss: 106.4770


====> Epoch: 94 Average loss: 107.0586
====> Test set loss: 108.4339




====> Epoch: 95 Average loss: 108.8556
====> Test set loss: 108.7489


====> Epoch: 96 Average loss: 108.1986
====> Test set loss: 106.4457




====> Epoch: 97 Average loss: 103.5835
====> Test set loss: 107.4371


====> Epoch: 98 Average loss: 105.5094
====> Test set loss: 108.4181




====> Epoch: 99 Average loss: 104.7869
====> Test set loss: 107.3745


====> Epoch: 100 Average loss: 104.9939
====> Test set loss: 107.4733




====> Epoch: 101 Average loss: 106.1707
====> Test set loss: 106.2733


====> Epoch: 102 Average loss: 103.6061
====> Test set loss: 105.1580




====> Epoch: 103 Average loss: 104.0047
====> Test set loss: 110.7915


====> Epoch: 104 Average loss: 100.5821
====> Test set loss: 104.4398




====> Epoch: 105 Average loss: 100.5027
====> Test set loss: 104.4993


====> Epoch: 106 Average loss: 103.7807
====> Test set loss: 105.9006




====> Epoch: 107 Average loss: 105.8362
====> Test set loss: 110.1818


====> Epoch: 108 Average loss: 103.8004
====> Test set loss: 105.1655




====> Epoch: 109 Average loss: 106.0410
====> Test set loss: 105.1656


====> Epoch: 110 Average loss: 103.9989
====> Test set loss: 104.0933




====> Epoch: 111 Average loss: 102.1737
====> Test set loss: 103.7808


====> Epoch: 112 Average loss: 99.5351
====> Test set loss: 105.9088




====> Epoch: 113 Average loss: 101.5571
====> Test set loss: 104.6532


====> Epoch: 114 Average loss: 102.6755


====> Test set loss: 102.5710


====> Epoch: 115 Average loss: 98.7746
====> Test set loss: 110.8022




====> Epoch: 116 Average loss: 99.4445
====> Test set loss: 102.5843


====> Epoch: 117 Average loss: 99.6469
====> Test set loss: 103.2039




====> Epoch: 118 Average loss: 99.9273
====> Test set loss: 103.2896


====> Epoch: 119 Average loss: 101.0790
====> Test set loss: 116.0575




====> Epoch: 120 Average loss: 101.4933
====> Test set loss: 103.5606


====> Epoch: 121 Average loss: 98.4670
====> Test set loss: 101.7149




====> Epoch: 122 Average loss: 98.2480
====> Test set loss: 101.8965


====> Epoch: 123 Average loss: 100.4526
====> Test set loss: 113.1858




====> Epoch: 124 Average loss: 106.3190
====> Test set loss: 105.5617


====> Epoch: 125 Average loss: 98.8771
====> Test set loss: 102.7470




====> Epoch: 126 Average loss: 97.8545
====> Test set loss: 104.0309


====> Epoch: 127 Average loss: 97.8996
====> Test set loss: 102.7289




====> Epoch: 128 Average loss: 94.8661
====> Test set loss: 103.6210


====> Epoch: 129 Average loss: 97.9530
====> Test set loss: 102.7287




====> Epoch: 130 Average loss: 98.2484
====> Test set loss: 101.4310


====> Epoch: 131 Average loss: 94.5862
====> Test set loss: 99.9331




====> Epoch: 132 Average loss: 99.5705
====> Test set loss: 112.0986


====> Epoch: 133 Average loss: 98.5898
====> Test set loss: 101.4185




====> Epoch: 134 Average loss: 97.0461
====> Test set loss: 99.9011




====> Epoch: 135 Average loss: 96.9062
====> Test set loss: 100.1258


====> Epoch: 136 Average loss: 96.7813
====> Test set loss: 126.1521




====> Epoch: 137 Average loss: 99.1191
====> Test set loss: 103.3552


====> Epoch: 138 Average loss: 94.4388
====> Test set loss: 100.9114




====> Epoch: 139 Average loss: 98.4680
====> Test set loss: 100.9610


====> Epoch: 140 Average loss: 94.6780
====> Test set loss: 99.4151




====> Epoch: 141 Average loss: 97.5418
====> Test set loss: 101.7245


====> Epoch: 142 Average loss: 96.7394
====> Test set loss: 101.3645




====> Epoch: 143 Average loss: 96.2356
====> Test set loss: 98.9250


====> Epoch: 144 Average loss: 96.2794
====> Test set loss: 104.8067




====> Epoch: 145 Average loss: 94.7319
====> Test set loss: 99.5740


====> Epoch: 146 Average loss: 93.7123
====> Test set loss: 99.3107




====> Epoch: 147 Average loss: 94.6794
====> Test set loss: 99.7795


====> Epoch: 148 Average loss: 94.1316
====> Test set loss: 104.1407




====> Epoch: 149 Average loss: 97.6114
====> Test set loss: 101.2465


====> Epoch: 150 Average loss: 93.3345
====> Test set loss: 101.6197




====> Epoch: 151 Average loss: 92.9624
====> Test set loss: 103.0387


====> Epoch: 152 Average loss: 95.3774
====> Test set loss: 101.5461




====> Epoch: 153 Average loss: 90.9958
====> Test set loss: 100.4333


====> Epoch: 154 Average loss: 92.6427
====> Test set loss: 100.4875




====> Epoch: 155 Average loss: 94.2599
====> Test set loss: 100.6626


====> Epoch: 156 Average loss: 91.4526
====> Test set loss: 98.4717




KeyboardInterrupt: 

In [17]:
torch.save(model.state_dict(), result_path+"model.pt")

In [None]:
optimizer = optim.Adam(model.parameters(), lr=5e-4)
epochs=200

In [10]:

model.load_state_dict(torch.load(result_path+"model.pt"))
model.eval()

NP(
  (h_1): Linear(in_features=3, out_features=400, bias=True)
  (h_2): Linear(in_features=400, out_features=400, bias=True)
  (h_3): Linear(in_features=400, out_features=128, bias=True)
  (latent_layer_1): Linear(in_features=3, out_features=400, bias=True)
  (latent_layer_2): Linear(in_features=400, out_features=400, bias=True)
  (latent_layer_3): Linear(in_features=400, out_features=128, bias=True)
  (s_to_z_mean): Linear(in_features=128, out_features=128, bias=True)
  (s_to_z_logvar): Linear(in_features=128, out_features=128, bias=True)
  (g_1): Linear(in_features=258, out_features=400, bias=True)
  (g_2): Linear(in_features=400, out_features=400, bias=True)
  (g_3): Linear(in_features=400, out_features=400, bias=True)
  (g_y): Linear(in_features=400, out_features=1, bias=True)
)

In [12]:
for epoch in range(157, epochs + 157):
    
    test(epoch)
    train(epoch)



KeyboardInterrupt: 