In [1]:
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
import os
import random

In [2]:
batch_size=256
epochs=300
seed=1
cuda=False and torch.cuda.is_available()
log_interval=1
r_dim=128
z_dim=128


In [3]:

torch.manual_seed(seed)
random.seed(seed)
device = torch.device("cpu") #"cuda" if args.cuda else 

kwargs = {'num_workers': 1, 'pin_memory': True} if cuda else {}

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.ToTensor()),
    batch_size=batch_size, shuffle=True, **kwargs)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.ToTensor()),
    batch_size=batch_size, shuffle=True, **kwargs)


In [4]:
def get_context_idx(N):
    # generate the indeces of the N context points in a flattened image
    idx = random.sample(range(0, 784), N)
    idx = torch.tensor(idx, device=device)
    return idx

In [5]:
def generate_grid(h, w):
    rows = torch.linspace(0, 1, h, device=device)
    cols = torch.linspace(0, 1, w, device=device)
    grid = torch.stack([cols.repeat(h, 1).t().contiguous().view(-1), rows.repeat(w)], dim=1)
    grid = grid.unsqueeze(0)
    return grid

In [6]:
def idx_to_y(idx, data):
    # get the [0;1] pixel intensity at each index
    y = torch.index_select(data, dim=1, index=idx)
    return y


In [7]:
def idx_to_x(idx, batch_size):
    # From flat idx to 2d coordinates of the 28x28 grid. E.g. 35 -> (1, 7)
    # Equivalent to np.unravel_index()
    x = torch.index_select(x_grid, dim=1, index=idx)
    x = x.expand(batch_size, -1, -1)
    return x

In [16]:
class CNP(nn.Module):
    def __init__(self, r_dim,z_dim):
        super(CNP, self).__init__()
        self.r_dim = r_dim
        self.z_dim = z_dim
        self.h_1 = nn.Linear(3, 256)
        self.h_2 = nn.Linear(256, 256)
        self.h_3 = nn.Linear(256, self.r_dim)

        self.r_to_z_mean = nn.Linear(self.r_dim, self.z_dim)
        self.r_to_z_logvar = nn.Linear(self.r_dim, self.z_dim)

        self.g_1 = nn.Linear(self.z_dim + 2, 256)
        self.g_2 = nn.Linear(256,256)
        self.g_3 = nn.Linear(256,256)
        self.g_mu = nn.Linear(256, 1)
        self.g_sigma = nn.Linear(256, 1)

    def h(self, x_y):
        x_y = F.relu(self.h_1(x_y))
        x_y = F.relu(self.h_2(x_y))
        x_y = F.relu(self.h_3(x_y))
        return x_y

    def aggregate(self, r):
        return torch.mean(r, dim=1)

    def g(self, rep, x_target):
        r_et_x= torch.cat([rep, x_target], dim=2)
        input = F.relu(self.g_1(r_et_x))
        input = F.relu(self.g_2(input))
        input = F.relu(self.g_3(input))
        mu=self.g_mu(input)
        log_sigma=self.g_sigma(input)
        return mu,log_sigma

    def xy_to_r_params(self, x, y):
        
        x_y = torch.cat([x, y], dim=2)
        
        r_i = self.h(x_y)
        r = self.aggregate(r_i)
        
        return r

    def forward(self, x_context, y_context,inTest=False):
       
        r_context = self.xy_to_r_params(x_context, y_context)  # (mu, logvar) of z
        
        r_expand = r_context.unsqueeze(1).expand(-1, 784, -1)

        # reconstruct the whole image including the provided context points
        
        x_target = x_grid.expand(y_context.shape[0], -1, -1)
        
        #change r with a random tensor in test phase
        if inTest:
            r_expand=torch.rand(r_expand.size())
        
        mu, log_sigma = self.g(r_expand, x_target)
    
        # Bound the variance
        sigma = 0.1 + 0.9 * F.softplus(log_sigma)

        return mu, sigma
        

In [9]:
def np_loss(mu,sigma, y):#, z_all, z_context

    if y is not None:
        log_p = -(y - mu)**2 / (2 * sigma**2) - torch.log(sigma)
    else:
        log_p = None
    return log_p

In [10]:
model = CNP(r_dim,z_dim).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-4)
x_grid = generate_grid(28, 28)
os.makedirs("results_cnp_r_alea/", exist_ok=True)

In [11]:
def train(epoch):
    model.train()
    train_loss = 0
    for batch_idx, (y_all, _) in enumerate(train_loader):
        
        batch_size = y_all.shape[0]
        
        y_all = y_all.to(device).view(batch_size, -1, 1)
        
        N = random.randint(1, 784)  # number of context points
        context_idx = get_context_idx(N)
        x_context = idx_to_x(context_idx, batch_size)
        
        y_context = idx_to_y(context_idx, y_all)
        optimizer.zero_grad()
        
        mu,sigma= model(x_context, y_context)
        
        loss = -np_loss(mu,sigma, y_all).sum(dim=0).mean() #z_all, z_context
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(y_all), len(train_loader.dataset),
                       100. * batch_idx / len(train_loader),
                       loss.item() / len(y_all)))
    print('====> Epoch: {} Average loss: {:.4f}'.format(
        epoch, train_loss / len(train_loader.dataset)))

In [12]:
def test(epoch):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for i, (y_all, _) in enumerate(test_loader):
            y_all = y_all.to(device).view(y_all.shape[0], -1, 1)
            batch_size = y_all.shape[0]

            N = 300
            context_idx = get_context_idx(N)
            x_context = idx_to_x(context_idx, batch_size)
            y_context = idx_to_y(context_idx, y_all)

            mu,sigma = model(x_context, y_context,inTest=True)
            test_loss += -np_loss(mu,sigma, y_all).sum(dim=0).mean().item()

            if i == 0:  # save PNG of reconstructed examples
                plot_Ns = [10, 100, 300, 784]
                num_examples = min(batch_size, 16)
                for N in plot_Ns:
                    recons = []
                    recons1=[]
                    context_idx = get_context_idx(N)
                    x_context = idx_to_x(context_idx, batch_size)
                    y_context = idx_to_y(context_idx, y_all)
                    for d in range(5):#r is a random tensor, output 5 times to see its influence
                        mu,sigma= model(x_context, y_context)
                        recons.append(mu[:num_examples])
                        recons1.append(sigma[:num_examples])
                    recons = torch.cat(recons).view(-1, 1, 28, 28).expand(-1, 3, -1, -1)
                    recons1 = torch.cat(recons1).view(-1, 1, 28, 28).expand(-1, 3, -1, -1)
                    background = torch.tensor([0., 0., 1.], device=device)
                    background = background.view(1, -1, 1).expand(num_examples, 3, 784).contiguous()
                    context_pixels = y_all[:num_examples].view(num_examples, 1, -1)[:, :, context_idx]
                    context_pixels = context_pixels.expand(num_examples, 3, -1)
                    background[:, :, context_idx] = context_pixels
                    comparison = torch.cat([background.view(-1, 3, 28, 28),
                                            recons,recons1])
                    save_image(comparison.cpu(),
                               'results_cnp_r_alea/ep_' + str(epoch) +
                               '_cps_' + str(N) + '.png', nrow=num_examples)

    test_loss /= len(test_loader.dataset)
    print('====> Test set loss: {:.4f}'.format(test_loss))

In [13]:
for epoch in range(1, epochs + 1):
    train(epoch)
    test(epoch)



====> Epoch: 1 Average loss: -0.6174
====> Test set loss: -0.5830


====> Epoch: 2 Average loss: -0.7511
====> Test set loss: 5.7385




====> Epoch: 3 Average loss: -0.9127
====> Test set loss: 51.1758


====> Epoch: 4 Average loss: -1.0967
====> Test set loss: 32.3399




====> Epoch: 5 Average loss: -1.2640
====> Test set loss: 5.7438


====> Epoch: 6 Average loss: -1.3181
====> Test set loss: 3.9333




====> Epoch: 7 Average loss: -1.3485
====> Test set loss: 3.3140


====> Epoch: 8 Average loss: -1.3449
====> Test set loss: 2.8341




====> Epoch: 9 Average loss: -1.3687
====> Test set loss: 2.7720


====> Epoch: 10 Average loss: -1.3802
====> Test set loss: 2.8450




====> Epoch: 11 Average loss: -1.3924
====> Test set loss: 2.8236


====> Epoch: 12 Average loss: -1.3610
====> Test set loss: 2.6965




====> Epoch: 13 Average loss: -1.3949
====> Test set loss: 2.6757


====> Epoch: 14 Average loss: -1.3994
====> Test set loss: 2.7026




====> Epoch: 15 Average loss: -1.3996
====> Test set loss: 2.6625


====> Epoch: 16 Average loss: -1.4083
====> Test set loss: 2.6583




====> Epoch: 17 Average loss: -1.4055
====> Test set loss: 2.6756


====> Epoch: 18 Average loss: -1.4023
====> Test set loss: 2.7028


====> Epoch: 19 Average loss: -1.4093


====> Test set loss: 2.6716


====> Epoch: 20 Average loss: -1.4133
====> Test set loss: 2.6762


====> Epoch: 21 Average loss: -1.4127
====> Test set loss: 2.6669




====> Epoch: 22 Average loss: -1.4130
====> Test set loss: 2.6801


====> Epoch: 23 Average loss: -1.4232
====> Test set loss: 2.6774




====> Epoch: 24 Average loss: -1.4002
====> Test set loss: 2.6963


====> Epoch: 25 Average loss: -1.4113
====> Test set loss: 2.6943




====> Epoch: 26 Average loss: -1.4216
====> Test set loss: 2.6817


====> Epoch: 27 Average loss: -1.4203
====> Test set loss: 2.6827




====> Epoch: 28 Average loss: -1.4349
====> Test set loss: 2.6808


====> Epoch: 29 Average loss: -1.4328
====> Test set loss: 2.6851




====> Epoch: 30 Average loss: -1.4328
====> Test set loss: 2.6918


====> Epoch: 31 Average loss: -1.4438
====> Test set loss: 2.7305




====> Epoch: 32 Average loss: -1.4490
====> Test set loss: 2.6783


====> Epoch: 33 Average loss: -1.4462
====> Test set loss: 2.6750




====> Epoch: 34 Average loss: -1.4393
====> Test set loss: 2.7205


====> Epoch: 35 Average loss: -1.4419
====> Test set loss: 2.8177




====> Epoch: 36 Average loss: -1.4511
====> Test set loss: 2.6968


====> Epoch: 37 Average loss: -1.4630
====> Test set loss: 2.7491




====> Epoch: 38 Average loss: -1.4535
====> Test set loss: 2.7357


====> Epoch: 39 Average loss: -1.4459
====> Test set loss: 2.6621




====> Epoch: 40 Average loss: -1.4703
====> Test set loss: 2.6644


====> Epoch: 41 Average loss: -1.4707
====> Test set loss: 2.6643




====> Epoch: 42 Average loss: -1.4731
====> Test set loss: 2.6704


====> Epoch: 43 Average loss: -1.4790
====> Test set loss: 2.6763




====> Epoch: 44 Average loss: -1.4766
====> Test set loss: 2.6913


====> Epoch: 45 Average loss: -1.4719
====> Test set loss: 2.9099




====> Epoch: 46 Average loss: -1.4960
====> Test set loss: 2.9637


====> Epoch: 47 Average loss: -1.4905
====> Test set loss: 2.8395




====> Epoch: 48 Average loss: -1.4873
====> Test set loss: 2.7382


====> Epoch: 49 Average loss: -1.4935
====> Test set loss: 3.0393




====> Epoch: 50 Average loss: -1.4950
====> Test set loss: 2.7460


====> Epoch: 51 Average loss: -1.4909
====> Test set loss: 2.7186




====> Epoch: 52 Average loss: -1.5012
====> Test set loss: 2.9762


====> Epoch: 53 Average loss: -1.5059
====> Test set loss: 2.7152




====> Epoch: 54 Average loss: -1.5046
====> Test set loss: 2.7425


====> Epoch: 55 Average loss: -1.4888
====> Test set loss: 3.8958




====> Epoch: 56 Average loss: -1.5248
====> Test set loss: 3.1549


====> Epoch: 57 Average loss: -1.5118
====> Test set loss: 3.4175


====> Epoch: 58 Average loss: -1.4999


====> Test set loss: 4.3520


====> Epoch: 59 Average loss: -1.5103
====> Test set loss: 4.2136


====> Epoch: 60 Average loss: -1.5228
====> Test set loss: 3.6700




====> Epoch: 61 Average loss: -1.5253
====> Test set loss: 3.5526


====> Epoch: 62 Average loss: -1.5172
====> Test set loss: 3.1428




====> Epoch: 63 Average loss: -1.5176
====> Test set loss: 3.9475


====> Epoch: 64 Average loss: -1.5287
====> Test set loss: 3.0853




====> Epoch: 65 Average loss: -1.5390
====> Test set loss: 3.4927


====> Epoch: 66 Average loss: -1.5267
====> Test set loss: 2.7785




====> Epoch: 67 Average loss: -1.5304
====> Test set loss: 3.3950


====> Epoch: 68 Average loss: -1.5479
====> Test set loss: 2.9908




====> Epoch: 69 Average loss: -1.5340
====> Test set loss: 3.0399


====> Epoch: 70 Average loss: -1.5514
====> Test set loss: 3.2057




====> Epoch: 71 Average loss: -1.5476
====> Test set loss: 2.7371


====> Epoch: 72 Average loss: -1.5528
====> Test set loss: 2.8323




====> Epoch: 73 Average loss: -1.5432
====> Test set loss: 3.1027


====> Epoch: 74 Average loss: -1.5341
====> Test set loss: 2.8163




====> Epoch: 75 Average loss: -1.5452
====> Test set loss: 2.7883


====> Epoch: 76 Average loss: -1.5556
====> Test set loss: 2.7796




====> Epoch: 77 Average loss: -1.5491
====> Test set loss: 3.4349


====> Epoch: 78 Average loss: -1.5600
====> Test set loss: 4.0719




====> Epoch: 79 Average loss: -1.5656
====> Test set loss: 3.4416


====> Epoch: 80 Average loss: -1.5701
====> Test set loss: 3.1198




====> Epoch: 81 Average loss: -1.5385
====> Test set loss: 3.9374


====> Epoch: 82 Average loss: -1.5643
====> Test set loss: 4.6135




====> Epoch: 83 Average loss: -1.5645
====> Test set loss: 4.1381


====> Epoch: 84 Average loss: -1.5564
====> Test set loss: 7.1544




====> Epoch: 85 Average loss: -1.5828
====> Test set loss: 6.1912


====> Epoch: 86 Average loss: -1.5783
====> Test set loss: 4.1302




====> Epoch: 87 Average loss: -1.5850
====> Test set loss: 4.2365


====> Epoch: 88 Average loss: -1.5762
====> Test set loss: 5.2610




====> Epoch: 89 Average loss: -1.5894
====> Test set loss: 4.0344


====> Epoch: 90 Average loss: -1.5711
====> Test set loss: 4.0408




====> Epoch: 91 Average loss: -1.5765
====> Test set loss: 3.8530


====> Epoch: 92 Average loss: -1.5899
====> Test set loss: 3.2762




====> Epoch: 93 Average loss: -1.5760
====> Test set loss: 4.4246


====> Epoch: 94 Average loss: -1.5950
====> Test set loss: 3.4827




====> Epoch: 95 Average loss: -1.6013
====> Test set loss: 3.3340


====> Epoch: 96 Average loss: -1.6058
====> Test set loss: 3.2699


====> Epoch: 97 Average loss: -1.5984


====> Test set loss: 3.2031


====> Epoch: 98 Average loss: -1.6049
====> Test set loss: 3.1036


====> Epoch: 99 Average loss: -1.6114
====> Test set loss: 2.9117




====> Epoch: 100 Average loss: -1.6048
====> Test set loss: 2.9613


====> Epoch: 101 Average loss: -1.6020


====> Test set loss: 2.9042


====> Epoch: 102 Average loss: -1.5791
====> Test set loss: 2.8925




KeyboardInterrupt: 

In [15]:
#save the model
save_path="./results_cnp_r_alea/model_param.pt"
torch.save(model.state_dict(), save_path)
