In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions as D
import torch.optim as optim
from torch.nn.utils.rnn import pad_sequence
import numpy as np
import matplotlib.pyplot as plt

from models import LocallyWeightedCNP, CNP

In [None]:
if torch.cuda.is_available():
    device = torch.device("cuda:0")
else:
    device = torch.device("cpu")

print(device.type)

# ---

torch.manual_seed(42)
np.random.seed(42)

In [None]:
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

def plot(data):
    x = torch.linspace(0, 1, data.shape[1])
    for traj in data:
        plt.plot(x, traj)

def cx_sigm(n=10, t_max=200):
    # rand_btw_7-30 * {-1, 1}
    c1 = (torch.rand(n, 1) * 23 + 7) * torch.from_numpy(np.random.choice([-1, 1], (n, 1)))  # c1 for steepness
    t = torch.linspace(0, 1, t_max)
    c2 = 0.5  # c2 midpoint
    
    data = 1/(1 + torch.exp(-c1 * (t-c2)))
    return data


n = 51
t_max = 200

data = cx_sigm(n, t_max).view(n, t_max, 1).to(device)
y_test = data[-1]

x_train, x_val, y_train, y_val = train_test_split(torch.linspace(0, 1, t_max).repeat(n-1, 1).
                                                  view(n-1, t_max, 1).to(device), data[:-1], train_size=0.8)

print(x_train.shape)
print(y_train.shape)
print(x_val.shape)
print(y_val.shape)

In [None]:
n_max = 10
num_trajs, t_steps = x_train.shape[0], x_train.shape[1]
d_x, d_y = x_train.shape[-1], y_train.shape[-1]

model = LocallyWeightedCNP((d_x, d_y)).to(device)
learning_rate = 1e-4
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

def predict_model(observations, x_target, plot=True, step=-1):    
    with torch.no_grad():
        y_pred, std_pred = model(observations, x_target)
    
    if plot:
        for i in range(d_y):  # for every feature in Y vector we are plotting training data and its prediction
            fig = plt.figure(figsize=(5,5))
            for j in range(num_trajs):
                plt.plot(x_train[j, :, 0].cpu(), y_train[j, :, i].cpu(), c="b", alpha=0.1)

            idx = 0
            plt.plot(x_val[idx, :, 0].cpu(), y_pred[idx, :, i].cpu(), color='black')
            plt.errorbar(x_val[idx, :, 0].cpu(), y_pred[idx, :, i].cpu(), yerr=std_pred[idx, :, i].cpu(), color = 'black', alpha=0.4)
            plt.scatter(observations[idx, [0, -1], 0].cpu(), observations[idx, [0, -1], d_x+i].cpu(), marker="X", color='black')
            plt.savefig(f'out/{step}_{i}_val.png')
            plt.close()
    return y_pred, std_pred


def get_training_sample(batch_size=1, context_max=10, target_max=10):
    context_all = []
    target_all = []
    context_mask = []
    target_mask = []
    for _ in range(batch_size):
        n_context = torch.randint(1, context_max, ())
        n_target = torch.randint(1, target_max, ())
        traj_id = np.random.choice(num_trajs)
        traj_xdims = x_train[traj_id]
        traj_ydims = y_train[traj_id]
        R = torch.randperm(traj_xdims.shape[0])
        traj_xy = torch.cat([traj_xdims, traj_ydims], dim=-1)
        context = traj_xy[R[:n_context]]
        target = traj_xy[R[:(n_context+n_target)]]
        context_all.append(context)
        context_mask.append(torch.ones(context.shape[0]))
        target_all.append(target)
        target_mask.append(torch.ones(target.shape[0]))
    context_all = pad_sequence(context_all, batch_first=True)
    target_all = pad_sequence(target_all, batch_first=True)
    context_mask = pad_sequence(context_mask, batch_first=True)
    target_mask = pad_sequence(target_mask, batch_first=True)
    return context_all, target_all, context_mask, target_mask

In [None]:
x_t, y_t, x_mask, y_mask = get_training_sample(5)
plt.scatter(x_t[0, :, 0], x_t[0, :, 1], marker="o", c="b", alpha=0.7, label="context points")
plt.scatter(y_t[0, :, 0], y_t[0, :, 1], marker="x", c="r", alpha=0.7, label="target points")
with torch.no_grad():
    x_m, x_s = model(x_t, y_t[..., :1])
plt.scatter(y_t[0, :, 0], x_m[0, :, 0], marker="x", c="c", alpha=0.7, label="initial preds")
plt.legend()

In [None]:
smooth_losses = [0]
losses = []
loss_checkpoint = 1000
loss_inform_checkpoint = 100000
plot_checkpoint = 10000
validation_checkpoint = 1000
validation_error = 9999999

val_observation = torch.zeros(2, 1, d_x+d_y, device=device)
nof_val_trajs = x_val.shape[0]

for step in range(5000000):
    x_t, y_t, x_m, y_m = get_training_sample(1)

    optimizer.zero_grad()
    loss = model.nll_loss(x_t, y_t[..., :1], y_t[..., 1:])
    loss.backward()
    optimizer.step()
    
    if step % loss_inform_checkpoint == 0:
        print(f'Step: {step}')
    
    if step % loss_checkpoint == 0:
        losses.append(loss.data.cpu())
        smooth_losses[-1] += loss.data.cpu()/(plot_checkpoint/loss_checkpoint)
    
    if step % validation_checkpoint == 0:
        xy_val = torch.cat([x_val, y_val], dim=-1)
        current_error = model.nll_loss(xy_val[:, [0, -1]], x_val, y_val)  # only give the initial and the last timestep
        if current_error < validation_error:
            validation_error = current_error
            torch.save(model.state_dict(), 'cnmp_best_validation.h5')
            print(f'Step: {step}. New validation best. Error: {current_error}')
        
    if step % plot_checkpoint == 0:        
        plt.figure(figsize=(15,5))
        plt.subplot(121)
        plt.title('Train Loss')
        plt.plot(range(len(losses)),losses)
        plt.subplot(122)
        plt.title('Train Loss (Smoothed)')
        plt.plot(range(len(smooth_losses)), smooth_losses)
        plt.savefig(f'out/{step}.png')
        plt.close()
        
        #plotting validation cases
        predict_model(torch.tensor([[[0.0, 1.0], [1.0, 0.0]]]), x_val[:1], plot=True, step=step)
        predict_model(torch.tensor([[[0.0, 0.0], [1.0, 1.0]]]), x_val[:1], plot=True, step=step+1)
        predict_model(torch.tensor([[[0.0, 1.0], [1.0, 1.0]]]), x_val[:1], plot=True, step=step+2)
        predict_model(torch.tensor([[[0.0, 0.0], [1.0, 0.0]]]), x_val[:1], plot=True, step=step+3)
        
        if step!=0:
            smooth_losses.append(0)
# print('Finished Training')