In [1]:
import sys  
sys.path.append('../')

## Archi-1

In [23]:
import torch 
from torch import nn, Tensor
from torch.nn import functional as F
from torch.nn.modules import MultiheadAttention, Linear, Dropout, BatchNorm1d, TransformerEncoderLayer
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
from torch.nn import MSELoss,L1Loss

from model.ts_transformer import RelativeGlobalAttention
from model.modeling_lxmert import LxmertConfig, LxmertXLayer

In [4]:
class IMUEncoder(nn.Module):
    def __init__(self, in_ft, d_model, ft_size, n_classes, num_heads=1, max_len=1024, dropout=0.1):
        super(IMUEncoder, self).__init__()
        self.in_ft = in_ft
        self.max_len = max_len
        self.d_model = d_model
        self.num_heads = num_heads
        self.ft_size = ft_size 
        self.n_classes = n_classes

        self.lstm = nn.LSTM(input_size=self.in_ft,
                            hidden_size=self.d_model,
                            num_layers=self.num_heads,
                            batch_first=True,
                            bidirectional=True)
        self.drop = nn.Dropout(p=0.1)
        self.act = nn.ReLU()
        self.fcLayer1 = nn.Linear(2*self.d_model, self.ft_size)
        # self.fcLayer2 = nn.Linear(self.ft_size, self.ft_size)

    def forward(self, x):
        out, _ = self.lstm(x)
        out_forward = out[:, self.max_len - 1, :self.d_model]
        out_reverse = out[:, 0, self.d_model:]
        out_reduced = torch.cat((out_forward, out_reverse), 1)
        out = self.drop(out_reduced)
        out = self.act(out)
        out = self.fcLayer1(out)
        # out = self.fcLayer2(out)
        return out

In [18]:
class BiLSTMEncoder(nn.Module):
    def __init__(self, seq_len, input_size, hidden_size, linear_filters, embedding_size:int, num_layers = 1, bidirectional=True, batch_size=32, device='cpu'):
        super(BiLSTMEncoder, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.linear_filters = linear_filters
        self.embedding_size = embedding_size
        self.bidirectional = bidirectional
        self.batch_size = batch_size
        self.seq_len = seq_len
        self.device = device

        # define LSTM layer
        self.layers = []

        # add linear layers 
        for __id,layer_out in enumerate(self.linear_filters):
            if __id == 0:
                self.layers.append(nn.Linear(self.input_size, layer_out))
            else:
                self.layers.append(nn.Linear(self.linear_filters[__id-1], layer_out))

        # add lstm layer
        self.lstm = nn.LSTM(input_size = layer_out, hidden_size = self.hidden_size,
                            num_layers = self.num_layers, bidirectional=self.bidirectional,
                            batch_first=True)
        self.net = nn.Sequential(*self.layers)

        #add embedding out
        if bidirectional:
            self.out_linear = nn.Linear(self.hidden_size*4, self.embedding_size)
        else:
            self.out_linear = nn.Linear(self.hidden_size*2, self.embedding_size)

        
    def forward(self, x_input):
        '''
        : param x_input:               input of shape (seq_len, # in batch, input_size)
        : return lstm_out, hidden:     lstm_out gives all the hidden states in the sequence; hidden gives the hidden state and cell state for the last element in the sequence                         
        '''
        
        x = self.net(x_input)
        lstm_out, self.hidden = self.lstm(x)
        hidden_transformed = torch.concat(self.hidden,0)
        hidden_transformed = torch.transpose(hidden_transformed,0,1)
        hidden_transformed = torch.flatten(hidden_transformed,start_dim=1)
        hidden_transformed = self.out_linear(hidden_transformed)
        
        return hidden_transformed

    
class BiLSTMDecoder(nn.Module):
    def __init__(self,seq_len, input_size, hidden_size, linear_filters, embedding_size:int, num_layers = 1,bidirectional=True,batch_size=32, device='cpu'):
        super(BiLSTMDecoder, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.linear_filters = linear_filters[::-1]
        self.embedding_size = embedding_size
        self.bidirectional = bidirectional
        self.batch_size = batch_size
        self.seq_len = seq_len
        self.device = device

        if bidirectional:
            self.x_linear = nn.Linear(self.embedding_size, 4*self.hidden_size)
        else:
            self.input_linear = nn.Linear(self.embedding_size, 2*self.hidden_size)

        # define LSTM layer
        self.layers = []
        # add lstm
        self.lstm = nn.LSTM(input_size = self.linear_filters[0], hidden_size = self.hidden_size,
                            num_layers = self.num_layers, bidirectional=True,
                            batch_first=bidirectional)
              
        # add linear layers 
        if bidirectional:
            self.layers.append(nn.Linear(2*hidden_size, self.linear_filters[0]))
        else:
            self.layers.append(nn.Linear(hidden_size, self.linear_filters[0]))

        for __id,layer_in in enumerate(self.linear_filters):
            if __id == len(linear_filters)-1:
                self.layers.append(nn.Linear(layer_in, self.input_size))
            else:
                self.layers.append(nn.Linear(layer_in, self.linear_filters[__id+1]))

        self.net = nn.Sequential(*self.layers)

    def forward(self,encoder_hidden):
        '''
        : param x_input:               input of shape (seq_len, # in batch, input_size)
        : return lstm_out, hidden:     lstm_out gives all the hidden states in the sequence; hidden gives the hidden state and cell state for the last element in the sequence                         
        '''
        
        
        hidden_shape = encoder_hidden.shape
        encoder_hidden = self.input_linear(encoder_hidden)
        
        if self.bidirectional:
            hidden = encoder_hidden.view((self.batch_size, 4, self.hidden_size))
            # print(hidden.shape)
            hidden = torch.transpose(hidden,1,0)
            h1,h2,c1,c2 = torch.unbind(hidden,0)
            h,c = torch.stack((h1,h2)),torch.stack((c1,c2))
        else:
            hidden = encoder_hidden.view((self.batch_size, 2, self.hidden_size))
            hidden = torch.transpose(hidden,1,0)
            h,c = torch.unbind(hidden,0)
        
        dummy_input = torch.rand((self.batch_size, self.seq_len, self.hidden_size), requires_grad=True)
        dummy_input = dummy_input.to(self.device)
        
        lstm_out, self.hidden = self.lstm(dummy_input,(h,c))
        x = self.net(lstm_out)
        return x

class SkeletonAE(nn.Module):
    def __init__(self,seq_len, input_size, hidden_size, linear_filters=[128,256,512],embedding_size:int=256, num_layers = 1,bidirectional=True, batch_size=32, device='cpu'):
        super(SkeletonAE, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.linear_filters = linear_filters[::-1]
        self.embedding_size = embedding_size
        self.bidirectional = bidirectional
        self.batch_size = batch_size
        self.seq_len = seq_len
        
        self.encoder = BiLSTMEncoder(seq_len, input_size, hidden_size, linear_filters,embedding_size, num_layers = 1,bidirectional=True,batch_size=32)
        self.decoder = BiLSTMDecoder(seq_len, input_size, hidden_size, linear_filters,embedding_size, num_layers = 1,bidirectional=True,batch_size=32, device=device)
        
    def forward(self,x):
        lstm_out, embedding = self.encoder(x)
        decoder_out = self.decoder(embedding)
        
        return decoder_out, embedding  
        

In [19]:
imu_config = {
    'in_ft': 128,
    'd_model': 256,
    'num_heads': 2,
    'ft_size': 400,
    'n_classes': 18,
    'max_len': 60
}

imu_model = IMUEncoder(**imu_config)
imu_model

IMUEncoder(
  (lstm): LSTM(128, 256, num_layers=2, batch_first=True, bidirectional=True)
  (drop): Dropout(p=0.1, inplace=False)
  (act): ReLU()
  (fcLayer1): Linear(in_features=512, out_features=400, bias=True)
)

In [20]:
imu_input = torch.randn((32, 60, 128))
imu_model(imu_input).shape

torch.Size([32, 400])

In [21]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [27]:
skel_config = {
    'seq_len': 50,
    'input_size': 36,
    'hidden_size': 512,
    'linear_filters': [128,256,512],
    'embedding_size': 256,
    'num_layers': 1,
    'bidirectional': True,
    'batch_size': 32,
    'device': 'cpu'
}

skel_encoder = BiLSTMEncoder(**skel_config)
skel_encoder

BiLSTMEncoder(
  (lstm): LSTM(512, 512, batch_first=True, bidirectional=True)
  (net): Sequential(
    (0): Linear(in_features=36, out_features=128, bias=True)
    (1): Linear(in_features=128, out_features=256, bias=True)
    (2): Linear(in_features=256, out_features=512, bias=True)
  )
  (out_linear): Linear(in_features=2048, out_features=256, bias=True)
)

In [23]:
skel_input = torch.randn((32, 50, 36))
skel_encoder(skel_input).shape

torch.Size([32, 256])

In [29]:
skel_decoder = BiLSTMDecoder(**skel_config)
skel_decoder

BiLSTMDecoder(
  (input_linear): Linear(in_features=256, out_features=2048, bias=True)
  (lstm): LSTM(512, 512, batch_first=True, bidirectional=True)
  (net): Sequential(
    (0): Linear(in_features=1024, out_features=512, bias=True)
    (1): Linear(in_features=512, out_features=256, bias=True)
    (2): Linear(in_features=256, out_features=128, bias=True)
    (3): Linear(in_features=128, out_features=36, bias=True)
  )
)

In [30]:
skel_input = torch.randn((32, 256))
# skel_decoder = skel_decoder.to(device)
skel_decoder(skel_input).shape

torch.Size([32, 50, 36])

In [43]:
class BaseModel(nn.Module):
    def __init__(self, config):
        super(BaseModel, self).__init__()
        
        self.imu_model = IMUEncoder(**config['imu_config'])
        self.skel_encoder = BiLSTMEncoder(**config['skel_config'])
        self.skel_decoder = BiLSTMDecoder(**config['skel_config'])
        self.lxmert_config = LxmertConfig(**config['xmert_config'])
        self.lxmert_xlayer = LxmertXLayer(self.lxmert_config)

        self.num_layers = config['num_layers']

    def forward(self, x_imu, x_skel):
        imu_feats = self.imu_model(x_imu)
        skel_feats = self.skel_encoder(x_skel)
        print(f"imu_feats {imu_feats.shape} | skel_feats {skel_feats.shape}")
        for i in range(self.num_layers):
            x_outputs = self.lxmert_xlayer(
                lang_feats = skel_feats,
                lang_attention_mask = None,  
                visual_feats = imu_feats,
                visual_attention_mask = None,
                input_id = None,
                output_attentions=False,
            )
            skel_feats, imu_feats = x_outputs[:2]

        skel_recon = self.skel_decoder(skel_feats)
        return imu_feats, skel_recon


In [None]:
imu_config = {
    'in_ft': 128,
    'd_model': 256,
    'num_heads': 2,
    'ft_size': 400,
    'n_classes': 18,
    'max_len': 60
}

skel_config = {
    'seq_len': 50,
    'input_size': 36,
    'hidden_size': 512,
    'linear_filters': [128,256,512],
    'embedding_size': 400,
    'num_layers': 1,
    'bidirectional': True,
    'batch_size': 32,
    'device': 'cpu'
}

base_config = {
    'imu_config': imu_config,
    'skel_config': skel_config,
    'num_layers': 1,
    'xmert_config': {
        'vocab_size': 1024,
        'hidden_size': 400,
        'num_attention_heads': 2,
        'intermediate_size': 512
    }
}

base_model = BaseModel(base_config)
base_model

In [None]:
imu_input = torch.randn((32, 60, 128))
skel_input = torch.randn((32, 50, 36))

imu_output, skel_recon = base_model(imu_input, skel_input)

## Archi-2

In [86]:
import torch 
from torch import nn, Tensor
from torch.nn import functional as F
from torch.nn.modules import MultiheadAttention, Linear, Dropout, BatchNorm1d, TransformerEncoderLayer
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
from torch.nn import MSELoss,L1Loss

import torchinfo

from model.ts_transformer import RelativeGlobalAttention
from model.modeling_lxmert import LxmertConfig, LxmertXLayer

In [87]:
class IMUEncoder(nn.Module):
    def __init__(self, in_ft, d_model, num_heads, ft_size, n_classes, max_len=1024, dropout=0.1):
        super(IMUEncoder, self).__init__()
        self.in_ft = in_ft
        self.max_len = max_len
        self.d_model = d_model
        self.num_heads = num_heads
        self.ft_size = ft_size 
        self.n_classes = n_classes

        # feature prep layer
        self.DenseL = nn.Linear(in_ft, d_model)
        # relative global attention layer
        self.AttnL = RelativeGlobalAttention(self.d_model, self.num_heads, self.max_len)
        # positional encoding concat <-> 1DConv 
        self.Act = F.relu  # _get_activation_fn(activation)
        self.AvgPoolL = nn.AvgPool2d((self.max_len,1))
        self.DenseL2 = nn.Linear(self.d_model, self.ft_size)
        self.dropout = nn.Dropout1d(dropout)
        self.logist = nn.Linear(self.ft_size, self.n_classes)
        self.DenseL3 = nn.Linear(self.ft_size, self.ft_size)

    def forward(self, x):
        out = self.DenseL(x)
        out = self.AttnL(out)
        out = self.Act(out)
        out = self.AvgPoolL(out)
        out = torch.squeeze(out)
        out = self.dropout(out)
        out = self.DenseL2(out)   
        # out = self.DenseL3(out)
        # out = self.logist(out1)     
        return out

In [88]:
imu_config = {
        'in_ft':24, 
        'd_model':256, 
        'num_heads':2, 
        'ft_size':512, 
        'max_len':60, 
        'n_classes':18
    }

model = IMUEncoder(**imu_config)
# model
torchinfo.summary(model, input_size=(32, 60, 24), col_names = ("input_size", "output_size", "num_params", "kernel_size", "mult_adds"))

  action_fn=lambda data: sys.getsizeof(data.storage()),
  return super().__sizeof__() + self.nbytes()


Layer (type:depth-idx)                   Input Shape               Output Shape              Param #                   Kernel Shape              Mult-Adds
IMUEncoder                               [32, 60, 24]              [32, 512]                 271,890                   --                        --
├─Linear: 1-1                            [32, 60, 24]              [32, 60, 256]             6,400                     --                        204,800
├─RelativeGlobalAttention: 1-2           [32, 60, 256]             [32, 60, 256]             7,680                     --                        --
│    └─Linear: 2-1                       [32, 60, 256]             [32, 60, 256]             65,792                    --                        2,105,344
│    └─Linear: 2-2                       [32, 60, 256]             [32, 60, 256]             65,792                    --                        2,105,344
│    └─Linear: 2-3                       [32, 60, 256]             [32, 60, 256]      

In [89]:
from model.sgn_model import embed, local, gcn_spa, compute_g_spa 
import math

In [90]:
class SGN(nn.Module):
    def __init__(self, num_joint, seg, hidden_size=128, bs=32, train=True, bias=True, device='cpu'):
        super(SGN, self).__init__()

        self.dim1 = hidden_size
        self.seg = seg
        self.num_joint = num_joint
        self.bs = bs
        if train:
            self.spa = self.one_hot(bs, num_joint, self.seg)
            self.spa = self.spa.permute(0, 3, 2, 1).to(device)
            self.tem = self.one_hot(bs, self.seg, num_joint)
            self.tem = self.tem.permute(0, 3, 1, 2).to(device)
        else:
            self.spa = self.one_hot(32 * 5, num_joint, self.seg)
            self.spa = self.spa.permute(0, 3, 2, 1).to(device)
            self.tem = self.one_hot(32 * 5, self.seg, num_joint)
            self.tem = self.tem.permute(0, 3, 1, 2).to(device)

        self.tem_embed = embed(self.seg, joint=num_joint, hidden_dim=64*4, norm=False, bias=bias)
        self.spa_embed = embed(num_joint, joint=num_joint, hidden_dim=64, norm=False, bias=bias)
        self.joint_embed = embed(3, joint=num_joint, hidden_dim=64, norm=True, bias=bias)
        self.dif_embed = embed(3, joint=num_joint, hidden_dim=64, norm=True, bias=bias)
        self.maxpool = nn.AdaptiveMaxPool2d([1, 1])
        self.cnn = local(self.dim1, self.dim1 * 2, bias=bias)
        self.compute_g1 = compute_g_spa(self.dim1 // 2, self.dim1, bias=bias)
        self.gcn1 = gcn_spa(self.dim1 // 2, self.dim1 // 2, bias=bias)
        self.gcn2 = gcn_spa(self.dim1 // 2, self.dim1, bias=bias)
        self.gcn3 = gcn_spa(self.dim1, self.dim1, bias=bias)
        # self.fc = nn.Linear(self.dim1 * 2, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))

        nn.init.constant_(self.gcn1.w.cnn.weight, 0)
        nn.init.constant_(self.gcn2.w.cnn.weight, 0)
        nn.init.constant_(self.gcn3.w.cnn.weight, 0)


    def forward(self, input):
        
        # Dynamic Representation
        # bs, step, dim = input.size()
        # num_joints = dim //3
        input = input.view((self.bs, self.seg, self.num_joint, 3))
        input = input.permute(0, 3, 2, 1).contiguous()
        dif = input[:, :, :, 1:] - input[:, :, :, 0:-1]
        dif = torch.cat([dif.new(self.bs, dif.size(1), self.num_joint, 1).zero_(), dif], dim=-1)
        # print(input.shape)
        pos = self.joint_embed(input)
        tem1 = self.tem_embed(self.tem)
        spa1 = self.spa_embed(self.spa)
        dif = self.dif_embed(dif)
        dy = pos + dif
        # Joint-level Module
        input= torch.cat([dy, spa1], 1)
        g = self.compute_g1(input)
        input = self.gcn1(input, g)
        input = self.gcn2(input, g)
        input = self.gcn3(input, g)
        # Frame-level Module
        input = input + tem1
        input = self.cnn(input)
        output = torch.squeeze(input)
        output = output.permute(0, 2, 1).contiguous()
        # Classification
        # output = self.maxpool(input)
        # output = torch.flatten(output, 1)
        # output = self.fc(output)

        return output

    def one_hot(self, bs, spa, tem):

        y = torch.arange(spa).unsqueeze(-1)
        y_onehot = torch.FloatTensor(spa, spa)

        y_onehot.zero_()
        y_onehot.scatter_(1, y, 1)

        y_onehot = y_onehot.unsqueeze(0).unsqueeze(0)
        y_onehot = y_onehot.repeat(bs, tem, 1, 1)

        return y_onehot

In [91]:
sgn = SGN(num_joint=12, seg=60, hidden_size=256, bs=32, train=True)
sample = torch.randn((32, 60, 36))
sgn(sample).shape
# torchinfo.summary(sgn, input_size=(32, 60, 36), col_names = ("input_size", "output_size", "num_params", "kernel_size", "mult_adds"))
# sgn

torch.Size([32, 60, 512])

In [92]:
# torchinfo.summary(sgn, input_size=(32, 60, 36), col_names = ("input_size", "output_size", "num_params", "kernel_size", "mult_adds"))


In [93]:
class BiLSTMDecoder(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_layers = 1, bidirectional=True, batch_size=32, device='cpu'):
        super(BiLSTMDecoder, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.bidirectional = bidirectional
        self.batch_size = batch_size
        self.device = device

        # # define LSTM layer
        # self.layers = []

        # # add linear layers 
        # for __id,layer_out in enumerate(self.linear_filters):
        #     if __id == 0:
        #         self.layers.append(nn.Linear(self.input_size, layer_out))
        #     else:
        #         self.layers.append(nn.Linear(self.linear_filters[__id-1], layer_out))

        # add lstm layer
        self.lstm = nn.LSTM(input_size = input_size, hidden_size = self.hidden_size,
                            num_layers = self.num_layers, bidirectional=self.bidirectional,
                            batch_first=True, proj_size = output_size)
        # self.net = nn.Sequential(*self.layers)

        # #add embedding out
        # if bidirectional:
        #     self.out_linear = nn.Linear(self.hidden_size*4, self.embedding_size)
        # else:
        #     self.out_linear = nn.Linear(self.hidden_size*2, self.embedding_size)

        
    def forward(self, x):
        '''
        : param x_input:               input of shape (seq_len, # in batch, input_size)
        : return lstm_out, hidden:     lstm_out gives all the hidden states in the sequence; hidden gives the hidden state and cell state for the last element in the sequence                         
        '''
        
        # x = self.net(x_input)
        out, _ = self.lstm(x)
        # hidden_transformed = torch.concat(self.hidden,0)
        # hidden_transformed = torch.transpose(hidden_transformed,0,1)
        # hidden_transformed = torch.flatten(hidden_transformed,start_dim=1)
        # hidden_transformed = self.out_linear(hidden_transformed)
        
        return out

In [94]:
dec = BiLSTMDecoder(input_size=512, hidden_size=128, num_layers=2, output_size=36, bidirectional=False)

In [95]:
torchinfo.summary(dec, input_size=(32, 60, 512), col_names = ("input_size", "output_size", "num_params", "kernel_size", "mult_adds"))

Layer (type:depth-idx)                   Input Shape               Output Shape              Param #                   Kernel Shape              Mult-Adds
BiLSTMDecoder                            [32, 60, 512]             [32, 60, 36]              --                        --                        --
├─LSTM: 1-1                              [32, 60, 512]             [32, 60, 36]              328,704                   --                        631,111,680
Total params: 328,704
Trainable params: 328,704
Non-trainable params: 0
Total mult-adds (M): 631.11
Input size (MB): 3.93
Forward/backward pass size (MB): 0.55
Params size (MB): 1.31
Estimated Total Size (MB): 5.80

In [98]:
class BaseModel(nn.Module):
    def __init__(self, config):
        super(BaseModel, self).__init__()
        
        self.imu_model = IMUEncoder(**config['imu_config'])
        self.skel_encoder = SGN(**config['sgn_config'])
        self.skel_decoder = BiLSTMDecoder(**config['dec_config'])
        self.lxmert_config = LxmertConfig(**config['xmert_config'])
        self.lxmert_xlayer = LxmertXLayer(self.lxmert_config)

        self.num_layers = config['num_layers']

    def forward(self, x_imu, x_skel):
        imu_feats = self.imu_model(x_imu).unsqueeze(1)
        skel_feats = self.skel_encoder(x_skel)
        print(f"imu_feats {imu_feats.shape} | skel_feats {skel_feats.shape}")
        for i in range(self.num_layers):
            x_outputs = self.lxmert_xlayer(
                lang_feats = skel_feats,
                lang_attention_mask = None,  
                visual_feats = imu_feats,
                visual_attention_mask = None,
                input_id = None,
                output_attentions=False,
            )
            skel_feats, imu_feats = x_outputs[:2]

        skel_recon = self.skel_decoder(skel_feats)
        return imu_feats, skel_recon

In [99]:
imu_config = {
    'in_ft': 54,
    'd_model': 256,
    'num_heads': 2,
    'ft_size': 512,
    'n_classes': 18,
    'max_len': 60
}

sgn_config = {
    'num_joint': 12,
    'seg': 60,
    'hidden_size': 256,
    'train': True,
    'bs': 32,
}

dec_config = {
    'input_size': 512,
    'hidden_size': 256,
    'num_layers': 2,
    'output_size': 36,
    'bidirectional': False
}



base_config = {
    'imu_config': imu_config,
    'sgn_config': sgn_config,
    'dec_config': dec_config,
    'num_layers': 1,
    'xmert_config': {
        'vocab_size': 1024,
        'hidden_size': 512,
        'num_attention_heads': 2,
        'intermediate_size': 512
    }
}
 
base_model = BaseModel(base_config)

In [100]:
imu_input = torch.randn((32, 60, 54))
skel_input = torch.randn((32, 60, 36))

imu_output, skel_recon = base_model(imu_input, skel_input)

imu_feats torch.Size([32, 1, 512]) | skel_feats torch.Size([32, 60, 512])


  result = _VF.lstm(input, hx, self._flat_weights, self.bias, self.num_layers,


In [101]:
print(f"imu_output : {imu_output.shape} | skel_output : {skel_recon.shape}")

imu_output : torch.Size([32, 1, 512]) | skel_output : torch.Size([32, 60, 36])


## Archi3

In [2]:
import torch 
from torch import nn, Tensor
from torch.nn import functional as F
from torch.nn.modules import MultiheadAttention, Linear, Dropout, BatchNorm1d, TransformerEncoderLayer
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
from torch.nn import MSELoss,L1Loss

import torchinfo

# from model.sgn_model import embed, local, gcn_spa, compute_g_spa 
import math

In [3]:
class norm_data(nn.Module):
    def __init__(self, dim=3, joints=20):
        super(norm_data, self).__init__()

        self.bn = nn.BatchNorm1d(dim*joints)

    def forward(self, x):
        bs, c, num_joints, step = x.size()
        x = x.view(bs, -1, step)
        x = self.bn(x)
        x = x.view(bs, -1, num_joints, step).contiguous()
        return x

class embed(nn.Module):
    def __init__(self, dim=3, joint=20, hidden_dim=128, norm=True, bias=False):
        super(embed, self).__init__()

        if norm:
            self.cnn = nn.Sequential(
                norm_data(dim, joint),
                cnn1x1(dim, 64, bias=bias),
                nn.ReLU(),
                cnn1x1(64, hidden_dim, bias=bias),
                nn.ReLU(),
            )
        else:
            self.cnn = nn.Sequential(
                cnn1x1(dim, 64, bias=bias),
                nn.ReLU(),
                cnn1x1(64, hidden_dim, bias=bias),
                nn.ReLU(),
            )

    def forward(self, x):
        x = self.cnn(x)
        return x

class cnn1x1(nn.Module):
    def __init__(self, dim1 = 3, dim2 =3, bias = True):
        super(cnn1x1, self).__init__()
        self.cnn = nn.Conv2d(dim1, dim2, kernel_size=1, bias=bias)

    def forward(self, x):
        x = self.cnn(x)
        return x

class local(nn.Module):
    def __init__(self, dim1 = 3, dim2 = 3, bias = False):
        super(local, self).__init__()
        self.maxpool = nn.AdaptiveMaxPool2d((1, None))
        self.cnn1 = nn.Conv2d(dim1, dim1, kernel_size=(1, 3), padding=(0, 1), bias=bias)
        self.bn1 = nn.BatchNorm2d(dim1)
        self.relu = nn.ReLU()
        self.cnn2 = nn.Conv2d(dim1, dim2, kernel_size=1, bias=bias)
        self.bn2 = nn.BatchNorm2d(dim2)
        self.dropout = nn.Dropout2d(0.2)

    def forward(self, x1):
        x1 = self.maxpool(x1)
        x = self.cnn1(x1)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.cnn2(x)
        x = self.bn2(x)
        x = self.relu(x)

        return x

class gcn_spa(nn.Module):
    def __init__(self, in_feature, out_feature, bias = False):
        super(gcn_spa, self).__init__()
        self.bn = nn.BatchNorm2d(out_feature)
        self.relu = nn.ReLU()
        self.w = cnn1x1(in_feature, out_feature, bias=False)
        self.w1 = cnn1x1(in_feature, out_feature, bias=bias)


    def forward(self, x1, g):
        x = x1.permute(0, 3, 2, 1).contiguous()
        x = g.matmul(x)
        x = x.permute(0, 3, 2, 1).contiguous()
        x = self.w(x) + self.w1(x1)
        x = self.relu(self.bn(x))
        return x

class compute_g_spa(nn.Module):
    def __init__(self, dim1 = 64 *3, dim2 = 64*3, bias = False):
        super(compute_g_spa, self).__init__()
        self.dim1 = dim1
        self.dim2 = dim2
        self.g1 = cnn1x1(self.dim1, self.dim2, bias=bias)
        self.g2 = cnn1x1(self.dim1, self.dim2, bias=bias)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x1):

        g1 = self.g1(x1).permute(0, 3, 2, 1).contiguous()
        g2 = self.g2(x1).permute(0, 3, 1, 2).contiguous()
        g3 = g1.matmul(g2)
        g = self.softmax(g3)
        return g
    

class SGNEncoder(nn.Module):
    def __init__(self, num_joint, seg, hidden_size=128, bs=32, is_3d=True, train=True, bias=True, device='cpu'):
        super(SGNEncoder, self).__init__()

        self.dim1 = hidden_size
        self.dim_unit = hidden_size // 4 
        self.seg = seg
        self.num_joint = num_joint
        self.bs = bs

        if is_3d:
          self.spatial_dim = 3
        else:
          self.spatial_dim = 2

        if train:
            self.spa = self.one_hot(bs, num_joint, self.seg)
            self.spa = self.spa.permute(0, 3, 2, 1).to(device)
            self.tem = self.one_hot(bs, self.seg, num_joint)
            self.tem = self.tem.permute(0, 3, 1, 2).to(device)
        else:
            self.spa = self.one_hot(32 * 5, num_joint, self.seg)
            self.spa = self.spa.permute(0, 3, 2, 1).to(device)
            self.tem = self.one_hot(32 * 5, self.seg, num_joint)
            self.tem = self.tem.permute(0, 3, 1, 2).to(device)

        self.tem_embed = embed(self.seg, joint=self.num_joint, hidden_dim=self.dim_unit*4, norm=False, bias=bias)
        self.spa_embed = embed(num_joint, joint=self.num_joint, hidden_dim=self.dim_unit, norm=False, bias=bias)
        self.joint_embed = embed(self.spatial_dim, joint=self.num_joint, hidden_dim=self.dim_unit, norm=True, bias=bias)
        self.dif_embed = embed(self.spatial_dim, joint=self.num_joint, hidden_dim=self.dim_unit, norm=True, bias=bias)
        self.maxpool = nn.AdaptiveMaxPool2d([1, 1])
        self.cnn = local(self.dim1, self.dim1 * 2, bias=bias)
        self.compute_g1 = compute_g_spa(self.dim1 // 2, self.dim1, bias=bias)
        self.gcn1 = gcn_spa(self.dim1 // 2, self.dim1 // 2, bias=bias)
        self.gcn2 = gcn_spa(self.dim1 // 2, self.dim1, bias=bias)
        self.gcn3 = gcn_spa(self.dim1, self.dim1, bias=bias)


        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))

        nn.init.constant_(self.gcn1.w.cnn.weight, 0)
        nn.init.constant_(self.gcn2.w.cnn.weight, 0)
        nn.init.constant_(self.gcn3.w.cnn.weight, 0)


    def forward(self, input):
        
        # Dynamic Representation
        input = input.view((self.bs, self.seg, self.num_joint, self.spatial_dim))
        input = input.permute(0, 3, 2, 1).contiguous().to()
        dif = input[:, :, :, 1:] - input[:, :, :, 0:-1]
        dif = torch.cat([dif.new(self.bs, dif.size(1), self.num_joint, 1).zero_(), dif], dim=-1)
        # print(input.shape)
        print("input :", input.get_device(), " dif : ", dif.get_device())
        pos = self.joint_embed(input)
        print("pos : ", pos.get_device(), " tem : ", self.tem.get_device(), " spa : ", self.spa.get_device())
        tem1 = self.tem_embed(self.tem)
        spa1 = self.spa_embed(self.spa)
        dif = self.dif_embed(dif)
        dy = pos + dif
        # Joint-level Module
        input= torch.cat([dy, spa1], 1)
        g = self.compute_g1(input)
        input = self.gcn1(input, g)
        input = self.gcn2(input, g)
        input = self.gcn3(input, g)
        # Frame-level Module
        input = input + tem1
        input = self.cnn(input)
        output_feat = torch.squeeze(input).permute(0,2,1)

        return output_feat

    def one_hot(self, bs, spa, tem):

        y = torch.arange(spa).unsqueeze(-1)
        y_onehot = torch.FloatTensor(spa, spa)

        y_onehot.zero_()
        y_onehot.scatter_(1, y, 1)

        y_onehot = y_onehot.unsqueeze(0).unsqueeze(0)
        y_onehot = y_onehot.repeat(bs, tem, 1, 1)

        return y_onehot

class SGNClassifier(nn.Module):
  def __init__(self,num_classes,embedding_size, *args, **kwargs) -> None:
      super().__init__(*args, **kwargs)
      self.num_classes = num_classes
      self.embedding_size = embedding_size
      self.maxpool = nn.AdaptiveMaxPool2d([embedding_size//2, 2])
      self.fc = nn.Linear(self.embedding_size, self.num_classes)

  def forward(self, input):
      output = self.maxpool(input)
      output = torch.flatten(output, 1)
      output = self.fc(output)
      return output
    
class BiLSTMDecoder(nn.Module):
    def __init__(self,seq_len, input_size, hidden_size, linear_filters, embedding_size:int, num_layers = 1, bidirectional=True, device='cpu'):
        super(BiLSTMDecoder, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.device = device
        self.num_layers = num_layers
        self.linear_filters = linear_filters[::-1]
        self.embedding_size = embedding_size
        self.bidirectional = bidirectional
        self.seq_len = seq_len

        if bidirectional:
            self.input_linear = nn.Linear(self.embedding_size,4*self.hidden_size)
        else:
            self.input_linear = nn.Linear(self.embedding_size,2*self.hidden_size)

        # define LSTM layer
        self.layers = []
        # add lstm
        self.lstm = nn.LSTM(input_size = self.seq_len, hidden_size = self.hidden_size,
                            num_layers = self.num_layers, bidirectional=True,
                            batch_first=True)
        
        self.maxpool = nn.AdaptiveMaxPool2d([self.seq_len, self.seq_len])
        # add linear layers 
        if bidirectional:
            self.layers.append(nn.Linear(2*hidden_size,self.linear_filters[0]))
        else:
            self.layers.append(nn.Linear(hidden_size,self.linear_filters[0]))

        for __id,layer_in in enumerate(self.linear_filters):
            if __id == len(linear_filters)-1:
                self.layers.append(nn.Linear(layer_in,self.input_size))
            else:
                self.layers.append(nn.Linear(layer_in,self.linear_filters[__id+1]))

        self.net = nn.Sequential(*self.layers)

    def forward(self,encoder_hidden):
        """
        : param x_input:               input of shape (seq_len, # in batch, input_size)
        : return lstm_out, hidden:     lstm_out gives all the hidden states in the sequence; hidden gives the hidden state and cell state for the last element in the sequence
        """
        
        output = self.maxpool(encoder_hidden)
        lstm_out, self.hidden = self.lstm(output)
        x = self.net(lstm_out)
        # x = x.permute(0,2,1)
        return x

class EncDecModel(nn.Module):
    def __init__(self,encoder,decoder,classifier):
        super(EncDecModel, self).__init__()
        
        self.encoder = encoder
        self.decoder = decoder
        self.classifier = classifier
        
    def forward(self,x):
        embedding = self.encoder(x)
        classifier_out = self.classifier(embedding)
        decoder_out = self.decoder(embedding)
        
        return decoder_out, embedding, classifier_out
        

In [4]:
sgn_config = {
    'num_joint': 12,
    'seg': 60,
    'hidden_size': 512,
    'train': True,
    'bs': 32,
    'is_3d': False
}

In [5]:
sgn_model = SGNEncoder(**sgn_config)
sgn_input = torch.randn((32, 60, 24))
sgn_output = sgn_model(sgn_input)
sgn_output.shape

input : -1  dif :  -1
pos :  -1  tem :  -1  spa :  -1


torch.Size([32, 60, 1024])

In [6]:
dec_config = {
    'seq_len': 60,
    'input_size': 512,
    'hidden_size': 256,
    'num_layers': 2,
    'bidirectional': True,
    "embedding_size": 512,
    "linear_filters":[128,256,512,1024],
}

# seq_len, input_size, hidden_size, linear_filters, embedding_size:int, num_layers = 1,bidirectional=True, device='cpu'

In [7]:
skel_dec = BiLSTMDecoder(**dec_config)
skel_input = torch.randn((32, 60, 512))
skel_output = skel_dec(skel_input)
skel_output.shape

torch.Size([32, 60, 512])

In [24]:
import torch
from torch import nn
from torch.nn import TransformerEncoder, TransformerEncoderLayer


class IMUTransformerEncoder(nn.Module):

    def __init__(self, config):
        """
        config: (dict) configuration of the model
        """
        super().__init__()

        self.transformer_dim = config.get("transformer_dim")

        self.input_proj = nn.Sequential(nn.Conv1d(config.get("input_dim"), self.transformer_dim, 1), nn.GELU(),
                                        nn.Conv1d(self.transformer_dim, self.transformer_dim, 1), nn.GELU(),
                                        nn.Conv1d(self.transformer_dim, self.transformer_dim, 1), nn.GELU(),
                                        nn.Conv1d(self.transformer_dim, self.transformer_dim, 1), nn.GELU())

        self.window_size = config.get("window_size")
        self.encode_position = config.get("encode_position")
        encoder_layer = TransformerEncoderLayer(d_model = self.transformer_dim,
                                       nhead = config.get("nhead"),
                                       dim_feedforward = config.get("dim_feedforward"),
                                       dropout = config.get("transformer_dropout"),
                                       activation = config.get("transformer_activation"))

        self.transformer_encoder = TransformerEncoder(encoder_layer,
                                              num_layers = config.get("num_encoder_layers"),
                                              norm = nn.LayerNorm(self.transformer_dim))
        self.cls_token = nn.Parameter(torch.zeros((1, self.transformer_dim)), requires_grad=True)

        if self.encode_position:
            self.position_embed = nn.Parameter(torch.randn(self.window_size + 1, 1, self.transformer_dim))

        # num_classes =  config.get("num_classes")
        output_size = config.get("output_size")
        self.imu_head = nn.Sequential(
            nn.AvgPool2d((self.window_size,1)),
            nn.Linear(self.transformer_dim,  output_size),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(output_size, output_size)
        )
        self.sigmoid = nn.Sigmoid()

        # init
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def forward(self, data):
        src = data  # Shape N x S x C with S = sequence length, N = batch size, C = channels

        # Embed in a high dimensional space and reshape to Transformer's expected shape
        src = self.input_proj(src.transpose(1, 2)).permute(2, 0, 1)

        # Prepend class token
        cls_token = self.cls_token.unsqueeze(1).repeat(1, src.shape[1], 1)
        src = torch.cat([cls_token, src])

        # Add the position embedding
        if self.encode_position:
            src += self.position_embed

        # Transformer Encoder pass
        target = self.transformer_encoder(src)

        # Class probability
        target = torch.squeeze(target.permute(1,0,2))
        # target = self.imu_head(target)
        return target

def get_activation(activation):
    """Return an activation function given a string"""
    if activation == "relu":
        return nn.ReLU(inplace=True)
    if activation == "gelu":
        return nn.GELU()
    raise RuntimeError("Activation {} not supported".format(activation))

In [25]:
imu_config = {
	"input_dim": 6,
    "window_size":50,
	"encode_position":True,
	"transformer_dim": 256,
	"nhead": 8,
	"num_encoder_layers": 6, 
	"dim_feedforward": 128, 
	"transformer_dropout": 0.1, 
	"transformer_activation": "gelu",
	"head_activation": "gelu",
    "baseline_dropout": 0.1,
	"batch_size": 32,
	"output_size": 512
}

In [26]:
imu_model = IMUTransformerEncoder(imu_config)
imu_input = torch.randn((32, 50, 6))
imu_output = imu_model(imu_input)
imu_output.shape

torch.Size([32, 51, 256])

In [27]:
torchinfo.summary(imu_model, input_size=(32, 50, 6), col_names = ("input_size", "output_size", "num_params", "kernel_size", "mult_adds"))

Layer (type:depth-idx)                        Input Shape               Output Shape              Param #                   Kernel Shape              Mult-Adds
IMUTransformerEncoder                         [32, 50, 6]               [32, 51, 256]             407,552                   --                        --
├─Sequential: 1-1                             [32, 6, 50]               [32, 256, 50]             --                        --                        --
│    └─Conv1d: 2-1                            [32, 6, 50]               [32, 256, 50]             1,792                     [1]                       2,867,200
│    └─GELU: 2-2                              [32, 256, 50]             [32, 256, 50]             --                        --                        --
│    └─Conv1d: 2-3                            [32, 256, 50]             [32, 256, 50]             65,792                    [1]                       105,267,200
│    └─GELU: 2-4                              [32, 256, 50]

In [11]:
import sys 
sys.path.append('../')


# from model.modeling_lxmert import LxmertConfig, LxmertXLayer

  from .autonotebook import tqdm as notebook_tqdm


In [12]:
class CrossAttention(nn.Module):
    def __init__(self, input1_dim, input2_dim, hidden_dim):
        super(CrossAttention, self).__init__()
        self.W1 = nn.Linear(input1_dim, hidden_dim, bias=False)
        self.W2 = nn.Linear(input2_dim, hidden_dim, bias=False)

    def forward(self, input1, input2):
        # Compute attention weights for input1 with respect to input2
        a1 = torch.matmul(self.W1(input1), self.W2(input2).transpose(1, 2))
        a1 = torch.softmax(a1, dim=2)
        v1 = torch.matmul(a1, input2)

        # Compute attention weights for input2 with respect to input1
        a2 = torch.matmul(self.W2(input2), self.W1(input1).transpose(1, 2))
        a2 = torch.softmax(a2, dim=2)
        v2 = torch.matmul(a2, input1)

        return v1, v2

In [17]:
class BaseModel(nn.Module):
    def __init__(self, config):
        super(BaseModel, self).__init__()
        
        self.imu_model = IMUTransformerEncoder(config['imu_config'])
        self.skel_encoder = SGNEncoder(**config['sgn_config'])
        self.skel_decoder = BiLSTMDecoder(**config['dec_config'])
        # self.lxmert_config = LxmertConfig(**config['xmert_config'])
        self.cross_attn = CrossAttention(**config['xmert_config'])

        self.num_layers = config['num_layers']

    def forward(self, x_imu, x_skel):
        imu_feats = self.imu_model(x_imu).unsqueeze(1)
        skel_feats = self.skel_encoder(x_skel)
        print(f"imu_feats {imu_feats.shape} | skel_feats {skel_feats.shape}")
        skel_feats, imu_feats = self.cross_attn(skel_feats, imu_feats)

        skel_recon = self.skel_decoder(skel_feats)
        return imu_feats, skel_recon

In [18]:
imu_config = {
	"input_dim": 54,
    "window_size":60,
	"encode_position":True,
	"transformer_dim": 256,
	"nhead": 8,
	"num_encoder_layers": 6, 
	"dim_feedforward": 128, 
	"transformer_dropout": 0.1, 
	"transformer_activation": "gelu",
	"head_activation": "gelu",
    "baseline_dropout": 0.1,
	"batch_size": 32,
	"output_size": 512
}

sgn_config = {
    'num_joint': 12,
    'seg': 60,
    'hidden_size': 256,
    'train': True,
    'bs': 32,
    'is_3d': False
}

dec_config = {
    'seq_len': 60,
    'input_size': 512,
    'hidden_size': 256,
    'num_layers': 2,
    'bidirectional': True,
    "embedding_size": 512,
    "linear_filters":[128,256,512,1024],
}


base_config = {
    'imu_config': imu_config,
    'sgn_config': sgn_config,
    'dec_config': dec_config,
    'num_layers': 1,
    'xmert_config': {
        'input1_dim': 60,
        'input2_dim': 1,
        'hidden_dim': 512,
    }
}
 

base_model = BaseModel(base_config)

In [None]:
imu_input = torch.randn((32, 60, 54))
skel_input = torch.randn((32, 60, 24))

imu_output, skel_recon = base_model(imu_input, skel_input)

In [30]:
print(f"imu_output : {imu_output.shape} | skel_output : {skel_recon.shape}")

imu_output : torch.Size([32, 1, 512]) | skel_output : torch.Size([32, 60, 512])


In [31]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
base_model.to(device)
imu_input = imu_input.to(device)
skel_input = skel_input.to(device)

In [None]:
imu_output, skel_recon = base_model(imu_input, skel_input)

## Archi-4

In [2]:
import torch 
from torch import nn, Tensor
from torch.nn import functional as F
from torch.nn.modules import MultiheadAttention, Linear, Dropout, BatchNorm1d, TransformerEncoderLayer
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
from torch.nn import MSELoss,L1Loss

import torchinfo

# from model.sgn_model import embed, local, gcn_spa, compute_g_spa 
import math

In [3]:
class norm_data(nn.Module):
    def __init__(self, dim=3, joints=20):
        super(norm_data, self).__init__()

        self.bn = nn.BatchNorm1d(dim*joints)

    def forward(self, x):
        bs, c, num_joints, step = x.size()
        x = x.view(bs, -1, step)
        x = self.bn(x)
        x = x.view(bs, -1, num_joints, step).contiguous()
        return x

class embed(nn.Module):
    def __init__(self, dim=3, joint=20, hidden_dim=128, norm=True, bias=False):
        super(embed, self).__init__()

        if norm:
            self.cnn = nn.Sequential(
                norm_data(dim, joint),
                cnn1x1(dim, 64, bias=bias),
                nn.ReLU(),
                cnn1x1(64, hidden_dim, bias=bias),
                nn.ReLU(),
            )
        else:
            self.cnn = nn.Sequential(
                cnn1x1(dim, 64, bias=bias),
                nn.ReLU(),
                cnn1x1(64, hidden_dim, bias=bias),
                nn.ReLU(),
            )

    def forward(self, x):
        x = self.cnn(x)
        return x

class cnn1x1(nn.Module):
    def __init__(self, dim1 = 3, dim2 =3, bias = True):
        super(cnn1x1, self).__init__()
        self.cnn = nn.Conv2d(dim1, dim2, kernel_size=1, bias=bias)

    def forward(self, x):
        x = self.cnn(x)
        return x

class local(nn.Module):
    def __init__(self, dim1 = 3, dim2 = 3, bias = False):
        super(local, self).__init__()
        self.maxpool = nn.AdaptiveMaxPool2d((1, None))
        self.cnn1 = nn.Conv2d(dim1, dim1, kernel_size=(1, 3), padding=(0, 1), bias=bias)
        self.bn1 = nn.BatchNorm2d(dim1)
        self.relu = nn.ReLU()
        self.cnn2 = nn.Conv2d(dim1, dim2, kernel_size=1, bias=bias)
        self.bn2 = nn.BatchNorm2d(dim2)
        self.dropout = nn.Dropout2d(0.2)

    def forward(self, x1):
        x1 = self.maxpool(x1)
        x = self.cnn1(x1)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.cnn2(x)
        x = self.bn2(x)
        x = self.relu(x)

        return x

class gcn_spa(nn.Module):
    def __init__(self, in_feature, out_feature, bias = False):
        super(gcn_spa, self).__init__()
        self.bn = nn.BatchNorm2d(out_feature)
        self.relu = nn.ReLU()
        self.w = cnn1x1(in_feature, out_feature, bias=False)
        self.w1 = cnn1x1(in_feature, out_feature, bias=bias)


    def forward(self, x1, g):
        x = x1.permute(0, 3, 2, 1).contiguous()
        x = g.matmul(x)
        x = x.permute(0, 3, 2, 1).contiguous()
        x = self.w(x) + self.w1(x1)
        x = self.relu(self.bn(x))
        return x

class compute_g_spa(nn.Module):
    def __init__(self, dim1 = 64 *3, dim2 = 64*3, bias = False):
        super(compute_g_spa, self).__init__()
        self.dim1 = dim1
        self.dim2 = dim2
        self.g1 = cnn1x1(self.dim1, self.dim2, bias=bias)
        self.g2 = cnn1x1(self.dim1, self.dim2, bias=bias)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x1):

        g1 = self.g1(x1).permute(0, 3, 2, 1).contiguous()
        g2 = self.g2(x1).permute(0, 3, 1, 2).contiguous()
        g3 = g1.matmul(g2)
        g = self.softmax(g3)
        return g
    

class SGNEncoder(nn.Module):
    def __init__(self, num_joint, seg, hidden_size=128, bs=32, is_3d=True, train=True, bias=True, device='cpu'):
        super(SGNEncoder, self).__init__()

        self.dim1 = hidden_size
        self.dim_unit = hidden_size // 4 
        self.seg = seg
        self.num_joint = num_joint
        self.bs = bs

        if is_3d:
          self.spatial_dim = 3
        else:
          self.spatial_dim = 2

        if train:
            self.spa = self.one_hot(bs, num_joint, self.seg)
            self.spa = self.spa.permute(0, 3, 2, 1).to(device)
            self.tem = self.one_hot(bs, self.seg, num_joint)
            self.tem = self.tem.permute(0, 3, 1, 2).to(device)
        else:
            self.spa = self.one_hot(32 * 5, num_joint, self.seg)
            self.spa = self.spa.permute(0, 3, 2, 1).to(device)
            self.tem = self.one_hot(32 * 5, self.seg, num_joint)
            self.tem = self.tem.permute(0, 3, 1, 2).to(device)

        self.tem_embed = embed(self.seg, joint=self.num_joint, hidden_dim=self.dim_unit*4, norm=False, bias=bias)
        self.spa_embed = embed(num_joint, joint=self.num_joint, hidden_dim=self.dim_unit, norm=False, bias=bias)
        self.joint_embed = embed(self.spatial_dim, joint=self.num_joint, hidden_dim=self.dim_unit, norm=True, bias=bias)
        self.dif_embed = embed(self.spatial_dim, joint=self.num_joint, hidden_dim=self.dim_unit, norm=True, bias=bias)
        self.maxpool = nn.AdaptiveMaxPool2d([1, 1])
        self.cnn = local(self.dim1, self.dim1 * 2, bias=bias)
        self.compute_g1 = compute_g_spa(self.dim1 // 2, self.dim1, bias=bias)
        self.gcn1 = gcn_spa(self.dim1 // 2, self.dim1 // 2, bias=bias)
        self.gcn2 = gcn_spa(self.dim1 // 2, self.dim1, bias=bias)
        self.gcn3 = gcn_spa(self.dim1, self.dim1, bias=bias)


        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))

        nn.init.constant_(self.gcn1.w.cnn.weight, 0)
        nn.init.constant_(self.gcn2.w.cnn.weight, 0)
        nn.init.constant_(self.gcn3.w.cnn.weight, 0)


    def forward(self, input):
        
        # Dynamic Representation
        input = input.view((self.bs, self.seg, self.num_joint, self.spatial_dim))
        input = input.permute(0, 3, 2, 1).contiguous().to()
        dif = input[:, :, :, 1:] - input[:, :, :, 0:-1]
        dif = torch.cat([dif.new(self.bs, dif.size(1), self.num_joint, 1).zero_(), dif], dim=-1)
        # print(input.shape)
        # print("input :", input.get_device(), " dif : ", dif.get_device())
        pos = self.joint_embed(input)
        # print("pos : ", pos.get_device(), " tem : ", self.tem.get_device(), " spa : ", self.spa.get_device())
        tem1 = self.tem_embed(self.tem)
        spa1 = self.spa_embed(self.spa)
        dif = self.dif_embed(dif)
        dy = pos + dif
        # Joint-level Module
        input= torch.cat([dy, spa1], 1)
        g = self.compute_g1(input)
        input = self.gcn1(input, g)
        input = self.gcn2(input, g)
        input = self.gcn3(input, g)
        # Frame-level Module
        input = input + tem1
        input = self.cnn(input)
        output_feat = torch.squeeze(input)
        output_feat = output_feat.permute(0,2,1).contiguous()
        # output_feat = input
        return output_feat

    def one_hot(self, bs, spa, tem):

        y = torch.arange(spa).unsqueeze(-1)
        y_onehot = torch.FloatTensor(spa, spa)

        y_onehot.zero_()
        y_onehot.scatter_(1, y, 1)

        y_onehot = y_onehot.unsqueeze(0).unsqueeze(0)
        y_onehot = y_onehot.repeat(bs, tem, 1, 1)

        return y_onehot

        

In [4]:
sgn_config = {
    'num_joint': 12,
    'seg': 60,
    'hidden_size': 512,
    'train': True,
    'bs': 32,
    'is_3d': False
}

In [5]:
from torch.autograd import gradcheck

In [6]:
sgn_model = SGNEncoder(**sgn_config)
sgn_input = torch.randn((32, 60, 24),requires_grad=True)
sgn_output = sgn_model(sgn_input)
sgn_output.shape

torch.Size([32, 60, 1024])

In [7]:
dec_config = {
    'seq_len': 60,
    'input_size': 512,
    'hidden_size': 256,
    'num_layers': 2,
    'bidirectional': True,
    "embedding_size": 512,
    "linear_filters":[128,256,512,1024],
}

In [16]:
skel_dec = BiLSTMDecoder(**dec_config)
skel_input = torch.randn((32, 60, 512))
skel_output = skel_dec(skel_input)
skel_output.shape

NameError: name 'BiLSTMDecoder' is not defined

In [8]:
import torch
from torch import nn
from torch.nn import TransformerEncoder, TransformerEncoderLayer


class IMUTransformerEncoder(nn.Module):

    def __init__(self, config):
        """
        config: (dict) configuration of the model
        """
        super().__init__()

        self.transformer_dim = config.get("transformer_dim")

        self.input_proj = nn.Sequential(nn.Conv1d(config.get("input_dim"), self.transformer_dim, 1), nn.GELU(),
                                        nn.Conv1d(self.transformer_dim, self.transformer_dim, 1), nn.GELU(),
                                        nn.Conv1d(self.transformer_dim, self.transformer_dim, 1), nn.GELU(),
                                        nn.Conv1d(self.transformer_dim, self.transformer_dim, 1), nn.GELU())

        self.window_size = config.get("window_size")
        self.encode_position = config.get("encode_position")
        encoder_layer = TransformerEncoderLayer(d_model = self.transformer_dim,
                                       nhead = config.get("nhead"),
                                       dim_feedforward = config.get("dim_feedforward"),
                                       dropout = config.get("transformer_dropout"),
                                       activation = config.get("transformer_activation"))

        self.transformer_encoder = TransformerEncoder(encoder_layer,
                                              num_layers = config.get("num_encoder_layers"),
                                              norm = nn.LayerNorm(self.transformer_dim))
        self.cls_token = nn.Parameter(torch.zeros((1, self.transformer_dim)), requires_grad=True)

        if self.encode_position:
            self.position_embed = nn.Parameter(torch.randn(self.window_size + 1, 1, self.transformer_dim))

        # num_classes =  config.get("num_classes")
        output_size = config.get("output_size")
        self.imu_head = nn.Sequential(
            nn.AvgPool2d((self.window_size,1)),
            nn.Linear(self.transformer_dim,  output_size),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(output_size, output_size)
        )
        self.sigmoid = nn.Sigmoid()

        # init
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def forward(self, data):
        src = data  # Shape N x S x C with S = sequence length, N = batch size, C = channels

        # Embed in a high dimensional space and reshape to Transformer's expected shape
        src = self.input_proj(src.transpose(1, 2)).permute(2, 0, 1)

        # Prepend class token
        cls_token = self.cls_token.unsqueeze(1).repeat(1, src.shape[1], 1)
        src = torch.cat([cls_token, src])

        # Add the position embedding
        if self.encode_position:
            src += self.position_embed

        # Transformer Encoder pass
        target = self.transformer_encoder(src)

        # Class probability
        target = torch.squeeze(target.permute(1,0,2))
        # target = self.imu_head(target)
        return target

def get_activation(activation):
    """Return an activation function given a string"""
    if activation == "relu":
        return nn.ReLU(inplace=True)
    if activation == "gelu":
        return nn.GELU()
    raise RuntimeError("Activation {} not supported".format(activation))

In [9]:
imu_config = {
	"input_dim": 6,
    "window_size":50,
	"encode_position":True,
	"transformer_dim": 256,
	"nhead": 8,
	"num_encoder_layers": 6, 
	"dim_feedforward": 128, 
	"transformer_dropout": 0.1, 
	"transformer_activation": "gelu",
	"head_activation": "gelu",
    "baseline_dropout": 0.1,
	"batch_size": 32,
	"output_size": 512
}

In [10]:
imu_model = IMUTransformerEncoder(imu_config)
imu_input = torch.randn((32, 50, 6))
imu_output = imu_model(imu_input)
imu_output.shape

torch.Size([32, 51, 256])

In [11]:
torchinfo.summary(imu_model, input_size=(32, 50, 6), col_names = ("input_size", "output_size", "num_params", "kernel_size", "mult_adds"))

  action_fn=lambda data: sys.getsizeof(data.storage()),
  return super().__sizeof__() + self.nbytes()


Layer (type:depth-idx)                        Input Shape               Output Shape              Param #                   Kernel Shape              Mult-Adds
IMUTransformerEncoder                         [32, 50, 6]               [32, 51, 256]             407,552                   --                        --
├─Sequential: 1-1                             [32, 6, 50]               [32, 256, 50]             --                        --                        --
│    └─Conv1d: 2-1                            [32, 6, 50]               [32, 256, 50]             1,792                     [1]                       2,867,200
│    └─GELU: 2-2                              [32, 256, 50]             [32, 256, 50]             --                        --                        --
│    └─Conv1d: 2-3                            [32, 256, 50]             [32, 256, 50]             65,792                    [1]                       105,267,200
│    └─GELU: 2-4                              [32, 256, 50]

In [12]:
from src.models.bidirectional_cross_attention import BidirectionalCrossAttention

In [16]:
video = torch.randn(32, 1, 512)
audio = torch.randn(32, 60, 386)

video_mask = torch.ones((32, 1)).bool()
audio_mask = torch.ones((32, 60)).bool()

joint_cross_attn = BidirectionalCrossAttention(
    dim = 512,
    heads = 8,
    dim_head = 64,
    context_dim = 386
)

video_out, audio_out = joint_cross_attn(
    video,
    audio,
    mask = video_mask,
    context_mask = audio_mask
)

In [18]:
video_out.shape

torch.Size([32, 1, 512])

In [27]:
audio_out.size()[:2]

torch.Size([32, 60])

---

## Archi-5

In [2]:
import torch 
from torch import nn, Tensor
from torch.nn import functional as F
from torch.nn.modules import MultiheadAttention, Linear, Dropout, BatchNorm1d, TransformerEncoderLayer
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
from torch.nn import MSELoss,L1Loss

import math
import torchinfo

In [3]:
class norm_data(nn.Module):
    def __init__(self, dim=3, joints=20):
        super(norm_data, self).__init__()

        self.bn = nn.BatchNorm1d(dim*joints)

    def forward(self, x):
        bs, c, num_joints, step = x.size()
        x = x.view(bs, -1, step)
        x = self.bn(x)
        x = x.view(bs, -1, num_joints, step).contiguous()
        return x

class embed(nn.Module):
    def __init__(self, dim=3, joint=20, hidden_dim=128, norm=True, bias=False):
        super(embed, self).__init__()

        if norm:
            self.cnn = nn.Sequential(
                norm_data(dim, joint),
                cnn1x1(dim, 64, bias=bias),
                nn.ReLU(),
                cnn1x1(64, hidden_dim, bias=bias),
                nn.ReLU(),
            )
        else:
            self.cnn = nn.Sequential(
                cnn1x1(dim, 64, bias=bias),
                nn.ReLU(),
                cnn1x1(64, hidden_dim, bias=bias),
                nn.ReLU(),
            )

    def forward(self, x):
        x = self.cnn(x)
        return x

class cnn1x1(nn.Module):
    def __init__(self, dim1 = 3, dim2 =3, bias = True):
        super(cnn1x1, self).__init__()
        self.cnn = nn.Conv2d(dim1, dim2, kernel_size=1, bias=bias)

    def forward(self, x):
        x = self.cnn(x)
        return x

class local(nn.Module):
    def __init__(self, dim1 = 3, dim2 = 3, bias = False):
        super(local, self).__init__()
        self.maxpool = nn.AdaptiveMaxPool2d((1, None))
        self.cnn1 = nn.Conv2d(dim1, dim1, kernel_size=(1, 3), padding=(0, 1), bias=bias)
        self.bn1 = nn.BatchNorm2d(dim1)
        self.relu = nn.ReLU()
        self.cnn2 = nn.Conv2d(dim1, dim2, kernel_size=1, bias=bias)
        self.bn2 = nn.BatchNorm2d(dim2)
        self.dropout = nn.Dropout2d(0.2)

    def forward(self, x1):
        x1 = self.maxpool(x1)
        x = self.cnn1(x1)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.cnn2(x)
        x = self.bn2(x)
        x = self.relu(x)

        return x

class gcn_spa(nn.Module):
    def __init__(self, in_feature, out_feature, bias = False):
        super(gcn_spa, self).__init__()
        self.bn = nn.BatchNorm2d(out_feature)
        self.relu = nn.ReLU()
        self.w = cnn1x1(in_feature, out_feature, bias=False)
        self.w1 = cnn1x1(in_feature, out_feature, bias=bias)


    def forward(self, x1, g):
        x = x1.permute(0, 3, 2, 1).contiguous()
        x = g.matmul(x)
        x = x.permute(0, 3, 2, 1).contiguous()
        x = self.w(x) + self.w1(x1)
        x = self.relu(self.bn(x))
        return x

class compute_g_spa(nn.Module):
    def __init__(self, dim1 = 64 *3, dim2 = 64*3, bias = False):
        super(compute_g_spa, self).__init__()
        self.dim1 = dim1
        self.dim2 = dim2
        self.g1 = cnn1x1(self.dim1, self.dim2, bias=bias)
        self.g2 = cnn1x1(self.dim1, self.dim2, bias=bias)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x1):

        g1 = self.g1(x1).permute(0, 3, 2, 1).contiguous()
        g2 = self.g2(x1).permute(0, 3, 1, 2).contiguous()
        g3 = g1.matmul(g2)
        g = self.softmax(g3)
        return g
    

class SGNEncoder(nn.Module):
    def __init__(self, num_joint, seg, hidden_size=128, output_size=512, bs=32, is_3d=True, train=True, bias=True, device='cpu'):
        super(SGNEncoder, self).__init__()

        self.dim1 = hidden_size
        self.dim_unit = hidden_size // 4 
        self.seg = seg
        self.num_joint = num_joint
        self.bs = bs

        if is_3d:
          self.spatial_dim = 3
        else:
          self.spatial_dim = 2

        if train:
            self.spa = self.one_hot(bs, num_joint, self.seg)
            self.spa = self.spa.permute(0, 3, 2, 1).to(device)
            self.tem = self.one_hot(bs, self.seg, num_joint)
            self.tem = self.tem.permute(0, 3, 1, 2).to(device)
        else:
            self.spa = self.one_hot(32 * 5, num_joint, self.seg)
            self.spa = self.spa.permute(0, 3, 2, 1).to(device)
            self.tem = self.one_hot(32 * 5, self.seg, num_joint)
            self.tem = self.tem.permute(0, 3, 1, 2).to(device)

        self.tem_embed = embed(self.seg, joint=self.num_joint, hidden_dim=self.dim_unit*4, norm=False, bias=bias)
        self.spa_embed = embed(num_joint, joint=self.num_joint, hidden_dim=self.dim_unit, norm=False, bias=bias)
        self.joint_embed = embed(self.spatial_dim, joint=self.num_joint, hidden_dim=self.dim_unit, norm=True, bias=bias)
        self.dif_embed = embed(self.spatial_dim, joint=self.num_joint, hidden_dim=self.dim_unit, norm=True, bias=bias)
        self.maxpool = nn.AdaptiveMaxPool2d([1, 1])
        self.cnn = local(self.dim1, self.dim1 * 2, bias=bias)
        self.compute_g1 = compute_g_spa(self.dim1 // 2, self.dim1, bias=bias)
        self.gcn1 = gcn_spa(self.dim1 // 2, self.dim1 // 2, bias=bias)
        self.gcn2 = gcn_spa(self.dim1 // 2, self.dim1, bias=bias)
        self.gcn3 = gcn_spa(self.dim1, self.dim1, bias=bias)
        self.fc = nn.Linear(self.dim1 * 2, output_size)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))

        nn.init.constant_(self.gcn1.w.cnn.weight, 0)
        nn.init.constant_(self.gcn2.w.cnn.weight, 0)
        nn.init.constant_(self.gcn3.w.cnn.weight, 0)


    def forward(self, input):
        
        # Dynamic Representation
        input = input.view((self.bs, self.seg, self.num_joint, self.spatial_dim))
        input = input.permute(0, 3, 2, 1).contiguous().to()
        dif = input[:, :, :, 1:] - input[:, :, :, 0:-1]
        dif = torch.cat([dif.new(self.bs, dif.size(1), self.num_joint, 1).zero_(), dif], dim=-1)
        # print(input.shape)
        # print("input :", input.get_device(), " dif : ", dif.get_device())
        pos = self.joint_embed(input)
        # print("pos : ", pos.get_device(), " tem : ", self.tem.get_device(), " spa : ", self.spa.get_device())
        tem1 = self.tem_embed(self.tem)
        spa1 = self.spa_embed(self.spa)
        dif = self.dif_embed(dif)
        dy = pos + dif
        # Joint-level Module
        input= torch.cat([dy, spa1], 1)
        g = self.compute_g1(input)
        input = self.gcn1(input, g)
        input = self.gcn2(input, g)
        input = self.gcn3(input, g)
        # Frame-level Module
        input = input + tem1
        input = self.cnn(input)
        output = self.maxpool(input)
        output = torch.flatten(output, 1)
        output = self.fc(output)
        # output_feat = torch.squeeze(input)
        # output_feat = output_feat.permute(0,2,1).contiguous()
        # output_feat = input
        return output

    def one_hot(self, bs, spa, tem):

        y = torch.arange(spa).unsqueeze(-1)
        y_onehot = torch.FloatTensor(spa, spa)

        y_onehot.zero_()
        y_onehot.scatter_(1, y, 1)

        y_onehot = y_onehot.unsqueeze(0).unsqueeze(0)
        y_onehot = y_onehot.repeat(bs, tem, 1, 1)

        return y_onehot

        

In [4]:
sgn_config = {
    'num_joint': 12,
    'seg': 60,
    'hidden_size': 128,
    'output_size': 512,
    'train': True,
    'bs': 32,
    'is_3d': False
}

In [5]:
sgn_model = SGNEncoder(**sgn_config)
sgn_input = torch.randn((32, 60, 24),requires_grad=True)
sgn_output = sgn_model(sgn_input)
sgn_output.shape

torch.Size([32, 512])

In [6]:
import torch
from torch import nn
from torch.nn import TransformerEncoder, TransformerEncoderLayer


class IMUTransformerEncoder(nn.Module):

    def __init__(self, config):
        """
        config: (dict) configuration of the model
        """
        super().__init__()

        self.transformer_dim = config.get("transformer_dim")

        self.input_proj = nn.Sequential(nn.Conv1d(config.get("input_dim"), self.transformer_dim, 1), nn.GELU(),
                                        nn.Conv1d(self.transformer_dim, self.transformer_dim, 1), nn.GELU(),
                                        nn.Conv1d(self.transformer_dim, self.transformer_dim, 1), nn.GELU(),
                                        nn.Conv1d(self.transformer_dim, self.transformer_dim, 1), nn.GELU())

        self.window_size = config.get("window_size")
        self.encode_position = config.get("encode_position")
        encoder_layer = TransformerEncoderLayer(d_model = self.transformer_dim,
                                       nhead = config.get("nhead"),
                                       dim_feedforward = config.get("dim_feedforward"),
                                       dropout = config.get("transformer_dropout"),
                                       activation = config.get("transformer_activation"))

        self.transformer_encoder = TransformerEncoder(encoder_layer,
                                              num_layers = config.get("num_encoder_layers"),
                                              norm = nn.LayerNorm(self.transformer_dim))
        self.cls_token = nn.Parameter(torch.zeros((1, self.transformer_dim)), requires_grad=True)

        if self.encode_position:
            self.position_embed = nn.Parameter(torch.randn(self.window_size + 1, 1, self.transformer_dim))

        # num_classes =  config.get("num_classes")
        output_size = config.get("output_size")
        self.imu_head = nn.Sequential(
            nn.AvgPool2d((self.window_size,1)),
            nn.Linear(self.transformer_dim,  output_size),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(output_size, output_size)
        )
        self.sigmoid = nn.Sigmoid()

        # init
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def forward(self, data):
        src = data  # Shape N x S x C with S = sequence length, N = batch size, C = channels

        # Embed in a high dimensional space and reshape to Transformer's expected shape
        src = self.input_proj(src.transpose(1, 2)).permute(2, 0, 1)

        # Prepend class token
        cls_token = self.cls_token.unsqueeze(1).repeat(1, src.shape[1], 1)
        src = torch.cat([cls_token, src])

        # Add the position embedding
        if self.encode_position:
            src += self.position_embed

        # Transformer Encoder pass
        target = self.transformer_encoder(src)

        # Class probability
        target = torch.squeeze(target.permute(1,0,2))
        # target = self.imu_head(target)
        target = torch.squeeze(self.imu_head(target))
        return target

def get_activation(activation):
    """Return an activation function given a string"""
    if activation == "relu":
        return nn.ReLU(inplace=True)
    if activation == "gelu":
        return nn.GELU()
    raise RuntimeError("Activation {} not supported".format(activation))

In [7]:
imu_config = {
	"input_dim": 6,
    "window_size":50,
	"encode_position":True,
	"transformer_dim": 256,
	"nhead": 8,
	"num_encoder_layers": 6, 
	"dim_feedforward": 128, 
	"transformer_dropout": 0.1, 
	"transformer_activation": "gelu",
	"head_activation": "gelu",
    "baseline_dropout": 0.1,
	"batch_size": 32,
	"output_size": 512
}

In [8]:
imu_model = IMUTransformerEncoder(imu_config)
imu_input = torch.randn((32, 50, 6))
imu_output = imu_model(imu_input)
imu_output.shape

torch.Size([32, 512])

In [9]:
class ContrastHead(nn.Module):
  def __init__(self, embedding_size):
        super(ContrastHead, self).__init__()
        self.imu_head = nn.Sequential(
            # nn.AvgPool2d((window_size,1)),
            # nn.Linear(embedding_size,  hidden_size),
            # nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(embedding_size, 1),
            nn.Sigmoid()
        )

  def forward(self, x):
        output = self.imu_head(x)
        return output

In [10]:
class SuperHead(nn.Module):
  def __init__(self, embedding_size, num_classes):
        super(SuperHead, self).__init__()
        self.imu_head = nn.Sequential(
            # nn.AvgPool2d((window_size,1)),
            # nn.Linear(embedding_size,  hidden_size),
            # nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(embedding_size, num_classes),
            nn.Softmax()
        )

  def forward(self, x):
        output = self.imu_head(x)
        return output

In [11]:
class BiLSTMDecoder(nn.Module):
    def __init__(self, seq_len, input_size, hidden_size, linear_filters, embedding_size:int, num_layers = 1, bidirectional=True, device='cpu'):
        super(BiLSTMDecoder, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.device = device
        self.num_layers = num_layers
        self.linear_filters = linear_filters[::-1]
        self.embedding_size = embedding_size
        self.bidirectional = bidirectional
        self.seq_len = seq_len

        if bidirectional:
            self.input_linear = nn.Linear(self.embedding_size,4*self.hidden_size)
        else:
            self.input_linear = nn.Linear(self.embedding_size,2*self.hidden_size)

        # define LSTM layer
        self.layers = []
        # add lstm
        self.lstm = nn.LSTM(input_size=self.hidden_size, hidden_size=self.hidden_size,
                            num_layers=self.num_layers, bidirectional=True,
                            batch_first=bidirectional)

                        
        # add linear layers 
        if bidirectional:
            self.layers.append(nn.Linear(2*hidden_size,self.linear_filters[0]))
        else:
            self.layers.append(nn.Linear(hidden_size,self.linear_filters[0]))

        for __id,layer_in in enumerate(self.linear_filters):
            if __id == len(linear_filters)-1:
                self.layers.append(nn.Linear(layer_in,self.input_size))
            else:
                self.layers.append(nn.Linear(layer_in,self.linear_filters[__id+1]))

        self.net = nn.Sequential(*self.layers)

        
        

    def forward(self,encoder_hidden):
        """
        : param x_input:               input of shape (seq_len, # in batch, input_size)
        : return lstm_out, hidden:     lstm_out gives all the hidden states in the sequence; hidden gives the hidden state and cell state for the last element in the sequence
        """
        
        
        hidden_shape = encoder_hidden.shape
        encoder_hidden = self.input_linear(encoder_hidden)
        
        if self.bidirectional:
            hidden = encoder_hidden.view((-1,4,self.hidden_size))
            hidden = torch.transpose(hidden,1,0)
            h1,h2,c1,c2 = torch.unbind(hidden,0)
            h,c = torch.stack((h1,h2,h1,h2)),torch.stack((c1,c2,c1,c2))
            bs = h.size()[1]
        else:
            hidden = encoder_hidden.view((-1,2,self.hidden_size))
            hidden = torch.transpose(hidden,1,0)
            h,c = torch.unbind(hidden,0)
            bs = h.size()[1]
        
        dummy_input = torch.rand((bs,self.seq_len,self.hidden_size), requires_grad=True).to(self.device)
        print(dummy_input.shape, h.shape, c.shape)
        lstm_out, self.hidden = self.lstm(dummy_input,(h,c))
        x = self.net(lstm_out)
        
        return x

In [40]:
class BiLSTMDecoder(nn.Module):
    def __init__(self,seq_len, input_size, hidden_size, linear_filters,embedding_size:int, num_layers = 1,bidirectional=True,device="cpu"):
        super(BiLSTMDecoder, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.device = device
        self.num_layers = num_layers
        self.linear_filters = linear_filters[::-1]
        self.embedding_size = embedding_size
        self.bidirectional = bidirectional
        self.seq_len = seq_len

        if bidirectional:
            self.input_linear = nn.Linear(self.embedding_size,4*self.hidden_size)
        else:
            self.input_linear = nn.Linear(self.embedding_size,2*self.hidden_size)

        # define LSTM layer
        self.layers = []
        # add lstm
        self.lstm = nn.LSTM(input_size = self.linear_filters[0], hidden_size = self.hidden_size,
                            num_layers = self.num_layers, bidirectional=True,
                            batch_first=bidirectional)

                        
        # add linear layers 
        if bidirectional:
            self.layers.append(nn.Linear(2*hidden_size,self.linear_filters[0]))
        else:
            self.layers.append(nn.Linear(hidden_size,self.linear_filters[0]))

        for __id,layer_in in enumerate(self.linear_filters):
            if __id == len(linear_filters)-1:
                self.layers.append(nn.Linear(layer_in,self.input_size))
            else:
                self.layers.append(nn.Linear(layer_in,self.linear_filters[__id+1]))

        self.net = nn.Sequential(*self.layers)

        
        

    def forward(self,encoder_hidden):
        """
        : param x_input:               input of shape (seq_len, # in batch, input_size)
        : return lstm_out, hidden:     lstm_out gives all the hidden states in the sequence; hidden gives the hidden state and cell state for the last element in the sequence
        """
        
        
        hidden_shape = encoder_hidden.shape
        encoder_hidden = self.input_linear(encoder_hidden)
        
        if self.bidirectional:
            hidden = encoder_hidden.view((-1,4,self.hidden_size))
            hidden = torch.transpose(hidden,1,0)
            h1,h2,c1,c2 = torch.unbind(hidden,0)
            h,c = torch.stack((h1,h2)),torch.stack((c1,c2))
            bs = h.size()[1]
        else:
            hidden = encoder_hidden.view((-1,2,self.hidden_size))
            hidden = torch.transpose(hidden,1,0)
            h,c = torch.unbind(hidden,0)
            bs = h.size()[1]
        
        dummy_input = torch.rand((bs,self.seq_len,self.hidden_size), requires_grad=True).to(self.device)
        
        lstm_out, self.hidden = self.lstm(dummy_input,(h,c))
        x = self.net(lstm_out)
        
        return x

In [41]:
dec_config = {
    'seq_len': 60,
    'input_size': 24,
    'hidden_size': 256,
    'linear_filters': [128, 256],
    'embedding_size': 512,
    'num_layers': 1,
    'bidirectional': True,
    'device': 'cpu'
}

dec_model = BiLSTMDecoder(**dec_config)

In [42]:
dec_input = torch.randn((32, 512))
dec_out = dec_model(dec_input)
dec_out.shape

torch.Size([32, 60, 24])

In [14]:
from src.models.bidirectional_cross_attention import BidirectionalCrossAttention


In [15]:
class BaseModel(nn.Module):
    def __init__(self, config):
        super(BaseModel, self).__init__()
        self.device = config['device']
        self.imu_model = IMUTransformerEncoder(config['imu_config'])
        self.skel_encoder = SGNEncoder(**config['sgn_config'])
        self.skel_decoder = BiLSTMDecoder(**config['dec_config'])
        self.contrast_head = ContrastHead(**config['contrast_config'])
        self.super_head = SuperHead(**config['super_config'])
        self.lxmert_xlayer = BidirectionalCrossAttention(**config['xmert_config'])

        self.imu_mask = torch.ones((config['bs'], config['imu_len']), requires_grad=True).bool().to(self.device)
        self.skel_mask = torch.ones((config['bs'], config['skel_len']), requires_grad=True).bool().to(self.device)
        

    def forward(self, x_imu, x_skel):
        imu_feats = self.imu_model(x_imu)
        skel_feats = self.skel_encoder(x_skel)
        imu_feats = torch.unsqueeze(imu_feats, dim=1)
        skel_feats = torch.unsqueeze(skel_feats, dim=1)
        print(f"imu_feats {imu_feats.shape} | skel_feats {skel_feats.shape}")
        imu_feats, skel_feats = self.lxmert_xlayer(imu_feats, skel_feats, mask=self.imu_mask, context_mask=self.skel_mask)
        # print(f"imu_feats {imu_feats.shape} | skel_feats {skel_feats.shape}")
        # skel_recon = self.skel_decoder(skel_feats)
        # imu_feats = torch.squeeze(imu_feats)
        # bin_output = torch.squeeze(self.fc_head(imu_feats))
        imu_feats = torch.squeeze(imu_feats)
        skel_feats = torch.squeeze(skel_feats)
        print(f"imu_feats {imu_feats.shape} | skel_feats {skel_feats.shape}")
        contrast_out = self.contrast_head(imu_feats)
        super_out = self.super_head(imu_feats)

        skel_recon = self.skel_decoder(imu_feats)
        return contrast_out, super_out, skel_recon

In [22]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [24]:
imu_config = {
	"input_dim": 54,
    "window_size":24,
	"encode_position":True,
	"transformer_dim": 512,
	"nhead": 8,
	"num_encoder_layers": 6, 
	"dim_feedforward": 128, 
	"transformer_dropout": 0.1, 
	"transformer_activation": "gelu",
	"head_activation": "gelu",
    "baseline_dropout": 0.1,
	"batch_size": 32,
	"output_size": 512
}

sgn_config = {
    'num_joint': 12,
    'seg': 60,
    'hidden_size': 128,
    'output_size': 512,
    'train': True,
    'bs': 32,
    'is_3d': False
}

# dec_config = {
#     'seq_len': 60,
#     'input_size': 24,
#     'hidden_size': 256,
#     'num_layers': 2,
#     'bidirectional': True,
#     "embedding_size": 128,
#     "linear_filters":[128,256,512,1024],
# }

dec_config = {
    'seq_len': 60,
    'input_size': 24,
    'hidden_size': 512,
    'linear_filters': [128, 256, 512, 1024],
    'embedding_size': 512,
    'num_layers': 2,
    # 'output_size': 24,
    'bidirectional': True
}

# seq_len, input_size, hidden_size, linear_filters, embedding_size:int, num_layers = 1, bidirectional=True, device='cpu'

contrast_config = {
    'embedding_size': 512
}

super_config = {
    'embedding_size': 512,
    'num_classes': 5
}

base_config = {
    'imu_config': imu_config,
    'sgn_config': sgn_config,
    'dec_config': dec_config,
    'contrast_config': contrast_config,
    'super_config': super_config,
    'device': device,
    'bs': 32,
    'imu_len': 1,
    'skel_len': 60,
    'ft_size': 512,
    'xmert_config': {
            'dim': 512,
            'heads': 8,
            'dim_head': 64,
            'context_dim': 512,   
    }
}
 
base_model = BaseModel(base_config).to(device)

In [25]:
imu_input = torch.randn([32, 24, 54]).to(device)
skel_input = torch.randn([32, 60, 24]).to(device)
con_out, sup_out, skel_out = base_model(imu_input, skel_input)

RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same or input should be a MKLDNN tensor and weight is a dense tensor

In [18]:
con_out.shape

torch.Size([32, 1])

In [19]:
sup_out.shape

torch.Size([32, 5])

In [20]:
skel_out.shape

torch.Size([32, 60, 24])