In [None]:
#| default_exp models

In [None]:
#| export
import torch
import torch.nn.functional as F

In [None]:
#| hide
from nbdev import nbdev_export
nbdev_export()

# Useful operations

In [None]:
#| export
import torch.nn.init as init

def kaiming_init(m):
    '''
    Kaiming initialization of module m
    References:
    - He, K., Zhang, X., Ren, S., & Sun, J. (2015). Delving deep into rectifiers: Surpassing human-level performance on imagenet classification
    '''
    if isinstance(m, (nn.Linear, nn.Conv2d)):
        init.kaiming_normal_(m.weight)
        if m.bias is not None:
            m.bias.data.fill_(0)
    elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
        m.weight.data.fill_(1)
        if m.bias is not None:
            m.bias.data.fill_(0)

In [None]:
#| export
from torch.autograd import Variable

def reparametrize(mu, logvar):
    '''
    Reparameterization trick to sample from N(mu, var) from N(0,1).
    '''
    std = logvar.div(2).exp()
    eps = Variable(std.data.new(std.size()).normal_())
    return mu + std*eps

# VAIR

In [None]:
#| export
class VAIR(torch.nn.Module):

    '''
    Variational AIR (VAIR) architecture. 
    Consists of:
    - Encoder E_x: from observation x to latent mu
    - Encoder E_a: from action a to latent logvar
    - Decoder: from latent z and action a to reconstruction y
    '''
    
    def __init__(self,
                 dim_x = 62, # Dimension input to E_z (observation dimension)
                 dim_a = 0, # Dimension input to E_a (action dimension)
                 dim_enc_h = [512,256], # Dimension layers E_x
                 dim_dec_h = [256, 512], # Dimension layers decoder
                 dim_z = 2, # Dimension latent                 
                 dim_y = 62, # Dimension output 
                ):
        super().__init__()
        
        # Define latent dimension
        self.dim_z = dim_z 
        self.dim_a = dim_a
        
        # E_x
        self.lin1 = torch.nn.Linear(dim_x, dim_enc_h[0])        
        self.lin2 = torch.nn.Linear(dim_enc_h[0], dim_enc_h[1])   
        self.lin3 = torch.nn.Linear(dim_enc_h[1], dim_z)  

        # E_a
        self.linEa1 = torch.nn.Linear(self.dim_a, 256)            
        self.linEa2 = torch.nn.Linear(256, 512) 
        self.linEa3 = torch.nn.Linear(512, dim_z)         
        
        # Decoder       
        self.lin1_d = torch.nn.Linear(dim_z+dim_a, dim_dec_h[0])        
        self.lin2_d = torch.nn.Linear(dim_dec_h[0], dim_dec_h[1])   
        self.lin3_d = torch.nn.Linear(dim_dec_h[1], dim_y)              

    def E_x(self,
            x = None, # dimension: (BS , dim_x)
           ):         
        mu_p = self.lin1(x)        
        mu_p = F.relu(mu_p)
        mu_p = self.lin2(mu_p)   
        mu_p = F.relu(mu_p)
        mu = self.lin3(mu_p)
        
        return mu 


    def E_a(self,
            a = None, # dimension: (BS , dim_a)
           ):        
        logvar_p = self.linEa1(a)
        logvar_p = F.relu(logvar_p)
        logvar_p = self.linEa2(logvar_p)
        logvar_p = F.relu(logvar_p)        
        logvar = (self.linEa3(logvar_p))
        
        return logvar  
    

    def decoder(self,
                za = None, # dimension: (BS , dim_z+dim_a)
                ):           
        
        y_p = self.lin1_d(za)
        y_p = F.relu(y_p)
        y_p = self.lin2_d(y_p) 
        y_p = F.relu(y_p)        
        y = self.lin3_d(y_p)
        
        return y  

    def encoder(self, x):
        # Placeholder to mimic VAE encoder in certain analysis
        return self.E_x(x)


    def forward(self, inp):     

        # Separating observation and action from the input
        x = inp[:, :-self.dim_a].clone()
        a = inp[:, -self.dim_a:].clone()  
        
        # Encoders pass and sampling latent latent space
        mu = self.E_x(x)
        logvar = self.E_a(a)
        z = reparametrize(mu, logvar)

        # Merge action and latent and input to decoder
        merge = torch.cat((z, a), axis = 1)         
        y = self.decoder(merge)
        
        return y, mu, logvar
        

In [None]:
dim_x = 5; dim_a = 3; dim_z = 10; dim_y = 3
dataset = torch.rand((2, dim_x+dim_a))

vair = VAIR(dim_x = dim_x, dim_a = dim_a, dim_y = dim_y)

assert vair(dataset)[0].shape == (dataset.shape[0], dim_y)

# VAE$_{x,a}$: VAE with $x,a$ input, no decoder input

In [None]:
#| export
class VAE_xa(torch.nn.Module):

    '''
    VAE with observation and action as input to single encoder. No action input to decoder.
    Consists of:
    - Encoder E_x: from observation x and action a to latent mu and logvar
    - Decoder: from latent z to reconstruction y
    '''
    
    def __init__(self,
                 dim_x = 62, # Dimension input to E_z (observation dimension)
                 dim_a = 0, # Dimension input to E_a (action dimension)
                 dim_enc_h = [1000,256], # Dimension layers E_x
                 dim_dec_h = [256, 512], # Dimension layers decoder
                 dim_z = 2, # Dimension latent                 
                 dim_y = 62, # Dimension output 
                ):
        super().__init__()        
        
        # Define latent dimension
        self.dim_z = dim_z 
        self.dim_a = dim_a
        
        # E_x
        self.lin1 = torch.nn.Linear(dim_x+dim_a, dim_enc_h[0])        
        self.lin2 = torch.nn.Linear(dim_enc_h[0], dim_enc_h[1])  
        self.lin3 = torch.nn.Linear(dim_enc_h[1], 2*dim_z)   
        
        
        # Decoder       
        self.lin1_d = torch.nn.Linear(dim_z, dim_dec_h[0])        
        self.lin2_d = torch.nn.Linear(dim_dec_h[0], dim_dec_h[1])   
        self.lin3_d = torch.nn.Linear(dim_dec_h[1], dim_y)         
            

    def encoder(self,
                xa = None, # dimension: (BS , dim_x+dim_a)
                ): 
        state_part = self.lin1(xa)        
        state_part = F.relu(state_part)
        state_part = self.lin2(state_part)   
        state_part = F.relu(state_part)
        state_part = self.lin3(state_part)
        return state_part   
    

    def decoder(self,
                z = None, # dimension: (BS , dim_z)
                ):           
        
        y_p = self.lin1_d(z)
        y_p = F.relu(y_p)
        y_p = self.lin2_d(y_p) 
        y_p = F.relu(y_p)        
        y = self.lin3_d(y_p)
        
        return y   

    def forward(self, inp):            
        
        # Encoders pass and sampling latent latent space
        distributions = self.encoder(xa = inp)
        mu = distributions[:, :self.dim_z]
        logvar = distributions[:, self.dim_z:]
        z = reparametrize(mu, logvar)  

        # Decoder pass
        y = self.decoder(z)
        
        return y, mu, logvar


 # Separating observation and action from the input
        x = inp[:, :-self.dim_a].clone()
        a = inp[:, -self.dim_a:].clone()  
        
        # Encoders pass and sampling latent latent space
        mu = self.E_x(x)
        logvar = self.E_a(a)
        z = reparametrize(mu, logvar)

        # Merge action and latent and input to decoder
        merge = torch.cat((z, a), axis = 1)         
        y = self.decoder(merge)
        
        return y, mu, logvar

In [None]:
dim_x = 5; dim_a = 3; dim_z = 10; dim_y = 3
dataset = torch.rand((2, dim_x+dim_a))

vae_xa = VAE_xa(dim_x = dim_x, dim_a = dim_a, dim_y = dim_y)

assert vae_xa(dataset)[0].shape == (dataset.shape[0], dim_y)

# VAE$_{D_a}$: VAE with $a$ to decoder

In [None]:
#| export
class VAE_Da(torch.nn.Module):

    '''
    VAE with observation and action as input to single encoder, plus the action is also input to decoder.
    Consists of:
    - Encoder E_x: from observation x and action a to latent mu and logvar
    - Decoder: from latent z and action a to reconstruction y
    '''
    
    def __init__(self,
                 dim_x = 62, # Dimension input to encoder: the full trajectory (x,y)==((x),(y)); 51x2=102
                 dim_enc_h = [1000,256], # Dimension hidden FC
                 dim_dec_h = [256, 512], # Dimension hidden FC decoder
                 dim_z = 2, # Dimension latent
                 dim_a = 0, # Dimension action representation
                 dim_y = 62, # Dimension output which is the exact trajectory (x,y)==((x),(y))
                ):
        super().__init__()
        
        # Define latent dimension
        self.dim_z = dim_z 
        self.dim_a=dim_a
        
        # E_x
        self.lin1 = torch.nn.Linear(dim_x+dim_a, dim_enc_h[0])        
        self.lin2 = torch.nn.Linear(dim_enc_h[0], dim_enc_h[1])   
        self.lin3 = torch.nn.Linear(dim_enc_h[1], 2*dim_z)           
        
        # Decoder       
        self.lin1_d = torch.nn.Linear(dim_z+dim_a, dim_dec_h[0])        
        self.lin2_d = torch.nn.Linear(dim_dec_h[0], dim_dec_h[1])  
        self.lin3_d = torch.nn.Linear(dim_dec_h[1], dim_y)         
            

    def encoder(self,
                xa = None, # dimension: (BS , num_actions , input_size)
                ): 
        
        z_p = self.lin1(xa)        
        z_p = F.relu(z_p)
        z_p = self.lin2(z_p)   
        z_p = F.relu(z_p)
        z = self.lin3(z_p)
        
        return z   
    

    def decoder(self,
                za = None, # dimension: (BS , num_actions , input_size)
                ):           
        
        y_p = self.lin1_d(za)
        y_p = F.relu(y_p)
        y_p = self.lin2_d(y_p) 
        y_p = F.relu(y_p)
        y = self.lin3_d(y_p) 
        
        return y   

    def forward(self, inp):  

        # Separate action for the decoder
        a = inp[:, -self.dim_a:].clone()          

        # Encoder pass and latent computations
        distributions = self.encoder(inp)
        mu = distributions[:, :self.dim_z]
        logvar = distributions[:, self.dim_z:]       
        z = reparametrize(mu, logvar)  

        # Merge latent and actions and decoder pass
        merge = torch.cat((z, a), axis = 1)         
        y = self.decoder(merge)
        
        return y, mu, logvar

In [None]:
dim_x = 5; dim_a = 3; dim_z = 10; dim_y = 3
dataset = torch.rand((2, dim_x+dim_a))

vae_da = VAE_Da(dim_x = dim_x, dim_a = dim_a, dim_y = dim_y)

assert vae_da(dataset)[0].shape == (dataset.shape[0], dim_y)

# VAE with any action input

In [None]:
#| export
class VAE(torch.nn.Module):

    '''
    Vanilla VAE architecture.
    Consists of:
    - Encoder E_x: from observation x to latent mu and logvar
    - Decoder: from latent z to reconstruction x
    '''
    
    def __init__(self,
                 dim_x = 62, # Dimension input to encoder
                 dim_enc_h = [1000,256], # Dimension hidden FC
                 dim_dec_h = [260, 512], # Dimension hidden FC decoder
                 dim_z = 2, # Dimension latent
                ):
        super().__init__()
        
        # Define latent dimension
        self.dim_z = dim_z 
        
        # E_x
        self.lin1 = torch.nn.Linear(dim_x, dim_enc_h[0])        
        self.lin2 = torch.nn.Linear(dim_enc_h[0], dim_enc_h[1])  
        self.lin3 = torch.nn.Linear(dim_enc_h[1], 2*dim_z)   
        
        
        # Decoder       
        self.lin1_d = torch.nn.Linear(dim_z, dim_dec_h[0])        
        self.lin2_d = torch.nn.Linear(dim_dec_h[0], dim_dec_h[1])   
        self.lin3_d = torch.nn.Linear(dim_dec_h[1], dim_x)         
            

    def encoder(self,
                x = None, # dimension: (BS , num_actions , input_size)
                ): 
        
        z_p = self.lin1(x)        
        z_p = F.relu(z_p)
        z_p = self.lin2(z_p)   
        z_p = F.relu(z_p)
        z = self.lin3(z_p)        
        return z
    

    def decoder(self,
                z = None, # dimension: (BS , num_actions , dim_z)
                ):           
        
        y_p = self.lin1_d(z)
        y_p = F.relu(y_p)
        y_p = self.lin2_d(y_p) 
        y_p = F.relu(y_p)        
        y = self.lin3_d(y_p)        
        return y   

    def forward(self, inp): 

        # Encoder pass and latent operations
        distributions = self.encoder(x = inp)
        mu = distributions[:, :self.dim_z]
        logvar = distributions[:, self.dim_z:]       
        z = reparametrize(mu, logvar)  

        # Decoder pass
        y = self.decoder(z)
        
        return y, mu, logvar

In [None]:
dim_x = 5; dim_a = 3; dim_z = 10;
dataset = torch.rand((2, dim_x))

vae = VAE(dim_x = dim_x, dim_z = dim_z)

assert vae(dataset)[0].shape == (dataset.shape[0], dim_x)