In [1]:
import numpy as np 
import math

import torch 
from torch import nn, Tensor
from torch.nn import functional as F 

In [2]:
bs = 32
ft_dims = 6
seq_len = 120
sample_input = torch.randn((bs, seq_len, ft_dims))
sample_input.shape

torch.Size([32, 120, 6])

In [3]:
# use of linear projection layer
d_model = 128
denseL = nn.Linear(ft_dims, d_model)

sample_out = denseL(sample_input)
sample_out.shape

torch.Size([32, 120, 128])

In [4]:
# use of Attention layer 
n_heads = 8
inter_ft = 128
attnL = nn.TransformerEncoderLayer(d_model, n_heads, inter_ft)

sample_out = attnL(sample_out)
sample_out.shape

torch.Size([32, 120, 128])

In [5]:
attnL

TransformerEncoderLayer(
  (self_attn): MultiheadAttention(
    (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
  )
  (linear1): Linear(in_features=128, out_features=128, bias=True)
  (dropout): Dropout(p=0.1, inplace=False)
  (linear2): Linear(in_features=128, out_features=128, bias=True)
  (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  (dropout1): Dropout(p=0.1, inplace=False)
  (dropout2): Dropout(p=0.1, inplace=False)
)

In [6]:
# relative global attention based transformer encoder implementation
class RelativeGlobalAttention(nn.Module):
    def __init__(self, d_model, num_heads, max_len=1024, dropout=0.1):
        super().__init__()
        d_head, remainder = divmod(d_model, num_heads)
        if remainder:
            raise ValueError(
                "incompatible `d_model` and `num_heads`"
            )
        self.max_len = max_len
        self.d_model = d_model
        self.num_heads = num_heads
        self.key = nn.Linear(d_model, d_model)
        self.value = nn.Linear(d_model, d_model)
        self.query = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)
        self.Er = nn.Parameter(torch.randn(max_len, d_head))
        self.register_buffer(
            "mask", 
            torch.tril(torch.ones(max_len, max_len))
            .unsqueeze(0).unsqueeze(0)
        )
        # self.mask.shape = (1, 1, max_len, max_len)

    
    def forward(self, x):
        # x.shape == (batch_size, seq_len, d_model)
        batch_size, seq_len, _ = x.shape
        
        if seq_len > self.max_len:
            raise ValueError(
                "sequence length exceeds model capacity"
            )
        
        k_t = self.key(x).reshape(batch_size, seq_len, self.num_heads, -1).permute(0, 2, 3, 1)
        # k_t.shape = (batch_size, num_heads, d_head, seq_len)
        v = self.value(x).reshape(batch_size, seq_len, self.num_heads, -1).transpose(1, 2)
        q = self.query(x).reshape(batch_size, seq_len, self.num_heads, -1).transpose(1, 2)
        # shape = (batch_size, num_heads, seq_len, d_head)
        
        start = self.max_len - seq_len
        Er_t = self.Er[start:, :].transpose(0, 1)
        # Er_t.shape = (d_head, seq_len)
        QEr = torch.matmul(q, Er_t)
        # QEr.shape = (batch_size, num_heads, seq_len, seq_len)
        Srel = self.skew(QEr)
        # Srel.shape = (batch_size, num_heads, seq_len, seq_len)
        
        QK_t = torch.matmul(q, k_t)
        # QK_t.shape = (batch_size, num_heads, seq_len, seq_len)
        attn = (QK_t + Srel) / math.sqrt(q.size(-1))
        mask = self.mask[:, :, :seq_len, :seq_len]
        # mask.shape = (1, 1, seq_len, seq_len)
        attn = attn.masked_fill(mask == 0, float("-inf"))
        # attn.shape = (batch_size, num_heads, seq_len, seq_len)
        attn = F.softmax(attn, dim=-1)
        out = torch.matmul(attn, v)
        # out.shape = (batch_size, num_heads, seq_len, d_head)
        out = out.transpose(1, 2)
        # out.shape == (batch_size, seq_len, num_heads, d_head)
        out = out.reshape(batch_size, seq_len, -1)
        # out.shape == (batch_size, seq_len, d_model)
        return self.dropout(out)
        
    
    def skew(self, QEr):
        # QEr.shape = (batch_size, num_heads, seq_len, seq_len)
        padded = F.pad(QEr, (1, 0))
        # padded.shape = (batch_size, num_heads, seq_len, 1 + seq_len)
        batch_size, num_heads, num_rows, num_cols = padded.shape
        reshaped = padded.reshape(batch_size, num_heads, num_cols, num_rows)
        # reshaped.size = (batch_size, num_heads, 1 + seq_len, seq_len)
        Srel = reshaped[:, :, 1:, :]
        # Srel.shape = (batch_size, num_heads, seq_len, seq_len)
        return Srel

In [7]:
# test relative Glob. Attn.
relAttnL = RelativeGlobalAttention(d_model, n_heads, seq_len)

sample_input1 = denseL(sample_input)
sample_out = relAttnL(sample_input1)
sample_out.shape

torch.Size([32, 120, 128])

In [8]:
# build E2E ZSL HAR Model

class ZSLHARNet(nn.Module):
    def __init__(self, in_ft, d_model, num_heads, ft_size, attr_size, max_len=1024, dropout=0.1):
        super().__init__()
        self.in_ft = in_ft
        self.max_len = max_len
        self.d_model = d_model
        self.num_heads = num_heads
        self.ft_size = ft_size # semantic space size <-> output feature space size
        self.attr_size = attr_size # intermediate attribute space size 

        # custom sample layer configuration
        # Dense layer for feature projection
        self.DenseL = nn.Linear(in_ft, d_model)
        # attention encoder <-> global relative attention used here
        self.AttnL = RelativeGlobalAttention(d_model, num_heads, max_len)
        # positional encoding concat <-> LSTM 
        self.lstmL = nn.LSTM(input_size=d_model, hidden_size=ft_size, batch_first=True)
        # SAE submodule
        self.EncDenseL = nn.Linear(in_features=ft_size, out_features=attr_size, bias=False)
        self.DecDenseL = nn.Linear(in_features=attr_size, out_features=ft_dims, bias=False)
        # override weights
        del self.EncDenseL.weight
        del self.DecDenseL.weight
        # define shared weights
        self.TransMet = nn.Parameter(torch.randn(attr_size, ft_size))


    def forward(self, x):
        out = self.DenseL(x)
        out = self.AttnL(out)
        lstm_out, hidden = self.lstmL(out)
        # SAE Operation
        self.EncDenseL.weight = self.TransMet
        self.DecDenseL.weight = self.TransMet.T 
        attr_out = self.EncDenseL(lstm_out[:, -1, :])
        ft_out = self.DecDenseL(attr_out)
        return attr_out, ft_out
  


In [9]:
model = ZSLHARNet(in_ft=6, d_model=128, num_heads=8, ft_size=64, attr_size=16, max_len=120)

In [10]:
sample_input = torch.randn((bs, 120, 6))

attr_out, feat_out = model(sample_input)
print(attr_out.shape, feat_out.shape)

torch.Size([32, 16]) torch.Size([32, 64])


In [11]:
attr_out

tensor([[ 1.9738e-01,  1.3880e-01, -2.3393e-01, -1.4038e-01, -2.6661e-01,
          3.4279e-01, -2.5056e-02, -4.5227e-01, -4.0780e-01,  1.6693e-01,
         -6.1354e-02,  5.2765e-01, -2.0932e-01,  6.7647e-01, -2.9521e-01,
         -1.8451e-02],
        [-7.9600e-02,  1.4353e-01, -2.1818e-01, -1.2940e-01, -2.6362e-01,
          4.5588e-01, -2.4781e-01, -6.1644e-01, -6.3781e-01,  3.5662e-01,
          2.6989e-02,  4.4249e-01, -1.2734e-01,  3.8886e-01, -1.3834e-01,
         -1.6093e-01],
        [ 1.7725e-01,  9.2594e-02, -1.2674e-01, -2.9718e-01, -1.7856e-01,
          8.9125e-01, -1.8151e-01, -5.9856e-01, -4.5926e-01, -3.2906e-01,
         -1.5519e-01,  5.6246e-01, -5.2584e-01,  6.3432e-01,  1.2672e-01,
          1.8171e-02],
        [-1.8424e-01,  2.0375e-01, -2.9534e-01, -2.5335e-01, -2.3889e-03,
          2.9360e-01, -8.8648e-02, -5.7559e-01, -6.2367e-01,  2.3156e-01,
         -3.4893e-01,  1.1374e-01, -8.3828e-03,  7.6266e-01, -4.7523e-01,
         -3.7745e-01],
        [ 3.4628e-01