In [1]:
%%capture
!pip install utm
!pip install openpyxl

In [1]:
import numpy as np
import pandas as pd
import seaborn as sns
from matplotlib import cm
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.gaussian_process.kernels import RBF
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from sklearn.gaussian_process import GaussianProcessRegressor
import matplotlib.pyplot as plt
import utm

import warnings
warnings.filterwarnings("ignore")

In [2]:
data = pd.read_excel("../Dataset/Maharashtra_Soil_Nutrients_Data.xlsx")
data.head()

Unnamed: 0,lon,lat,OC,N,P,K
0,73.401111,17.894722,1.08,756.0,9.43,834.37
1,73.401389,17.894722,1.12,781.2,9.21,265.1
2,73.402222,17.894722,0.68,478.8,8.99,318.96
3,73.403056,17.894722,1.76,1234.8,9.65,954.77
4,73.403333,17.894722,1.78,1247.4,8.77,371.77


In [3]:
def scaled_coord(x,y):
    """
    parameters
    ----------
    x : numpy array, float64
        list of longitude cordinates
    y : numpy array, float64
        list of latitude cordinates
        
    return
    ------
    scaled(0-1) x and y
    """
    x = (x-x.min())/(x.max()-x.min())
    y = (y-y.min())/(y.max()-y.min())
    return x,y

In [4]:
%%time

val_col = ['OC','N','P','K']
values = data[val_col]
coordinates = data[['lon','lat']]
#lat,lon to utm projection

x,y,zone,ut = utm.from_latlon(coordinates['lat'].values,coordinates['lon'].values)

lon,lat = y/1000,x/1000 #in km

# lon, lat = scaled_coord(lon,lat)
#normalize values of OC, N, K, P

#standardise lon and lat
lon = (lon-np.mean(lon))/np.std(lon)
lat = (lat-np.mean(lat))/np.std(lat)

test_k = MinMaxScaler().fit_transform(values)
values = test_k

Wall time: 14 ms


In [5]:
data['lon'] = lon
data['lat'] = lat
for i,col in enumerate(val_col):
    data[col] = values[:,i]

In [6]:
data.head()

Unnamed: 0,lon,lat,OC,N,P,K
0,-1.425536,-2.402803,0.011632,0.079661,0.001163,0.012262
1,-1.42554,-2.402303,0.012067,0.082316,0.001136,0.003893
2,-1.42555,-2.400805,0.007283,0.05045,0.001109,0.004684
3,-1.425561,-2.399307,0.019024,0.130115,0.00119,0.014032
4,-1.425565,-2.398808,0.019241,0.131443,0.001082,0.005461


In [7]:
#split dataset into train and test
# split the dataset into train and test dataset
ix = np.random.choice(data.shape[0],int(data.shape[0]*0.2),replace = False)
data_train = data.iloc[[int(i) for i in range(data.shape[0]) if i not in ix]].reset_index(drop = True)
data_test = data.iloc[ix].reset_index(drop = True)

In [8]:
data_train.shape, data_test.shape

((20837, 6), (5209, 6))

In [9]:
data_train.head()

Unnamed: 0,lon,lat,OC,N,P,K
0,-1.42555,-2.400805,0.007283,0.05045,0.001109,0.004684
1,-1.425561,-2.399307,0.019024,0.130115,0.00119,0.014032
2,-1.425565,-2.398808,0.019241,0.131443,0.001082,0.005461
3,-1.425575,-2.39731,0.016306,0.111527,0.001136,0.007821
4,-1.425579,-2.396811,0.019241,0.131443,0.001082,0.005461


In [10]:
data_test.head()

Unnamed: 0,lon,lat,OC,N,P,K
0,0.320061,0.954405,0.005109,0.012552,0.001154,0.006153
1,-0.368445,0.777145,0.00837,0.011376,0.000635,0.001833
2,2.277226,-1.176953,0.000326,0.063433,0.003633,0.010698
3,-0.953466,0.131978,0.005979,0.016392,0.000775,0.005824
4,-1.182705,0.294285,0.007066,0.04868,0.001227,0.006582


## GP Data Generator

In [11]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import collections

In [12]:
# The (A)NP takes as input a `NPRegressionDescription` namedtuple with fields:
#   `query`: a tuple containing ((context_x, context_y), target_x)
#   `target_y`: a tensor containing the ground truth for the targets to be
#     predicted
#   `num_total_points`: A vector containing a scalar that describes the total
#     number of datapoints used (context + target)
#   `num_context_points`: A vector containing a scalar that describes the number
#     of datapoints used as context
# The GPCurvesReader returns the newly sampled data in this format at each
# iteration

NPRegressionDescription = collections.namedtuple(
    "NPRegressionDescription",
    ("query", "target_y", "num_total_points", "num_context_points"))


class GPCurvesReader(object):
  """Generates curves using a Gaussian Process (GP).

  Supports vector inputs (x) and vector outputs (y). Kernel is
  mean-squared exponential, using the x-value l2 coordinate distance scaled by
  some factor chosen randomly in a range. Outputs are independent gaussian
  processes.
  """

  def __init__(self,
               batch_size,
               max_num_context,
               x_size=1,
               y_size=1,
               l1_scale=0.6,
               sigma_scale=1.0,
               random_kernel_parameters=True,
               testing=False):
    """Creates a regression dataset of functions sampled from a GP.

    Args:
      batch_size: An integer.
      max_num_context: The max number of observations in the context.
      x_size: Integer >= 1 for length of "x values" vector.
      y_size: Integer >= 1 for length of "y values" vector.
      l1_scale: Float; typical scale for kernel distance function.
      sigma_scale: Float; typical scale for variance.
      random_kernel_parameters: If `True`, the kernel parameters (l1 and sigma) 
          will be sampled uniformly within [0.1, l1_scale] and [0.1, sigma_scale].
      testing: Boolean that indicates whether we are testing. If so there are
          more targets for visualization.
    """
    self._batch_size = batch_size
    self._max_num_context = max_num_context
    self._x_size = x_size
    self._y_size = y_size
    self._l1_scale = l1_scale
    self._sigma_scale = sigma_scale
    self._random_kernel_parameters = random_kernel_parameters
    self._testing = testing

  def _gaussian_kernel(self, xdata, l1, sigma_f, sigma_noise=2e-2):
    """Applies the Gaussian kernel to generate curve data.

    Args:
      xdata: Tensor of shape [B, num_total_points, x_size] with
          the values of the x-axis data.
      l1: Tensor of shape [B, y_size, x_size], the scale
          parameter of the Gaussian kernel.
      sigma_f: Tensor of shape [B, y_size], the magnitude
          of the std.
      sigma_noise: Float, std of the noise that we add for stability.

    Returns:
      The kernel, a float tensor of shape
      [B, y_size, num_total_points, num_total_points].
    """
    num_total_points = tf.shape(xdata)[1]

    # Expand and take the difference
    xdata1 = tf.expand_dims(xdata, axis=1)  # [B, 1, num_total_points, x_size]
    xdata2 = tf.expand_dims(xdata, axis=2)  # [B, num_total_points, 1, x_size]
    diff = xdata1 - xdata2  # [B, num_total_points, num_total_points, x_size]

    # [B, y_size, num_total_points, num_total_points, x_size]
    norm = tf.square(diff[:, None, :, :, :] / l1[:, :, None, None, :])

    norm = tf.reduce_sum(
        norm, -1)  # [B, data_size, num_total_points, num_total_points]

    # [B, y_size, num_total_points, num_total_points]
    kernel = tf.square(sigma_f)[:, :, None, None] * tf.exp(-0.5 * norm)

    # Add some noise to the diagonal to make the cholesky work.
    kernel += (sigma_noise**2) * tf.eye(num_total_points)

    return kernel

  def generate_curves(self):
    """Builds the op delivering the data.

    Generated functions are `float32` with x values between -2 and 2.
    
    Returns:
      A `NPRegressionDescription` namedtuple.
    """
    num_context = tf.random.uniform(
        shape=[], minval=3, maxval=self._max_num_context, dtype=tf.int32)

    # If we are testing we want to have more targets and have them evenly
    # distributed in order to plot the function.
    if self._testing:
      num_target = 400
      num_total_points = num_target
      x_values = tf.tile(
          tf.expand_dims(tf.range(-2., 2., 1. / 100, dtype=tf.float32), axis=0),
          [self._batch_size, 1])
      x_values = tf.expand_dims(x_values, axis=-1)
    # During training the number of target points and their x-positions are
    # selected at random
    else:
      num_target = tf.random.uniform(shape=(), minval=0, 
                                     maxval=self._max_num_context - num_context,
                                     dtype=tf.int32)
      num_total_points = num_context + num_target
      x_values = tf.random.uniform(
          [self._batch_size, num_total_points, self._x_size], -2, 2)

    # Set kernel parameters
    # Either choose a set of random parameters for the mini-batch
    if self._random_kernel_parameters:
      l1 = tf.random.uniform([self._batch_size, self._y_size,
                              self._x_size], 0.1, self._l1_scale)
      sigma_f = tf.random.uniform([self._batch_size, self._y_size],
                                  0.1, self._sigma_scale)
    # Or use the same fixed parameters for all mini-batches
    else:
      l1 = tf.ones(shape=[self._batch_size, self._y_size,
                          self._x_size]) * self._l1_scale
      sigma_f = tf.ones(shape=[self._batch_size,
                               self._y_size]) * self._sigma_scale

    # Pass the x_values through the Gaussian kernel
    # [batch_size, y_size, num_total_points, num_total_points]
    kernel = self._gaussian_kernel(x_values, l1, sigma_f)

    # Calculate Cholesky, using double precision for better stability:
    cholesky = tf.cast(tf.linalg.cholesky(tf.cast(kernel, tf.float64)), tf.float32)

    # Sample a curve
    # [batch_size, y_size, num_total_points, 1]
    y_values = tf.matmul(
        cholesky,
        tf.random.normal([self._batch_size, self._y_size, num_total_points, 1]))

    # [batch_size, num_total_points, y_size]
    y_values = tf.transpose(tf.squeeze(y_values, 3), [0, 2, 1])

    if self._testing:
      # Select the targets
      target_x = x_values
      target_y = y_values

      # Select the observations
      idx = tf.random.shuffle(tf.range(num_target))
      context_x = tf.gather(x_values, idx[:num_context], axis=1)
      context_y = tf.gather(y_values, idx[:num_context], axis=1)

    else:
      # Select the targets which will consist of the context points as well as
      # some new target points
      target_x = x_values[:, :num_target + num_context, :]
      target_y = y_values[:, :num_target + num_context, :]

      # Select the observations
      context_x = x_values[:, :num_context, :]
      context_y = y_values[:, :num_context, :]
#       print(context_x)
#       print(context_y.shape)
#       print(target_x.shape)
    query = ((context_x, context_y), target_x)
#     print(query.shape)
    return NPRegressionDescription(
        query=query,
        target_y=target_y,
        num_total_points=tf.shape(target_x)[1],
        num_context_points=num_context)

In [116]:
TRAINING_ITERATIONS = 100000 #@param {type:"number"}
MAX_CONTEXT_POINTS = 50 #@param {type:"number"}
PLOT_AFTER = 10000 #@param {type:"number"}
HIDDEN_SIZE = 128 #@param {type:"number"}
MODEL_TYPE = 'NP' #@param ['NP','ANP']
ATTENTION_TYPE = 'uniform' #@param ['uniform','laplace','dot_product','multihead']
random_kernel_parameters=True #@param {type:"boolean"}

# tf.reset_default_graph()
# Train dataset
dataset_train = GPCurvesReader(
    batch_size=16, max_num_context=MAX_CONTEXT_POINTS, random_kernel_parameters=random_kernel_parameters)
data_train = dataset_train.generate_curves()

# Test dataset
dataset_test = GPCurvesReader(
    batch_size=1, max_num_context=MAX_CONTEXT_POINTS, testing=True, random_kernel_parameters=random_kernel_parameters)
data_test = dataset_test.generate_curves()

In [117]:
#training data generator
(c_x,c_y),t_x = data_train.query
t_y = data_train.target_y

In [118]:
c_x = c_x.numpy()
c_x.shape

(16, 11, 1)

In [119]:
c_y = c_y.numpy()
c_y.shape

(16, 11, 1)

In [120]:
t_x = t_x.numpy()
t_x.shape

(16, 21, 1)

In [121]:
t_y = t_y.numpy()
t_y.shape

(16, 21, 1)

In [122]:
# t_y.shape

In [123]:
x = torch.from_numpy(c_x)
y = torch.from_numpy(c_y)
z = torch.cat([x, y], dim= -1)
z.shape

torch.Size([16, 11, 2])

## Data loading in torch.Dataloader

In [124]:
class NutrientsDataset(Dataset):
    def __init__(self, df, num_context=40, num_extra_target=10):
        self.df = df
        self.num_context = num_context
        self.num_extra_target = num_extra_target

    def get_rows(self, i):
        rows = self.df.iloc[i : i + (self.num_context + self.num_extra_target)].copy()
        x = rows.iloc[:,:2].copy()
        y = rows.iloc[:,2:].copy()
        return x, y


    def __getitem__(self, i):
        x, y = self.get_rows(i)
        return x.values, y.values
        
    def __len__(self):
        return len(self.df) - (self.num_context + self.num_extra_target)

In [125]:
def npsample_batch(x, y, size=None, sort=False):
    
    """Sample from numpy arrays along 2nd dim."""
    inds = np.random.choice(range(x.shape[1]), size=size, replace=False)
    return x[:, inds], y[:, inds]

def collate_fns(max_num_context, max_num_extra_target, sample, sort=True, context_in_target=True):
    def collate_fn(batch, sample=sample):
        # Collate
        x = np.stack([x for x, y in batch], 0)
        y = np.stack([y for x, y in batch], 0)

        # Sample a subset of random size
        num_context = np.random.randint(4, max_num_context)
        num_extra_target = np.random.randint(4, max_num_extra_target)

        x = torch.from_numpy(x).float()
        y = torch.from_numpy(y).float()

        
        x_context = x[:, :max_num_context]
        y_context = y[:, :max_num_context]
    
        x_target_extra = x[:, max_num_context:]
        y_target_extra = y[:, max_num_context:]
        
        if sample:

            x_context, y_context = npsample_batch(
                x_context, y_context, size=num_context
            )

            x_target_extra, y_target_extra = npsample_batch(
                x_target_extra, y_target_extra, size=num_extra_target, sort=sort
            )

        # do we want to compute loss over context+target_extra, or focus in on only target_extra?
        if context_in_target:
            x_target = torch.cat([x_context, x_target_extra], 1)
            y_target = torch.cat([y_context, y_target_extra], 1)
        else:
            x_target = x_target_extra
            y_target = y_target_extra

        
        return x_context, y_context, x_target, y_target

    return collate_fn

In [126]:
hparamas = dict(num_context = 15,
               num_extra_target = 16,
               batch_size = 40,
               context_in_target = False)
train_df = NutrientsDataset(data_train,hparamas['num_context'],hparamas['num_extra_target'])

train_loader = DataLoader(data_train,
                          batch_size=hparamas['batch_size'],
                         shuffle = True,
                         collate_fn=collate_fns(
                             hparamas['num_context'],hparamas['num_extra_target'], True,hparamas['context_in_target']))

 ## NP Model

In [127]:
class baseNPBlock(nn.Module):
    """relu non-linearities for NP block"""
    def __init__(self, inp_size,op_size, norm, bias = False, p = 0):
        """init function for linear2d class
        
        parameters
        ----------
        inp_size : int
                input dimension for the Encoder part (d_in)
        op_size : int
                output dimension for Encoder part(d_out)
        norm : str
                normalization to be applied on linear output
                pass norm == 'batch' to apply batch normalization
                else dropout normalization is applied
        bias : bool
                if True, bias is included for linear layer else discarded
        p : float
                probality to be considered while applying Dropout regularization
                
        """
        super().__init__()
        self.norm = norm
        self.linear = nn.Linear(inp_size,op_size,bias = bias)
        self.relu  = nn.ReLU()
        self.batch_norm = nn.BatchNorm2d(op_size)
        self.dropout = nn.Dropout2d(p)
        
    def forward(self,x):
        x = self.linear(x)
        x = self.batch_norm(x.permute(0,2,1)[:,:,:,None]) if self.norm == 'batch' else self.dropout(x.permute(0,2,1)[:,:,:,None])
        
        x = self.relu(x[:,:,:,0].permute(0,2,1))
        return x

In [128]:
class batch_MLP(nn.Module):
    """ Batch MLP layer for NP-Encoder"""
    def __init__(self, in_size, op_size, num_layers, norm, p = 0):
        """init function for linear2d class
        
        parameters
        ----------
        inp_size : int
                input dimension for the Encoder part (d_in)
        op_size : int
                output dimension for Encoder part(d_out)
        norm : str
                normalization to be applied on linear output
                pass norm == 'batch' to apply batch normalization
                else dropout normalization is applied
                
        return torch.tensor of size (B,num_context_points,d_out)
        """
        super().__init__()
        self.in_size = in_size
        self.op_size = op_size
        self.num_layers = num_layers
        self.norm  = norm
        
        self.first_layer = baseNPBlock(in_size, op_size, self.norm, False,p)
        self.encoder = nn.Sequential(*[batch_MLP(op_size, op_size, self.norm, False, p) for layer in range(self.num_layers-2)])
        self.last_layer = nn.ReLU()
        
    def forward(self, x):
        x = self.first_layer(x)
        x = self.encoder(x)
        x = self.last_layer(x)
        
        return x

In [129]:
class LinearAttention(nn.Module):
    def __init__(self,in_ch, out_ch):
        super().__init__()
        self.linear = nn.Linear(in_ch, out_ch, bias = False)
        torch.nn.init.normal_(self.linear.weight,std = in_ch**0.5) #initilize weight matrix
        
    def forward(self,x):
        return self.linear(x)
    
    
class AttentionModule(nn.Module):
    def __init__(
        self,
        hidden_dim, 
        attn_type , 
        attn_layers,
        x_dim, 
        rep='mlp',
        n_multiheads = 8,
        norm = 'dropout',
        p = 0):
        
        super().__init__()
        self.rep = rep
#         self.n_multiheads = n_multiheads
        # rep determines whether raw input given to the model would be used as key and query or
        # it's output through MLP. 
        if self.rep =='mlp':
            
            #Both Key and Value needs to have same dimension
            self.batch_mlpk = batch_MLP(x_dim, hidden_dim, attn_layers, norm ,p)
            self.batch_mlpq = batch_MLP(x_dim, hidden_dim, attn_layers, norm, p)
        
        
        if attn_type == 'uniform':
            self.attn_func = self.uniform_attn
        if attn_type=='laplace':
            self.attn_func = self.laplace_attn
        if attn_type == 'dot':
            self.attn_func = self.dot_attn
        elif attn_type == 'multihead':
            self.w_k = nn.ModuleList([LinearAttention(hidden_dim,hidden_dim) for head in range(n_multiheads)])
            self.w_v = nn.ModuleList([LinearAttention(hidden_dim,hidden_dim) for head in range(n_multiheads)])
            self.w_q = nn.ModuleList([LinearAttention(hidden_dim,hidden_dim) for head in range(n_multiheads)])
            
            self.w = LinearAttention(hidden_dim*n_multiheads,hidden_dim)
            self.attn_func = self.multihead_attn
            self.num_heads = n_multiheads
            
            
            
    def forward(self, k, q, v):
        if self.rep =='mlp':
            k = self.batch_mlpk(k) #(B, n, H)
            q = self.batch_mlpq(q) #(B, m, H)
        
        rep = self.attn_func(k,q,v)
        
        return rep
    
    
    def uniform_attn(self, k, q, v):
        num_points = q.shape[1]
        rep = torch.mean(v, axis = 1, keepdim = True)
        rep = rep.repeat(1,num_points,1)
        
        return rep
    
    def laplace_attn(self, k, q, v, scale = 0.5):
        k = k.unsqueeze(1)
        v = v.unsqueeze(2)
        
        w = torch.abs((k-v)*scale)
        w = w.sum(dim = -1)
        weight = torch.softmax(w, dim = -1)
        
        #batch matrix multiplication (einstein summation convention for tensor)
        rep = torch.einsum("bik, bkj -> bij",weight, v)
        
        return rep
    
    
    def dot_product_attn(self, k, q, v):
#         print("k =",k.shape)
#         print("q =",q.shape)
#         print("v =",v.shape)    
        β = q.shape[-1]**0.5
        w_unnorm = torch.einsum('bjk,bik->bij', k, q)/β
#         print("w_unnorm =",w_unnorm.shape)
        
        weight = torch.softmax(w_unnorm, dim = -1)
        rep = torch.einsum("bik, bkj -> bij",weight, v)
#         print("rep =",rep.shape)
        return rep
    
    def multihead_attn(self, k , q, v):
        outs = []
        
        for i in range(self.num_heads):
            k = self.w_k[i](k) #(B, n, H)
#             print("k =",k.shape)
            q = self.w_q[i](q) #(B, m, H)
#             print("q =",q.shape)
            v = self.w_v[i](v) #(B, n, H)
#             print("v =",v.shape)
            out = self.dot_product_attn(k, q, v)
            outs.append(out)
            
        outs = torch.stack(outs, dim = -1) #(B, m, H, n_heads)
#         print("outs dim =", outs.shape)
        outs = outs.view(outs.shape[0], outs.shape[1], -1) #(B, m, n_heads*H)
#         print("outs shape =",outs.shape)
        rep = self.w(outs) #(B, m, H)
        
        return rep
    
    

In [130]:
# AttentionModule?

In [131]:
class DeterministicEncoder(nn.Module):
    def __init__(
                self,
                in_dim,
                x_dim,
                norm = 'dropout',
                hidden_dim = 32,
                encoder_layer = 2,
                self_attn_type ='dot',
                cross_attn_type ='dot',
                p_encoder = 0,
                p_attention = 0,
                attn_layers = 2,
                use_self_attn = False
                ):
        super().__init__()
        
        self.use_self_attn = use_self_attn
        
        self.encoder = batch_MLP(in_dim, hidden_dim, encoder_layer,norm, p_encoder)
        
        if self.use_self_attn:
            self.self_attn = AttentionModule(hidden_dim, self_attn_type, attn_layers,x_dim, rep = 'mlp',norm = norm, p = p_attention)
            
        self.cross_attn = AttentionModule(hidden_dim, cross_attn_type, attn_layers, x_dim)
        
    
    def forward(self, context_x, context_y, target_x):
        #concatenate context_x, context_y along the last dim.
        det_enc_in = torch.cat([context_x, context_y], dim = -1)
        
        det_encoded = self.encoder(det_enc_in) #(B, n, hd)
        
        if self.use_self_attn:
            det_encoded = self.self_attn(det_encoded, det_encoded, det_encoded)
            
        h = self.cross_attn(context_x, target_x, det_encoded)
        
        return h
        
        
    
        
        

In [132]:
class LatentEncoder(nn.Module):
    def __init__(self,
                in_dim,
                hidden_dim = 32,
                latent_dim = 32,
                self_attn_type = 'dot',
                encoder_layer = 3,
                min_std = 0.01,
                norm = 'dropout',
                p_encoder = 0,
                p_attn = 0,
                use_self_attn = False,
                attn_layers = 2,
                ):
        
        super().__init__()
        
        self._use_attn = use_self_attn
        
        self.encoder = batch_MLP(in_dim, hidden_dim, encoder_layer,norm, p_encoder)
        
        if self._use_attn:
            self.self_attn = AttentionModule(hidden_dim, self_attn_type, attn_layers,x_dim, rep = 'identity',norm = norm, p = p_attention)
        
        self.secondlast_layer = nn.Linear(hidden_dim, hidden_dim)
        self.mean = nn.Linear(hidden_dim, latent_dim)
        self.l_sigma = nn.Linear(hidden_dim, hidden_dim) 
        self.min_std = min_std
#         self.use_lvar = use_lvar
        self.use_attn = use_self_attn
        
        
        
    def forward(self,x,y):
        encoder_inp = torch.cat([x,y], dim = -1) 
        
        encoded_op = self.encoder(encoder_inp)#(B, n, hd)
#         print("encoder_op shape = ",encoded_op.shape)
        if self.use_attn:
            encoded_op = self.self_attn(encoded_op, encoded_op, encoded_op) #(B, n, hd)
            
        
        mean_val = torch.mean(encoded_op, dim = 1) #mean aggregation (B, hd)
        
        #further MLP layer that maps parameters to gaussian latent
        mean_repr = torch.relu(self.secondlast_layer(mean_val)) #(B, hd)
        
        μ = self.mean(mean_repr) # (B, ld)
#         print("mean = ", μ.shape)
        log_scale = self.l_sigma(mean_repr) #(B, ld)
        
        #to avoid mode collapse
        σ = self.min_std + (1-self.min_std)*torch.sigmoid(log_scale*0.5) #(b, ld)
#         print(σ)
        dist = torch.distributions.Normal(μ, σ)
        
        return dist
        
        
            

In [142]:
class Decoder(nn.Module):
    def __init__(self,
                 x_dim,
                 y_dim,
                 hidden_dim = 32,
                 latent_dim = 32,
                 n_decoder_layer = 3,
                 use_deterministic_path = True,
                 min_std = 0.01,
                 norm = 'dropout',
                 dropout_p = 0,
                ):
        super().__init__()
        
        self.norm = norm
        self.target_transform = nn.Linear(x_dim, hidden_dim)
        
        if use_deterministic_path:
            hidden_dim_2 = 2 * hidden_dim + latent_dim
        else:
            hidden_dim_2 = hidden_dim + latent_dim
            
        self.decoder = batch_MLP(hidden_dim_2, hidden_dim_2, n_decoder_layer, norm, dropout_p)
        
        self.mean = nn.Linear(hidden_dim_2, y_dim)
        self.std = nn.Linear(hidden_dim_2, y_dim)
        self.deterministic_path = use_deterministic_path
        self.min_std = min_std
        
        
    def forward(self, r, z, t_x):
        x = self.target_transform(t_x)
        
        if self.deterministic_path:
            z = torch.cat([r,z], dim = -1)
#             print("z.shape =", z.shape)
        r = torch.cat([z,x], dim = -1)
        
        r = self.decoder(r)
        
        mean = self.mean(r)
        log_sigma = self.std(r)
        
        #clamp sigmad
        sigma = self.min_std + (1 - self.min_std) * F.softplus(log_sigma)
        
        dist = torch.distributions.Normal(mean,sigma)
        
        return dist

In [153]:
class LatentModel(nn.Module):
    def __init__(self,
               x_dim,
               y_dim,
               hidden_dim = 32,
               latent_dim = 32,
               latent_self_attn_type = 'multihead',
                det_self_attn_type = 'multihead',
                det_cross_attn_type = 'multihead',
               n_lat_enc_layer = 2,
               n_det_enc_layer = 2,
               n_decoder_layer = 2,
               use_deterministic_enc = False,
               min_std = 0.01,
               p_drop = 0,
               norm = 'dropout',
               p_attn_drop = 0,
               attn_layers = 2,
               use_self_attn = False,
               context_in_target = True,
                training = False):
        
        super().__init__()
        self.laten_encoder = LatentEncoder(x_dim+y_dim,
                                           hidden_dim=hidden_dim,
                                           latent_dim=latent_dim,
                                           self_attn_type=latent_self_attn_type,
                                           encoder_layer=n_lat_enc_layer,
                                           min_std=min_std,
                                           norm = norm,
                                           p_encoder=p_drop,
                                           p_attn=p_attn_drop,
                                           use_self_attn=use_self_attn,
                                           attn_layers=attn_layers 
                                          )
        self.deterministic_encoder = DeterministicEncoder(x_dim+y_dim,
                                                          x_dim,
                                                          norm = norm,
                                                          hidden_dim=hidden_dim,
                                                          encoder_layer=n_det_enc_layer,
                                                          self_attn_type=det_self_attn_type,
                                                          cross_attn_type=det_cross_attn_type,
                                                          p_encoder=p_drop,
                                                          p_attention=p_attn_drop,
                                                          attn_layers=attn_layers,
                                                          use_self_attn=use_self_attn
                                                         )
        self.decoder = Decoder(x_dim,
                              y_dim,
                              hidden_dim  = hidden_dim,
                              latent_dim=latent_dim,
                              n_decoder_layer=n_decoder_layer,
                              use_deterministic_path=use_deterministic_enc,
                              min_std=min_std,
                              norm=norm,
                              dropout_p=p_drop
                              )
        self.use_deterministic_enc = use_deterministic_enc
        self.context_in_target = context_in_target
        self.training = training
        
        
    def forward(self, c_x, c_y, t_x, t_y = None):
        dist_prior = self.laten_encoder(c_x, c_y)

        if t_y is not None:
            dist_posterior = self.laten_encoder(t_x, t_y)
            z = dist_posterior.loc
        else:
            z = dist_prior.loc
            
        n_target = t_x.shape[1]
        z = z.unsqueeze(1).repeat(1, n_target,1) #(B, n_target, L)
        
        if self.use_deterministic_enc:
            r = self.deterministic_encoder(c_x, c_y, t_x) #(B, n_target=m, H)
#             print(r.shape)
        else:
            r = None
            
        dist = self.decoder(r, z, t_x)
        
        #at test time, target y is not Known so we return None
        if t_y is not None:
            log_p = dist.log_prob(t_y).mean(-1)
            kl_loss = torch.distributions.kl_divergence(dist_posterior, dist_prior).mean(-1)
            kl_loss = kl_loss[:,None].expand(log_p.shape)
            loss = (kl_loss-log_p).mean()
            mse_loss = F.mse_loss(dist.loc, t_y, reduction = 'none')[:,:c_x.size(1)].mean()
        else:
            kl_loss  =None
            log_p = None
            mse_loss = None
            loss = None
            
        y_pred = dist.rsample() if self.training else dist.loc
            
        return y_pred,  dict(loss = loss, loss_p = log_p, loss_kl = kl_loss, loss_mse = mse_loss), dist



In [156]:
cx = torch.from_numpy(c_x)
cy = torch.from_numpy(c_y)
tx = torch.from_numpy(t_x)
ty = torch.from_numpy(t_y)


In [146]:
As = torch.randn(3,2,5)
Bs = torch.randn(3,5,4)
