#QUAG

In [1]:
import torch
import torch.nn as nn
import math

In [2]:
class SelfAttentionWithQUAG(nn.Module):
  def __init__(self, dim_model):
    super().__init__()
    self.dim_model = dim_model
    self.q_lin = nn.Linear(in_features=dim_model, out_features=dim_model)
    self.k_lin = nn.Linear(in_features=dim_model, out_features=dim_model)
    self.v_lin = nn.Linear(in_features=dim_model, out_features=dim_model)
    self.out_lin = nn.Linear(in_features=dim_model, out_features=dim_model)
    self.softmax = torch.nn.Softmax(dim=-1)

  def apply_attention_scores(self, attention_scores, value):
      attended_output = torch.matmul(attention_scores, value)
      return attended_output

  def get_rowwise_average(self, scores, mask):
    rowwise_sum = scores.sum(-1) #torch.sum(scores, dim=0)
    rowwise_mean = rowwise_sum / mask.sum(-1)
    expanded_rowwise_mean = rowwise_mean.unsqueeze(-1).expand(scores.shape)
    return expanded_rowwise_mean

  def apply_quag(self, attention_scores, mask, l_v, l_t, quads):
    if 'VV' in quads:
        attention_scores[:, :l_v, :l_v] = self.get_rowwise_average(attention_scores[:, :l_v, :l_v], mask[:, :l_v, :l_v])
    if 'VT' in quads:
        attention_scores[:, :l_v, -l_t:] = self.get_rowwise_average(attention_scores[:, :l_v, -l_t:], mask[:, :l_v, -l_t:])
    if 'TV' in quads:
        attention_scores[:, -l_t:, :l_v] = self.get_rowwise_average(attention_scores[:, -l_t:, :l_v], mask[:, -l_t:, :l_v])
    if 'TT' in quads:
        attention_scores[:, -l_t:, -l_t:] = self.get_rowwise_average(attention_scores[:, -l_t:, -l_t:], mask[:, -l_t:, -l_t:])
    attention_scores.masked_fill_(mask==0, 0)
    return attention_scores


  def forward(self, inputs, mask, l_v, l_t, quads):
      # Inputs:
      #   inputs: Tensor of shape (batch_size, sequence_length, dim_model)
      #   mask: Tensor of shape (batch_size, sequence_length)
      #   dim_model: Dimension of the model (e.g., 512)
      #   l_v: int    maximum length of video tokens
      #   l_t: int    maximum length of question tokens
      #   quads: list containing elements from {'VV', 'VT', 'TV', 'TT'}

      query = self.q_lin(inputs)
      key = self.k_lin(inputs)
      value = self.v_lin(inputs)
      scaled_dot_product = torch.matmul(query, key.transpose(1, 2)) / math.sqrt(self.dim_model)
      mask =  mask.unsqueeze(1) * mask.unsqueeze(2)
      scaled_dot_product = scaled_dot_product.masked_fill_(mask==0, -float("inf"))
      attention_scores = self.softmax(scaled_dot_product)
      attention_scores.masked_fill_(mask==0, 0.0)
      print(f"Attention scores before QUAG\n{attention_scores}")
      attention_scores = self.apply_quag(attention_scores, mask, l_v, l_t, quads)
      print(f"Attention scores after QUAG {quads}\n{attention_scores}")
      attended_output = 0.5*(inputs + self.apply_attention_scores(attention_scores, value))
      attended_output = self.out_lin(attended_output)
      return attended_output

Example of QUAG

In [3]:
quag_sa = SelfAttentionWithQUAG(6)

In [4]:
input = torch.tensor(torch.randn(1,6,6))
mask = torch.tensor([1,1,0,1,1,1]).unsqueeze(0)
output = quag_sa(input, mask, 3, 3, ['VV', 'TT'])

Attention scores before QUAG
tensor([[[0.1253, 0.2028, 0.0000, 0.2995, 0.1659, 0.2065],
         [0.1398, 0.2226, 0.0000, 0.2092, 0.1917, 0.2366],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.2685, 0.1900, 0.0000, 0.1402, 0.2514, 0.1499],
         [0.1633, 0.2167, 0.0000, 0.2429, 0.1852, 0.1919],
         [0.1754, 0.2061, 0.0000, 0.2052, 0.1943, 0.2190]]],
       grad_fn=<MaskedFillBackward0>)
Attention scores after QUAG ['VV', 'TT']
tensor([[[0.1641, 0.1641, 0.0000, 0.2995, 0.1659, 0.2065],
         [0.1812, 0.1812, 0.0000, 0.2092, 0.1917, 0.2366],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.2685, 0.1900, 0.0000, 0.1805, 0.1805, 0.1805],
         [0.1633, 0.2167, 0.0000, 0.2067, 0.2067, 0.2067],
         [0.1754, 0.2061, 0.0000, 0.2062, 0.2062, 0.2062]]],
       grad_fn=<MaskedFillBackward0>)


  input = torch.tensor(torch.randn(1,6,6))


#QUAG-Attention

In [5]:
import torch
import torch.nn as nn
import math

In [6]:
class QUAGAttention(nn.Module):
  def __init__(self, dim_model):
    super().__init__()
    self.dim_model = dim_model
    self.q_lin = nn.Linear(in_features=dim_model, out_features=dim_model)
    self.k_lin = nn.Linear(in_features=dim_model, out_features=dim_model)
    self.v_lin = nn.Linear(in_features=dim_model, out_features=dim_model)
    self.out_lin = nn.Linear(in_features=dim_model, out_features=dim_model)
    self.softmax = torch.nn.Softmax(dim=-1)

  def apply_attention_scores(self, attention_scores, value):
    attended_output = torch.matmul(attention_scores, value)
    return attended_output

  def get_avg_inputs(self, inputs, l_v, l_t, mask, mode):
    unmasked_lang = mask[:, -l_t:].sum(1).unsqueeze(1)
    unmasked_vid = mask[:, :l_v].sum(1).unsqueeze(1)
    ip = inputs * mask.unsqueeze(-1).expand(-1, -1, inputs.shape[-1]) #zero all the values that are padded
    #we need to take average of only unpadded values of key
    if mode == 'text-vid-avg':
      ip = torch.cat(((ip[:,:l_v,:].sum(1)/unmasked_vid).unsqueeze(1),(ip[:,-l_t:,:].sum(1)/unmasked_lang).unsqueeze(1)),1) #(bs, 2, dim)
    elif mode == 'vid-avg':
      ip = torch.cat(((ip[:,:l_v,:].sum(1)/unmasked_vid).unsqueeze(1),ip[:,-l_t:,:]),1) #(bs, 1+T, dim)
    elif mode == 'text-avg':
      ip = torch.cat((ip[:,:l_v,:],(ip[:,-l_t:,:].sum(1)/unmasked_lang).unsqueeze(1)),1) #(bs, V+1, dim)
    return ip

  def get_new_mask(self, mask, l_v, l_t, mode):
    if mode == 'text-vid-avg':
      attention_mask = torch.cat((mask[:, :, 0].unsqueeze(-1),mask[:, :, l_v].unsqueeze(-1)),-1)
    elif mode == 'vid-avg':
      attention_mask = torch.cat((mask[:, :, 0].unsqueeze(-1),mask[:, :, -l_t:]),-1)
    elif mode == 'text-avg':
      attention_mask = torch.cat((mask[:, :, :l_v],mask[:, :, -l_t].unsqueeze(-1)),-1)
    return attention_mask


  def apply_scaling(self, scaled_dot_product, mask, l_v, l_t, mode):
    vid_scaling = math.log(mask[:, :l_v, 0].sum(-1))#math.log(l_v)
    text_scaling = math.log(mask[:, -l_t:, 0].sum(-1))#math.log(l_v)
    if "vid" in mode:
      scaled_dot_product[:,:,0] = scaled_dot_product[:,:,0]*vid_scaling
    if "text" in mode:
      scaled_dot_product[:,:,-1] = scaled_dot_product[:,:,-1]*text_scaling
    return scaled_dot_product

  def forward(self, inputs, mask, l_v, l_t, mode):
    # Inputs:
    #   inputs: Tensor of shape (batch_size, sequence_length, dim_model)
    #   mask: Tensor of shape (batch_size, sequence_length)
    #   dim_model: Dimension of the model (e.g., 512)
    #   l_v: int    maximum length of video tokens
    #   l_t: int    maximum length of question tokens
    #   mode: one of {'vid-avg', 'text-avg', 'text-vid-avg'}

    average_inputs = self.get_avg_inputs(inputs, l_v, l_t, mask, mode)
    query = self.q_lin(inputs)
    key = self.k_lin(average_inputs)
    value = self.v_lin(average_inputs)

    scaled_dot_product = torch.matmul(query, key.transpose(1, 2)) / math.sqrt(self.dim_model)
    mask =  mask.unsqueeze(1) * mask.unsqueeze(2)
    avg_mask  = self.get_new_mask(mask, l_v, l_t, mode)

    scaled_dot_product = scaled_dot_product.masked_fill_(avg_mask==0, -float("inf"))
    scaled_dot_product = self.apply_scaling(scaled_dot_product, mask, l_v, l_t, mode)
    attention_scores = self.softmax(scaled_dot_product)
    attention_scores.masked_fill_(avg_mask==0, 0.0)
    print(f"{mode}-QUAG Attention Matrix:\n{attention_scores}")
    attended_output = 0.5*(inputs + self.apply_attention_scores(attention_scores, value))
    attended_output = self.out_lin(attended_output)
    return attended_output

Example of QUAG-attention

In [7]:
quag_attention = QUAGAttention(6)

In [8]:
input = torch.tensor(torch.randn(1,6,6))
mask = torch.tensor([1,1,0,1,1,1]).unsqueeze(0)

  input = torch.tensor(torch.randn(1,6,6))


In [9]:
output = quag_attention(input, mask, l_v=4, l_t=2, mode='text-vid-avg')

text-vid-avg-QUAG Attention Matrix:
tensor([[[0.2407, 0.7593],
         [0.3579, 0.6421],
         [0.0000, 0.0000],
         [0.4135, 0.5865],
         [0.4455, 0.5545],
         [0.5516, 0.4484]]], grad_fn=<MaskedFillBackward0>)


In [10]:
output = quag_attention(input, mask, l_v=4, l_t=2, mode='vid-avg')

vid-avg-QUAG Attention Matrix:
tensor([[[0.1110, 0.2332, 0.6558],
         [0.1663, 0.2122, 0.6215],
         [0.0000, 0.0000, 0.0000],
         [0.2564, 0.5068, 0.2368],
         [0.2822, 0.4076, 0.3102],
         [0.3757, 0.1897, 0.4346]]], grad_fn=<MaskedFillBackward0>)


In [11]:
output = quag_attention(input, mask, l_v=4, l_t=2, mode='text-avg')

text-avg-QUAG Attention Matrix:
tensor([[[0.2652, 0.0837, 0.0000, 0.1850, 0.4661],
         [0.4359, 0.0834, 0.0000, 0.1612, 0.3195],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0896, 0.2759, 0.0000, 0.3540, 0.2805],
         [0.1444, 0.2523, 0.0000, 0.3247, 0.2787],
         [0.3030, 0.3091, 0.0000, 0.1787, 0.2092]]],
       grad_fn=<MaskedFillBackward0>)
