In [1]:
import sys
sys.path.append('../')

from src.models.modeling_sgn import embed, local, gcn_spa, compute_g_spa
from src.models.modeling_lxmert import LxmertConfig, LxmertXLayer

import torch 
from torch import nn, Tensor
from torch.nn import TransformerEncoder, TransformerEncoderLayer

import math

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
torch.autograd.set_detect_anomaly(True)

<torch.autograd.anomaly_mode.set_detect_anomaly at 0x216eb71a8e0>

In [14]:
class SGNEncoder(nn.Module):
    def __init__(self, num_joint, seg, hidden_size=128, bs=32, is_3d=True, train=True, bias=True, device='cpu'):
        super(SGNEncoder, self).__init__()

        self.dim1 = hidden_size
        self.dim_unit = hidden_size // 4 
        self.seg = seg
        self.num_joint = num_joint
        self.bs = bs

        if is_3d:
          self.spatial_dim = 3
        else:
          self.spatial_dim = 2

        if train:
            self.spa = self.one_hot(bs, num_joint, self.seg)
            self.spa = self.spa.permute(0, 3, 2, 1).to(device)
            self.tem = self.one_hot(bs, self.seg, num_joint)
            self.tem = self.tem.permute(0, 3, 1, 2).to(device)
        else:
            self.spa = self.one_hot(32 * 5, num_joint, self.seg)
            self.spa = self.spa.permute(0, 3, 2, 1).to(device)
            self.tem = self.one_hot(32 * 5, self.seg, num_joint)
            self.tem = self.tem.permute(0, 3, 1, 2).to(device)

        self.tem_embed = embed(self.seg, joint=self.num_joint, hidden_dim=self.dim_unit*4, norm=False, bias=bias)
        self.spa_embed = embed(num_joint, joint=self.num_joint, hidden_dim=self.dim_unit, norm=False, bias=bias)
        self.joint_embed = embed(self.spatial_dim, joint=self.num_joint, hidden_dim=self.dim_unit, norm=True, bias=bias)
        self.dif_embed = embed(self.spatial_dim, joint=self.num_joint, hidden_dim=self.dim_unit, norm=True, bias=bias)
        self.maxpool = nn.AdaptiveMaxPool2d([1, 1])
        self.cnn = local(self.dim1, self.dim1 * 2, bias=bias)
        self.compute_g1 = compute_g_spa(self.dim1 // 2, self.dim1, bias=bias)
        self.gcn1 = gcn_spa(self.dim1 // 2, self.dim1 // 2, bias=bias)
        self.gcn2 = gcn_spa(self.dim1 // 2, self.dim1, bias=bias)
        self.gcn3 = gcn_spa(self.dim1, self.dim1, bias=bias)


        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))

        nn.init.constant_(self.gcn1.w.cnn.weight, 0)
        nn.init.constant_(self.gcn2.w.cnn.weight, 0)
        nn.init.constant_(self.gcn3.w.cnn.weight, 0)


    def forward(self, x):
        
        # Dynamic Representation
        x = x.view((self.bs, self.seg, self.num_joint, self.spatial_dim))
        x = x.permute(0, 3, 2, 1).contiguous()
        dif = x[:, :, :, 1:] - x[:, :, :, 0:-1]
        dif = torch.cat([dif.new(self.bs, dif.size(1), self.num_joint, 1).zero_(), dif], dim=-1)
        # print(x.shape)
        pos = self.joint_embed(x)
        tem1 = self.tem_embed(self.tem)
        spa1 = self.spa_embed(self.spa)
        dif = self.dif_embed(dif)
        dy = torch.add(pos, dif)
        # Joint-level Module
        x= torch.cat([dy, spa1], 1)
        g = self.compute_g1(x)
        x = self.gcn1(x, g)
        x = self.gcn2(x, g)
        x = self.gcn3(x, g)
        # Frame-level Module
        # x = torch.add(x, tem1)
        x = self.cnn(torch.add(x, tem1))
        output_feat = torch.squeeze(x).permute(0,2,1)

        return output_feat
    def one_hot(self, bs, spa, tem):
        y = torch.arange(spa).unsqueeze(-1)
        y_onehot = torch.FloatTensor(spa, spa)

        y_onehot.zero_()
        y_onehot.scatter_(1, y, 1)

        y_onehot = y_onehot.unsqueeze(0).unsqueeze(0)
        y_onehot = y_onehot.repeat(bs, tem, 1, 1)

        return y_onehot

    
class BiLSTMDecoder(nn.Module):
    def __init__(self,seq_len, input_size, hidden_size, linear_filters,embedding_size:int, num_layers = 1,bidirectional=True, device='cpu'):
        super(BiLSTMDecoder, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.device = device
        self.num_layers = num_layers
        self.linear_filters = linear_filters[::-1]
        self.embedding_size = embedding_size
        self.bidirectional = bidirectional
        self.seq_len = seq_len


        # define LSTM layer
        self.layers = []
        # add lstm
        self.lstm = nn.LSTM(input_size = self.seq_len, hidden_size = self.hidden_size,
                            num_layers = self.num_layers, bidirectional=True,
                            batch_first=True)
        
        self.maxpool = nn.AdaptiveMaxPool2d([self.seq_len, self.seq_len])
        # add linear layers 
        if bidirectional:
            self.layers.append(nn.Linear(2*hidden_size,self.linear_filters[0]))
        else:
            self.layers.append(nn.Linear(hidden_size,self.linear_filters[0]))

        for __id,layer_in in enumerate(self.linear_filters):
            if __id == len(linear_filters)-1:
                self.layers.append(nn.Linear(layer_in,self.input_size))
            else:
                self.layers.append(nn.Linear(layer_in,self.linear_filters[__id+1]))

        self.net = nn.Sequential(*self.layers)


    def forward(self,encoder_hidden):
        """
        : param x_input:               input of shape (seq_len, # in batch, input_size)
        : return lstm_out, hidden:     lstm_out gives all the hidden states in the sequence; hidden gives the hidden state and cell state for the last element in the sequence
        """
        output = self.maxpool(encoder_hidden)
        lstm_out, self.hidden = self.lstm(output)
        x = self.net(lstm_out)
        return x

In [15]:
# class BiLSTMDecoder(nn.Module):
#     def __init__(self, input_size, hidden_size, output_size, num_layers = 1, bidirectional=True, batch_size=32, device='cpu'):
#         super(BiLSTMDecoder, self).__init__()
#         self.input_size = input_size
#         self.hidden_size = hidden_size
#         self.num_layers = num_layers
#         self.bidirectional = bidirectional
#         self.batch_size = batch_size
#         self.device = device

#         # add lstm layer
#         self.lstm = nn.LSTM(input_size = input_size, hidden_size = self.hidden_size,
#                             num_layers = self.num_layers, bidirectional=self.bidirectional,
#                             batch_first=True, proj_size = output_size)

#     def forward(self, x):
#         '''
#         : param x_input:               input of shape (seq_len, # in batch, input_size)
#         : return lstm_out, hidden:     lstm_out gives all the hidden states in the sequence; hidden gives the hidden state and cell state for the last element in the sequence                         
#         '''
        
#         # x = self.net(x_input)
#         out, _ = self.lstm(x)
 
#         return out

In [26]:
class IMUTransformerEncoder(nn.Module):

    def __init__(self, config):
        """
        config: (dict) configuration of the model
        """
        super().__init__()

        self.transformer_dim = config.get("transformer_dim")

        self.input_proj = nn.Sequential(nn.Conv1d(config.get("input_dim"), self.transformer_dim, 1), nn.GELU(),
                                        nn.Conv1d(self.transformer_dim, self.transformer_dim, 1), nn.GELU(),
                                        nn.Conv1d(self.transformer_dim, self.transformer_dim, 1), nn.GELU(),
                                        nn.Conv1d(self.transformer_dim, self.transformer_dim, 1), nn.GELU())

        self.window_size = config.get("window_size")
        self.encode_position = config.get("encode_position")
        encoder_layer = TransformerEncoderLayer(d_model = self.transformer_dim,
                                       nhead = config.get("nhead"),
                                       dim_feedforward = config.get("dim_feedforward"),
                                       dropout = config.get("transformer_dropout"),
                                       activation = config.get("transformer_activation"))

        self.transformer_encoder = TransformerEncoder(encoder_layer,
                                              num_layers = config.get("num_encoder_layers"),
                                              norm = nn.LayerNorm(self.transformer_dim))
        self.cls_token = nn.Parameter(torch.zeros((1, self.transformer_dim)), requires_grad=True)

        if self.encode_position:
            self.position_embed = nn.Parameter(torch.randn(self.window_size + 1, 1, self.transformer_dim))

        # num_classes =  config.get("num_classes")
        # output_size = config.get("output_size")
        # self.imu_head = nn.Sequential(
        #     nn.AvgPool2d((self.window_size,1)),
        #     nn.Linear(self.transformer_dim,  output_size),
        #     nn.GELU(),
        #     nn.Dropout(0.1),
        #     nn.Linear(output_size, output_size)
        # )
        # self.sigmoid = nn.Sigmoid()

        # init
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def forward(self, data):
        src = data  # Shape N x S x C with S = sequence length, N = batch size, C = channels
        src = self.input_proj(src.transpose(1, 2)).permute(2, 0, 1)

        cls_token = self.cls_token.unsqueeze(1).repeat(1, src.shape[1], 1)
        src = torch.cat([cls_token, src])

        if self.encode_position:
            src += self.position_embed

        target = self.transformer_encoder(src)
        target = torch.squeeze(target.permute(1,0,2))
        # target = torch.squeeze(self.imu_head(target))
        return target

def get_activation(activation):
    """Return an activation function given a string"""
    if activation == "relu":
        return nn.ReLU(inplace=True)
    if activation == "gelu":
        return nn.GELU()
    raise RuntimeError("Activation {} not supported".format(activation))

In [44]:
class ClassifierHead(nn.Module):
  def __init__(self, window_size, embedding_size, hidden_size):
        super(ClassifierHead, self).__init__()
        #   self.maxpool = nn.AdaptiveMaxPool2d([embedding_size//2, 2])
      #   self.hidden = nn.Linear(embedding_size, embedding_size)
        # self.relu = nn.ReLU()
        # self.output = nn.Linear(embedding_size, 1)
        # self.sigmoid = nn.Sigmoid()
        self.imu_head = nn.Sequential(
            nn.AvgPool2d((window_size,1)),
            nn.Linear(embedding_size,  hidden_size),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size, 1),
            nn.Sigmoid()
        )

  def forward(self, x):
        output = self.imu_head(x)
        return output

In [45]:
class BaseModel(nn.Module):
    def __init__(self, config):
        super(BaseModel, self).__init__()
        
        self.imu_model = IMUTransformerEncoder(config['imu_config'])
        self.skel_encoder = SGNEncoder(**config['sgn_config'])
        self.skel_decoder = BiLSTMDecoder(**config['dec_config'])
        self.fc_head = ClassifierHead(**config['clf_config'])
        self.lxmert_config = LxmertConfig(**config['xmert_config'])
        self.lxmert_xlayer = LxmertXLayer(self.lxmert_config)

        self.num_layers = config['num_layers']

    def forward(self, x_imu, x_skel):
        imu_feats = self.imu_model(x_imu)
        skel_feats = self.skel_encoder(x_skel)
        print(f"imu_feats {imu_feats.shape} | skel_feats {skel_feats.shape}")
        for i in range(self.num_layers):
            x_outputs = self.lxmert_xlayer(
                lang_feats = skel_feats,
                lang_attention_mask = None,  
                visual_feats = imu_feats,
                visual_attention_mask = None,
                input_id = None,
                output_attentions=False,
            )
            skel_feats, imu_feats = x_outputs[:2]

        print(f"imu_feats {imu_feats.shape} | skel_feats {skel_feats.shape}")
        skel_recon = self.skel_decoder(skel_feats)
        # imu_feats = torch.squeeze(imu_feats)
        bin_output = self.fc_head(imu_feats)
        return bin_output, skel_recon

In [46]:
imu_config = {
	"input_dim": 54,
    "window_size":24,
	"encode_position":True,
	"transformer_dim": 512,
	"nhead": 8,
	"num_encoder_layers": 6, 
	"dim_feedforward": 128, 
	"transformer_dropout": 0.1, 
	"transformer_activation": "gelu",
	"head_activation": "gelu",
    "baseline_dropout": 0.1,
	"batch_size": 32,
	"output_size": 512
}

sgn_config = {
    'num_joint': 12,
    'seg': 60,
    'hidden_size': 256,
    'train': True,
    'bs': 32,
    'is_3d': False
}

# dec_config = {
#     'input_size': 512,
#     'hidden_size': 256,
#     'output_size': 12,
#     'num_layers': 2,
#     'bidirectional': True,
#     # "embedding_size": 128,
#     # "linear_filters":[128,256,512,1024],
# }

dec_config = {
    'seq_len': 60,
    'input_size': 24,
    'hidden_size': 256,
    'num_layers': 2,
    'bidirectional': True,
    "embedding_size": 512,
    "linear_filters":[128,256,512,1024],
}

clf_config = {
    'window_size': 25,
    'embedding_size': 512,
    'hidden_size': 256
}


base_config = {
    'imu_config': imu_config,
    'sgn_config': sgn_config,
    'dec_config': dec_config,
    'clf_config': clf_config,
    'num_layers': 1,
    'xmert_config': {
        'vocab_size': 1024,
        'hidden_size': 512,
        'num_attention_heads': 2,
        'intermediate_size': 512
    }
}
 
base_model = BaseModel(base_config)

In [47]:
imu_input = torch.randn((32, 24, 54))
skel_input = torch.randn((32, 60, 24))
y = torch.randint(low=0, high=2, size=(32,)).float()

In [48]:
imu_output, skel_output = base_model(imu_input, skel_input)

imu_feats torch.Size([32, 25, 512]) | skel_feats torch.Size([32, 60, 512])
imu_feats torch.Size([32, 25, 512]) | skel_feats torch.Size([32, 60, 512])


In [49]:
skel_output.shape

torch.Size([32, 60, 24])

In [52]:
imu_output.shape

torch.Size([32, 1, 1])

In [50]:
clossfunc = nn.BCELoss()
rlossfunc = nn.L1Loss()

In [53]:
closs = clossfunc(torch.squeeze(imu_output), y)
rloss = rlossfunc(skel_input, skel_output)
loss = closs + rloss

In [54]:
loss.backward()

In [55]:
for name, param in base_model.named_parameters():
    print(name, torch.isfinite(param.grad).all())

imu_model.cls_token tensor(True)
imu_model.position_embed tensor(True)
imu_model.input_proj.0.weight tensor(True)
imu_model.input_proj.0.bias tensor(True)
imu_model.input_proj.2.weight tensor(True)
imu_model.input_proj.2.bias tensor(True)
imu_model.input_proj.4.weight tensor(True)
imu_model.input_proj.4.bias tensor(True)
imu_model.input_proj.6.weight tensor(True)
imu_model.input_proj.6.bias tensor(True)
imu_model.transformer_encoder.layers.0.self_attn.in_proj_weight tensor(True)
imu_model.transformer_encoder.layers.0.self_attn.in_proj_bias tensor(True)
imu_model.transformer_encoder.layers.0.self_attn.out_proj.weight tensor(True)
imu_model.transformer_encoder.layers.0.self_attn.out_proj.bias tensor(True)
imu_model.transformer_encoder.layers.0.linear1.weight tensor(True)
imu_model.transformer_encoder.layers.0.linear1.bias tensor(True)
imu_model.transformer_encoder.layers.0.linear2.weight tensor(True)
imu_model.transformer_encoder.layers.0.linear2.bias tensor(True)
imu_model.transformer_e

In [None]:
for name, param in base_model.named_parameters():
    print(name)

imu_model.cls_token
imu_model.position_embed
imu_model.input_proj.0.weight
imu_model.input_proj.0.bias
imu_model.input_proj.2.weight
imu_model.input_proj.2.bias
imu_model.input_proj.4.weight
imu_model.input_proj.4.bias
imu_model.input_proj.6.weight
imu_model.input_proj.6.bias
imu_model.transformer_encoder.layers.0.self_attn.in_proj_weight
imu_model.transformer_encoder.layers.0.self_attn.in_proj_bias
imu_model.transformer_encoder.layers.0.self_attn.out_proj.weight
imu_model.transformer_encoder.layers.0.self_attn.out_proj.bias
imu_model.transformer_encoder.layers.0.linear1.weight
imu_model.transformer_encoder.layers.0.linear1.bias
imu_model.transformer_encoder.layers.0.linear2.weight
imu_model.transformer_encoder.layers.0.linear2.bias
imu_model.transformer_encoder.layers.0.norm1.weight
imu_model.transformer_encoder.layers.0.norm1.bias
imu_model.transformer_encoder.layers.0.norm2.weight
imu_model.transformer_encoder.layers.0.norm2.bias
imu_model.transformer_encoder.layers.1.self_attn.in_p