In [1]:
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from omegaconf import DictConfig, OmegaConf
from typing import Callable, Tuple

In [23]:
def bmm_input(b_weight, b_input):
    batch_size, feature_dim = b_input.shape
    bmm = torch.einsum('nfh, nf -> nh', b_weight, b_input) / feature_dim
    return bmm


def bmm_output(b_weight, b_input):
    batch_size, output_dim, shared_output_dim = b_weight.shape
    batch_size, shared_output_dim = b_input.shape
    # [batch_size, 6, 32], [batch_size, 32]
    bmm = torch.einsum('noh, nh -> no', b_weight, b_input)  / shared_output_dim
    return bmm


In [8]:
def get_activation(activation_name: str):
    if activation_name == 'tanh':
        activation = nn.Tanh
    elif activation_name == 'learnable':
        pass
    elif activation_name == 'relu':
        activation = nn.ReLU
    elif activation_name == 'leakyrelu':
        activation = nn.LeakyReLU
    elif activation_name == "prelu":
        activation = nn.PReLU
    elif activation_name == 'gelu':
        activation = nn.GELU
    elif activation_name == 'sigmoid':
        activation = nn.Sigmoid
    elif activation_name in [ None, 'id', 'identity', 'linear', 'none' ]:
        activation = nn.Identity
    elif activation_name == 'elu':
        activation = nn.ELU
    elif activation_name in ['swish', 'silu']:
        activation = nn.SiLU
    elif activation_name == 'softplus':
        activation = nn.Softplus
    else:
        raise NotImplementedError("hidden activation '{}' is not implemented".format(activation))
    return activation

In [19]:
class RNNEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.d_model = 256
        self.final_act = "identity"
        self.final_act_func = get_activation(self.final_act)()
        self.embedding = nn.Linear(1, self.d_model)
        self.rnn = nn.GRU(input_size=self.d_model,
                          hidden_size=self.d_model,
                          num_layers=1,
                          bias=True,
                          batch_first=True,
                          bidirectional=True
                          )
                                                
    def forward(self, x):
        #  # x: [batch_size, feature_dim]
        # print(x.storage())
        batch_size = x.size(0)
        feature_dim = x.size(1)
        ux = x.unsqueeze(-1) # x: [batch_size, feature_dim, 1]
        ux = self.embedding(ux) # ux: [batch_size, feature_dim, 32]
        weights, h_n = self.rnn(ux) # weights shape: [batch_size, feature_dim, 2*d_model]
        print(weights.shape)
        weights = weights.reshape(batch_size, feature_dim, 2, self.d_model)
        weights = weights.mean(dim=2, keepdim=False)
        weights = self.final_act_func(weights)
        return weights

In [22]:
class RNNDecoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.d_model = 256
        self.final_act = "identity"
        self.final_act_func = get_activation(self.final_act)()
        self.num_layers = 1
        self.bias = True
        # GRU input: (N,L,H_in) output: (N,L,H_out)
        self.rnn = nn.GRU(input_size=self.d_model, 
                          hidden_size=self.d_model,
                          num_layers=self.num_layers,
                          bias=self.bias,
                          batch_first=True,
                          bidirectional=True
                          )
        
    def forward(self, out_dim, embed_featrue):
        # shared_feature: [batch_size, shared_dim]
        batch_size, *dims = embed_featrue.shape
        x = embed_featrue.unsqueeze(1) 
        # embed_featrue: [batch_size, 1, shared_dim]
        # features = [embed_featrue for i in range(out_dim)]
        # embed_featrue = torch.cat(features, dim=1) 
        x = x.expand(batch_size, out_dim, *dims)
        x = x.reshape(batch_size, out_dim, -1)
         # embed_featrue: [batch_size, out_dim, d_model]
        weights, _ = self.rnn(x) 
        # weights shape: [batch_size, out_dim, 2*d_model]
        weights = weights.reshape(batch_size, out_dim, 2, self.d_model)
        weights = weights.mean(dim=2, keepdim=False)
        # weights shape: [batch_size, out_dim, d_model]
        weights = self.final_act_func(weights)
        return weights

In [33]:
encoder = RNNEncoder()
mean_decoder = RNNDecoder()
torch.manual_seed(0)
cov_mat_decoder1 = RNNDecoder()
torch.manual_seed(42)
cov_mat_decoder2 = RNNDecoder()

obs = torch.ones([32, 17])
enc_weights = encoder(obs)
hidden = bmm_input(enc_weights, obs)
print(hidden.shape)
mean_weights = mean_decoder(6, hidden)
dec_weights1 = cov_mat_decoder1(6, hidden)
dec_weights2 = cov_mat_decoder2(6, hidden)
print(dec_weights1.shape)
print(dec_weights1 == dec_weights2)

mean = bmm_output(mean_weights, hidden)
cov_mat = torch.relu(dec_weights1 @ dec_weights2.mT)
print(cov_mat.shape)

torch.Size([32, 17, 512])
torch.Size([32, 256])
torch.Size([32, 6, 256])
tensor([[[False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False]],

        [[False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False]],

        [[False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., F

In [41]:
cov_mat2 = torch.relu(dec_weights1 @ dec_weights1.mT)
cov_mat2 = dec_weights2 @ dec_weights2.mT
cov_mat2.shape

torch.Size([32, 6, 6])

In [56]:
from torch.distributions import Normal, Categorical, MultivariateNormal
dist = MultivariateNormal(loc=mean, covariance_matrix=cov_mat2)


In [57]:
a = dist.sample()

In [58]:
a.shape

torch.Size([32, 6])

In [59]:
b = dist.log_prob(a)
b.shape

torch.Size([32])

In [52]:
mean = torch.randn([32, 6])
logstd = torch.randn([32, 6])
std = logstd.exp()
dist = Normal(mean,std)


In [68]:
i = torch.eye(6)
i.shape

torch.Size([6, 6])

In [73]:
i = torch.eye(6, 6)
bi = torch.stack([i for j in range(32)])
bi.shape

torch.Size([32, 6, 6])

In [54]:
a = dist.sample()
a.shape

torch.Size([32, 6])

In [55]:
b = dist.log_prob(a)
b.shape

torch.Size([32, 6])

In [61]:
e = dist.entropy()
e.shape


torch.Size([32])