In [5]:
import numpy as np
import torch 

import matplotlib.pyplot as plt
%matplotlib inline

import torch, torch.nn as nn
import torch.utils.data
import time 
from IPython import  display

$w = h_\alpha(z)$

$p(w|\alpha) = \int_z p(z)p(w|z, \alpha) dz = \int_z p(z)\delta_{h_\alpha(z)-w_0} dz$

In [13]:
#Data generator

import torch.distributions as dist

class ShiftedExponential(dist.Exponential):
    def __init__(self, rate,shift, validate_args=None):
        self.shift = shift
        super(ShiftedExponential, self).__init__(rate, validate_args=None)
    def rsample(self, sample_shape=torch.Size()):
        s = super(ShiftedExponential, self).rsample(sample_shape)
        return s+self.shift

def get_dist(mean,var):
    l = dist.Laplace(mean,(var/2)**(1/2))
    u = dist.Uniform(mean-(3*var)**(1/2) ,mean+(3*var)**0.5)
    n = dist.Normal(mean,var**(1/2))
    e = ShiftedExponential((1/var)**(1/2),mean-var**(1/2))
    dists = [l,u,n,e]
    d_type = int(torch.randint(4,torch.Size([1])))
    return dists[d_type],d_type

def create_samples(num_samples,mean,var,cls = 0):
    l = dist.Laplace(mean,(var/2)**(1/2))
    u = dist.Uniform(mean-(3*var)**(1/2) ,mean+(3*var)**0.5)
    n = dist.Normal(mean,var**(1/2))
    e = ShiftedExponential((1/var)**(1/2),mean-var**(1/2))
    dists = [l,u,n,e]
    return dists[cls].sample(torch.Size([num_samples]))

def create_random_samples(num_dsets,num_samples):
    mean_dist = dist.Uniform(-0.1 ,0.1)
    var_dist = dist.Uniform(0.5 ,0.6)
    means = mean_dist.sample(torch.Size([num_dsets]))
    vars = var_dist.sample(torch.Size([num_dsets]))
    dsets = []
    dtypes = []
    for mean,var in zip(means,vars):
        current_dist, d_type = get_dist(mean,var)
        dset = current_dist.sample(torch.Size([num_samples]))
        dsets += dset,
        dtypes += d_type,
    return torch.stack(dsets), torch.tensor(dtypes)

num_dsets = 10000
num_samples = 200
train_size = num_dsets*num_samples
X,y = create_random_samples(num_dsets,num_samples)
X = X.unsqueeze(-1)
bs = 1

train_loader = torch.utils.data.DataLoader(\
            torch.utils.data.TensorDataset(\
            *(torch.Tensor(X),y)),\
            batch_size=bs,shuffle=False)



In [30]:
class fc_block(nn.Module):
    def __init__(self, dim_in, dim_out):
        
        super(fc_block, self).__init__()
        self.fc = nn.Sequential(nn.Linear(dim_in, dim_out), 
                                nn.LayerNorm(dim_out),
                                nn.ReLU())
                                
    def forward(self, C):

        return self.fc(C) 

class Decoder(nn.Module):
    def __init__(self, in_z, in_x, dim_middle, dim_out, n_layers):
        """ 
        nn that maps from latent space Z and the sample X to the target variable Y
        
        """
        super(Decoder, self).__init__()
        
#         in_z = 0
        self.fc0 = nn.Linear(in_z+in_x, dim_middle) 
        self.fc1 = fc_block(dim_middle, dim_middle)
                
        self.layer_stack = nn.ModuleList([
            fc_block(dim_middle, dim_middle)
            for _ in range(n_layers)])
        
        self.fc_last_mu = nn.Linear(dim_middle, dim_out)
        self.fc_last_sigma = nn.Linear(dim_middle, dim_out)

        self.sigmoid = nn.Sigmoid()
        self.relu = nn.ReLU()
    
 
    def forward(self, X, z):
        """
        takes x and latent z and produces target y
        """
  
        
        c = torch.cat([X, z], dim=-1)
#         c = X
        if verbose: print('c', c.shape)
        c = self.relu(self.fc0(c))
    
        c, c_prev = self.fc1(c), c
        if verbose: print('c', c.shape)

        for fc in self.layer_stack:
            c, c_prev = fc(c+c_prev), c

        if verbose: print('c', c.shape)
            
        self.mu_y = self.fc_last_mu(c)
        self.sigma_y = self.sigmoid(self.fc_last_sigma(c))*0.1+0.001


        return self.mu_y, self.sigma_y
    
class Encoder(nn.Module):
    def __init__(self, dim_in, dim_middle, dim_out, n_layers):
        """ 
        nn that maps dataset to the latent z that represent it      
        """
        super(Encoder, self).__init__()
        
        # пока самая простая 
        self.fc1 = nn.Linear(dim_in, dim_middle)
        self.relu = nn.ReLU()
        
        self.layer_stack = nn.ModuleList([
            fc_block(dim_middle, dim_middle)
            for _ in range(n_layers)])
            
        self.fc_last_mu = nn.Linear(dim_middle, dim_out)
        self.fc_last_sigma = nn.Linear(dim_middle, dim_out)
        
 
    def forward(self, S):
        """
        takes dataset S and produces mu and sigma of normal distribution z
        """

        S = self.relu(self.fc1(S))  
        S_prev = S
        
        for fc in self.layer_stack:
            S, S_prev = fc(S+S_prev), S
            
        if verbose: print('S', S.shape)
        S = S.mean(0)
        
        self.mu_z = self.fc_last_mu(S)
        self.s = self.fc_last_sigma(S)
        self.sigma_z = self.s.exp() # экспонента чтобы не было отрицательных значений дисперсии

        return self.mu_z, self.sigma_z, self.s
    
class DeepPrior(nn.Module):
    def __init__(self, in_x, in_z, dim_middle, dim_out, n_enc_layers, n_dec_layers):
        
        super(DeepPrior, self).__init__()
    
        self.dim_z = in_z
        
        in_y = 0
        self.in_z = in_z
        
        self.encoder = Encoder(in_x+in_y, self.in_z, self.in_z, n_enc_layers)
        self.decoder = Decoder(in_z, in_x, dim_middle, dim_out, n_dec_layers)
        
        self.ELBO = {}
 
    def forward(self, y, X):
        """
        takes dataset S and produces ELBO lower bound for p(D) where D - dataset
        """
        
        # now let compute stohastic part
        S = X
        if verbose: print('S', S.shape)
        mu_z, sigma_z, _  = self.encoder(S)
        if verbose: print('mu_z, sigma_z', mu_z.shape, sigma_z.shape)
        sigma_z = sigma_z.pow(1/2)
   
        # firstly let compute not stochastic component that corresponds to KL term
        KL_part = -1/2 * (sigma_z.sum()+(mu_z*mu_z).sum() - 2*self.dim_z - sigma_z.log().sum())
        
        if verbose: print('KL_part', KL_part)
            
        #sampling
        N_batch = S.shape[0]
        z = torch.normal(mean = torch.zeros((1, self.in_z)), 
                          std = torch.ones((1,  self.in_z)))
        
        z = z.repeat(N_batch, 1)
        
        if verbose: print('z', z.shape)
            
        z = mu_z+sigma_z*z
        if verbose: print('z', z.shape)
        mu_y, sigma_y = self.decoder(X, z)
        
        if verbose: print('mu_y, sigma_y', mu_y.shape, sigma_y.shape)
        
        
        log_likelihood = -1/2*(y - mu_y)*(y-mu_y)/ sigma_y.pow(2)
        if verbose: print('log_likelihood', log_likelihood.shape) 
            
        log_likelihood = log_likelihood.sum()-sigma_y.log().sum()
        if verbose: print('log_likelihood', log_likelihood.shape)

        ratio = float(np.max([l/N_batch, 1]))
        loss = ratio*log_likelihood+KL_part
        
        l_np = float(loss.cpu().data.numpy())

        try:
            self.ELBO[int(y.data.numpy().flatten())].append(l_np)
        except KeyError:
            self.ELBO[int(y.data.numpy().flatten())]=[l_np]

        
        return -loss
    
    def predict(self, X, y, X_pred):
        """
        S - dataset (tuple with two tensors with shape (Points, 1))
        X - points to predict y in (tensor with shape (N_points, ))
        """
        
        S = torch.cat([X, y], dim = -1)
        if verbose: print('S', S.shape)
        mu_z, sigma_z, _  = self.encoder(S)
        
#         sigma_z = sigma_z.pow(1/2)
#         N_batch = 1
#         z = torch.normal(mean = torch.zeros((N_batch, self.in_z)), 
#                           std = torch.ones((N_batch,  self.in_z)))      
#         if verbose: print('z', z.shape)
            
        z = mu_z.unsqueeze(0) #+sigma_z*z
        if verbose: print('z', z.shape)
            
        z = z.repeat(X_pred.shape[0], 1)
        if verbose: 
            print('z', z.shape)
            print('X', X.shape)
            
        mu_y, sigma_y = self.decoder(X_pred, z)
        
        return mu_y, sigma_y

        

In [31]:
def train_model(model, num_epochs=20, batchsize = 10, verbose=True, plot_every = 100):
    
    start_time = 0
    
    for epoch in range(num_epochs):
        model.train(True) # enable dropout / batch_norm training behavior
        for (X_batch, y_batch) in train_loader:
            
            print('X_batch', X_batch.shape)
            print('y_batch', y_batch.shape)
            
            loss = model(y_batch, X_batch[0])
            
            loss.backward()
            torch.nn.utils.clip_grad_norm(model.parameters(), 0.25)
            optimizer.step()
            optimizer.zero_grad()
            

        # Visualize
        # Then we print the results for this epoch:
        if verbose: # and epoch % plot_every == 0

            display.clear_output(wait=True)
            print("Epoch {} of {} took {:.3f}s".format(
                epoch + 1, num_epochs, time.time() - start_time)) 
            start_time = time.time()

            ELBO = np.array(list(model.ELBO.values()))
            print('current elbo: {}'.format(ELBO[0][-1]))
            

            plt.figure(figsize=(16, 6))
            plt.subplot(221)
            plt.title("ELBO")
            plt.xlabel("#iteration")
            plt.ylabel("elbo")
            plt.plot(ELBO.T, label = 'train_elbo')
            plt.show()
           
    return X_0, y_0

In [32]:
# оптимайзер
from torch import optim
from itertools import chain
model = DeepPrior(in_x=1, in_z=128, dim_middle=128, dim_out=1, n_enc_layers=4, n_dec_layers=12)
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [33]:
verbose = False
X_0, y_0 = train_model(model, num_epochs=10000, batchsize = 50, verbose=True)

X_batch torch.Size([1, 200, 1])
y_batch torch.Size([1])


RuntimeError: Expected object of type torch.LongTensor but found type torch.FloatTensor for argument #3 'other'