Imports

In [1]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import spacy
import pickle

Load spacy, disable 'ner' and 'parser', keep tagger only. require GPU.

In [None]:
nlp = spacy.load("en_core_web_sm")
nlp.disable_pipes('ner', 'parser')
#spacy.require_gpu()

[('ner', <spacy.pipeline.pipes.EntityRecognizer at 0x7f50802c0130>),
 ('parser', <spacy.pipeline.pipes.DependencyParser at 0x7f50802c00c0>)]

In [None]:
print(nlp.pipe_names)

['tagger']


Mount Google Drive

In [2]:
#
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


# **Data Preprocessing**

NOUN VERB ADJECTIVE ADVERB - consider adding pronouns

Read data and save as a list where each element is a tuple of lists which are: (keywords, keywords_pos, template, sentence)

In [None]:
# fine grained pos tags
#pos_tags = ['JJ', 'JJR', 'JJS', 'NN', 'NNS', 'NNP', 'NNPS', 'RB', 'RBR', 'RBS', 
#           'VB', 'VBD', 'VBG', 'VBN', 'VBP', 'VBZ']
# coarse grained pos tags
keyword_tags = ['NOUN', 'VERB', 'ADJ', 'ADV']
def read_data(filepath):
  data = []

  with open(filepath) as f:
    # the sentences are the labels
    labels = [sentence[:-1] for sentence in f]
    for doc in nlp.pipe(labels, batch_size=2000, n_process=10): 
      # templates are the POS tags of the sentence
      template = list(map(lambda word : word.pos_, doc))
      # sentence as list of words (sentence is the gold standard reference)
      label = [str(token).lower() for token in doc]
      # extract (and lemmatize) keywords
      keywords = extract_keywords(doc)
      # get keywords pos tags (individually)
      keywords_pos = list(map(lambda word: extract_pos(nlp(word)), keywords))
      data.append((keywords, keywords_pos, template, label))
  return data

In [None]:
# Find words in a sentence whose POS tag is a noun, verb, adjective or adverb. Lemmatize and store as keyword.
def extract_keywords(doc):
  kw = []
  for word in doc:
    if word.pos_ in keyword_tags:
      kw.append(nlp(str(word).lower())[0].lemma_)
  return kw

In [None]:
# Extract pos of a single word (Individually, no context)
def extract_pos(word):
  return word[0].pos_

In [None]:
'''# Read data
train_pre = read_data('drive/MyDrive/TGP/train.txt')
#test_pre = read_data('drive/MyDrive/TGP/test.txt')
#valid_pre = read_data('drive/MyDrive/TGP/valid.txt')'''

In [None]:
'''# Save data file on Google Drive
with open('drive/MyDrive/TGP/training_data_v2.txt', 'wb') as f1:
  pickle.dump(train_pre, f1)'''

In [27]:
'''# Load data file from google drive (fixed version)
currentFile = open('drive/MyDrive/TGP/training_data_v2.txt', mode='rb')
data = pickle.load(currentFile)'''

In [2]:
currentFile = open('drive/MyDrive/TGP/prepared_train.txt', mode='rb')
data = pickle.load(currentFile)

In [28]:
# Define function for sorting that returns the length of the sentence as key
def sortKey(e):
  return len(e[3])

In [29]:
# Sort data according to the length of a sentence
data.sort(key=sortKey)

In [30]:
# Add start and end of sentence tokens - "<sos>", "<eos>"
for i, (kw, kw_pos, template, sentence) in enumerate(data):
  template = ["<sos>"] + template + ["<eos>"]
  # For sentence we only use the <sos> token for generation
  sentence = ["<sos>"] + sentence
  data[i] = (kw, kw_pos, template, sentence)

In [3]:
len(data)

250000

In [4]:
len(data[-1][3])

38



---



In [37]:
'''with open('drive/MyDrive/TGP/prepared_train.txt', 'wb') as f1:
  pickle.dump(data, f1)'''

In [None]:
'''currentFile = open('drive/MyDrive/TGP/prepared_train.txt', mode='rb')
data = pickle.load(currentFile)'''

In [25]:
data.reverse()

In [None]:
'''nlp.disable_pipes('tagger')
print(nlp.pipe_names)'''

"nlp.disable_pipes('tagger')\nprint(nlp.pipe_names)"

The following Vocab class is served as a dictionary that maps words and tags into Ids. The __unk__ token is used for words that are not part of the training data, while __pad__ token is used as padding value (0). <sos> and <eos> are start and end of sentence tokens respectively.

In [5]:
class Vocab:
    def __init__(self):
      self.word2id = {"__pad__": 0, "__unk__": 1, "<sos>": 2}
      self.id2word = {0: "__pad__", 1: "__unk__", 2: "<sos>"}
      self.n_words = 2
        
      self.tag2id = {"__pad__": 0, "<sos>": 1, "<eos>": 2, 'ADJ':3, 'ADP':4, 'ADV':5, 'AUX':6, 'CONJ':7, 
                     'CCONJ':8, 'DET':9, 'INTJ':10, 'NOUN':11, 'NUM':12, 'PART':13, 'PRON':14, 'PROPN':15, 
                     'PUNCT':16, 'SCONJ':17, 'SYM':18, 'VERB':19, 'X':20, 'SPACE':21}
      self.id2tag = {v:k for (k, v) in self.tag2id.items()}
        
    def index_words(self, words):
      word_indexes = [self.index_word(w) for w in words]
      return word_indexes

    def index_tags(self, tags):
      tag_indexes = [self.tag2id[t] for t in tags]
      return tag_indexes
    
    def index_word(self, w):
        if w not in self.word2id:
          self.n_words += 1
          self.word2id[w] = self.n_words
          self.id2word[self.n_words] = w
        return self.word2id[w]

In [6]:
# Function for creating a new vocabulary from the words in the training data
def create_vocabulary(data):
    vocab = Vocab()
    for (keywords, _, _, sentence) in data:
      for token in sentence:
        vocab.index_word(token)
      for keyword in keywords:
        vocab.index_word(keyword)
      
    return vocab

In [7]:
vocab = create_vocabulary(data)

In [8]:
vocab.n_words

83138

REMINDER: Check if replacing -PRON- with the actual word improves performance

In [9]:
# Use the Vocab object to convert the data from strings to integers
def convert_data(data, vocab):
  int_data = []
  for (keywords, keywords_pos, template, sentence) in data:
    int_kw = [vocab.word2id[keyword] for keyword in keywords]
    int_kw_pos = [vocab.tag2id[keyword_pos] for keyword_pos in keywords_pos]
    int_template = [vocab.tag2id[pos] for pos in template]
    int_sentence = [vocab.word2id[word] for word in sentence]
    int_data.append((int_kw, int_kw_pos, int_template, int_sentence))
  
  return int_data

In [10]:
data = convert_data(data, vocab)

In [11]:
# Create batches and pad relevant input data
# Templates are going into GRU so they will be packed instead of padded.
def create_batches(data, batch_size=64):
  # The list to store all instances and corresponding labels
  data_batches = []
  # Keyword batch to store in data_batches
  kw_batch = []
  # Keyword pos batch to store in data_batches
  kw_pos_batch = []
  # Template batch to store in data_batches 
  template_batch = []
  # Store lengths of each instance in template batch for packing
  template_len = []
  # Sentence batch to store in data_batches
  sentence_batch = []
  
  for i, (keywords, keywords_pos, template, sentence) in enumerate(data):
    kw_batch.append(torch.LongTensor(keywords))
    kw_pos_batch.append(torch.LongTensor(keywords_pos))
    template_batch.append(torch.LongTensor(template))
    template_len.append(len(template))
    sentence_batch.append(torch.LongTensor(sentence))

    if (i + 1) % batch_size == 0:
      # Pad batchs of size batch_size
      kw_batch = torch.nn.utils.rnn.pad_sequence(kw_batch, batch_first=True)
      kw_pos_batch = torch.nn.utils.rnn.pad_sequence(kw_pos_batch, batch_first=True)
      template_batch = torch.nn.utils.rnn.pad_sequence(template_batch, batch_first=True)
      sentence_batch = torch.nn.utils.rnn.pad_sequence(sentence_batch, batch_first=True)

      data_batches.append((kw_batch, kw_pos_batch, (template_batch, template_len), sentence_batch))
      
      # Reinitialize the batches
      kw_batch = []
      kw_pos_batch = []
      template_batch = []
      template_len = []
      sentence_batch = []
      label_batch = []
    
  return data_batches

In [12]:
data = create_batches(data)

# **Model Architecture**

In [13]:
class Generator(nn.Module):
  def __init__(self, w_embed_dim=500, tag_embed_dim=57):
    super(Generator, self).__init__()
    # Embedding layer for words (shared between keywords and sentences)
    self.word_embed = nn.Embedding(num_embeddings=vocab.n_words, embedding_dim=w_embed_dim, padding_idx=0)
    # Embedding layer for pos tags (shared between keyword tags and templates)
    self.tag_embed = nn.Embedding(num_embeddings=len(vocab.tag2id), embedding_dim=tag_embed_dim, padding_idx=0)
    
    # Keyword encoder
    self.kw_encoder = FFNN(w_embed_dim)
    # Template encoder
    self.template_encoder = TemplateEncoder(tag_embed_dim)
    # Max Tag Overlap (template and keyword tag matching)
    self.mto = MTO(1)
    # Attention layer
    self.attn = Attention()
    # Decoder
    self.decoder = Decoder()


  def forward(self, kw, kw_pos, template_pack, sentence):    
    # Keyword encoder
    kw_embed = self.word_embed(kw)
    encoded_kw = self.kw_encoder(kw_embed)

    # Template encoder
    template, template_len = template_pack
    template_embed = self.tag_embed(template)
    packed_template = torch.nn.utils.rnn.pack_padded_sequence(template_embed, template_len, 
                                                              batch_first=True, enforce_sorted=False)
    outputs, last_hidden = self.template_encoder(packed_template)
    # Embed keyword pos tags
    kw_pos_embed = self.tag_embed(kw_pos)

    # Get lambda weights from Max Tag Overlap
    lambdas, lambdas_c = self.mto(template_embed, kw_pos_embed)

    # Embed sentence
    sentence_embed = self.word_embed(sentence)

    # First input is the <sos> token
    input = sentence_embed[:, 0, :].unsqueeze(1)

    # Initialize hidden layer for decoder
    hidden = torch.zeros(8, sentence.size(0), 500).cuda() 

    # Initialize mask to ignore padded values in attention mechanism 
    mask = torch.any((kw_embed != 0), dim=2).cuda()

    # List to store predictions
    preds = []
    for t in range(1, sentence_embed.size(1)):
      # Use attention to calculate context vector
      a = self.attn(outputs[:,t,:], last_hidden, encoded_kw, mask)
      context = torch.bmm(a.unsqueeze(1), encoded_kw)
      # Multiply this time step's context by its weight lambda
      context *= lambdas[:,t,:].unsqueeze(1)
      # Multiply this time step's template encoding by its weight lambda c (1 minus lambda)
      htt =  outputs[:,t,:].unsqueeze(1) * lambdas[:,t,:].unsqueeze(1)
      # mt is tanh of the concatenation of context and current time step's encoded template
      mt = torch.tanh(torch.cat((context, htt), dim=2))
      # Input to the decoder is the concatenation of mt and the embedding of a word from current time step t
      decoder_input = torch.cat((input, mt), dim=2)
      # Decode
      pred, hidden, last_hidden = self.decoder(decoder_input, hidden)
      # Store the highest probability prediction
      preds.append(pred)
    
    return torch.stack(preds, dim=1)

In [14]:
# Keyword encoder - Fully connected neural network
class FFNN(nn.Module):
  def __init__(self, h_dim):
    super(FFNN, self).__init__()
    self.fc1 = nn.Linear(h_dim, h_dim)
    self.lrelu1 = nn.LeakyReLU(0.01)
    self.fc2 = nn.Linear(h_dim, h_dim)
    self.lrelu2 = nn.LeakyReLU(0.01)
    self.fc3 = nn.Linear(h_dim, h_dim)
  
  
  def forward(self, x):
    x = self.lrelu1(self.fc1(x))
    x = self.lrelu2(self.fc2(x))
    x = torch.tanh(self.fc3(x))
    return x

In [15]:
# Template encoder - stack of bidirectional GRU's and MLP's to reduce the dimension for the decoder.
class TemplateEncoder(nn.Module):
  def __init__(self, input_dim, h_size=100):
    super(TemplateEncoder, self).__init__()
    self.gru = nn.GRU(input_size=input_dim, hidden_size=h_size, num_layers=4, batch_first=True, 
                      dropout=0.5, bidirectional=True)
    self.fc = nn.Linear(2*h_size, 2*h_size)
  
  def forward(self, x):
    outputs, hidden = self.gru(x)
    outputs, _ = torch.nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True)
    hidden = torch.tanh(self.fc(torch.cat((hidden[-2], hidden[-1]), dim=1)))

    return outputs, hidden

In [16]:
# Template and keyword tag matching (Max Tag Overlap)
class MTO(nn.Module):
  def __init__(self, input_size):
    super(MTO, self).__init__()
    self.fc = nn.Linear(input_size, input_size)

  def forward(self, template_embed, kw_pos_embed):
    # Calculate s - the max cosine similarity between each template tag and keyword tag
    s_batch = []
    for template_tags, keyword_tags in zip(template_embed, kw_pos_embed):
      s = []
      for t_pos in template_tags:
        s.append(torch.max(F.cosine_similarity(keyword_tags, t_pos.unsqueeze(0))).item())
      s_batch.append(torch.tensor(s))

    s_batch = torch.stack(s_batch).unsqueeze(2)
    s_batch = s_batch.cuda()

    # Lambdas are weights which are equal to s going into a sigmoid on top of a linear layer 
    lambdas = torch.sigmoid(self.fc(s_batch))
    # Lambda_c is 1-lambda for each weight lambda in lambdas
    lambdas_c = torch.add(torch.multiply(lambdas, -1), 1)

    return lambdas, lambdas_c

In [17]:
# Attention layer (additive attention)
class Attention(nn.Module):
  def __init__(self):
    super(Attention, self).__init__()
    self.w = torch.nn.Linear(900, 500)
    self.v = nn.Linear(500, 1, bias=False)

  def forward(self, enc_template_o, hidden, encoded_kw, mask):
    # Unsqueeze for repeat
    hidden = hidden.unsqueeze(1)
    enc_template_o = enc_template_o.unsqueeze(1)
    # Repeat for stacking
    hidden = hidden.repeat(1, encoded_kw.size(1), 1)
    enc_template_o = enc_template_o.repeat(1, encoded_kw.size(1), 1)
    # Conatenate last hidden layer, the encoded keywords and the encoded template
    energy = torch.tanh(self.w(torch.cat((hidden, encoded_kw, enc_template_o), dim=2)))
    # Calculate score
    attention = self.v(energy).squeeze(2)
    # Mask
    attention = attention.masked_fill(mask == 0, -1e10)

    return F.softmax(attention, dim=1)

In [18]:
# Decoder
class Decoder(nn.Module):
  def __init__(self, h_size=500, n_layers=4):
    super(Decoder, self).__init__()
    self.gru = nn.GRU(input_size=1200, hidden_size=500, num_layers=4, batch_first=True, 
                      dropout=0.5, bidirectional=True)
    self.fc = nn.Linear(1000, 200)
    self.fc_out = nn.Linear(2200, vocab.n_words)

  def forward(self, gru_input, hidden_input):
    outputs, hidden = self.gru(gru_input, hidden_input)
    # Last hidden layers for attention
    last_hidden = torch.tanh(self.fc(torch.cat((hidden[-2], hidden[-1]), dim=1)))
    prediction = self.fc_out(torch.cat((outputs.squeeze(1), gru_input.squeeze(1)), dim=1))
    return prediction, hidden, last_hidden

# **Training**

Notes to self: 
- Feed the decoder with the encoder last state
- Maybe try without MLP in the encoder.

- Read the article in NLP projects folder to decide num of layers of GRU in encoder and decoder

- Pack templates going into GRU encoder. Unpack them when going into MLP which reduces dimension

- Check how torch.nn.utils.rnn.pad_packed_sequence and torch.nn.utils.rnn.pack_sequence operate (pack a batch, unpack it and pack again and see if it remembers the paddings). probably need pack_padded_sequence for repacking (after unpacking)

- Add a dropout layer after embeddings

- Try using word2vec https://stackoverflow.com/questions/49710537/pytorch-gensim-how-to-load-pre-trained-word-embeddings/49802495#49802495

- Try teacher forcing

- For initialization of decoder hidden states, try random noise (or something else) instead of zeros

- Consider mask in attention mechanism

In [19]:
# Initialize the model
model = Generator().cuda()

# Loss function
criterion = torch.nn.CrossEntropyLoss()

# Optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [None]:
# Load saved model

#model.load_state_dict(torch.load("drive/MyDrive/model_0"))

In [None]:
'''def init_weights(m):
    for name, param in m.named_parameters():
        if 'weight' in name:
            nn.init.normal_(param.data, mean=0, std=0.01)
        else:
            nn.init.constant_(param.data, 0)
            
model.apply(init_weights)'''

In [20]:
# Training loop
def train_loop(model, n_epochs, train_set):
  epoch_loss = 0
  
  for e in range(1, n_epochs + 1):
    for i, (kw, kw_pos, (template, template_len), sentence) in enumerate(train_set, 1):
      kw = kw.cuda()
      kw_pos = kw_pos.cuda()
      template = template.cuda()
      sentence = sentence.cuda()

      optimizer.zero_grad()
      output = model(kw, kw_pos, (template, template_len), sentence)
      loss = criterion(output.reshape(-1, output.shape[-1]), sentence[:, 1:].reshape(-1))
      loss.backward()
      optimizer.step()
      epoch_loss += loss.item()

      #print(f"batch {i}")

    # Save model every epoch
    torch.save(model.state_dict(), f"drive/MyDrive/model") 
    # Print epoch loss
    print(f"Epoch {e} train loss: {epoch_loss / len(train_set)}")
    # Reset epoch loss
    epoch_loss = 0
    

In [None]:
train_loop(model, 40, data)

Epoch 1 train loss: 2.7816073292350376
Epoch 2 train loss: 1.9308604925214725
Epoch 3 train loss: 1.687490767063511


# **TESTS**

In [None]:
gen = Generator()
gen(data[0])



---



In [None]:
gen = Generator()

In [None]:
outputs, hidden, encoded_kw, sentence_embed, lambdas, lambdas_c, context, htt, mt, decoder_input, decoded = gen(data[0])

In [None]:
outputs.size()

torch.Size([256, 5, 200])

In [None]:
hidden.size()

torch.Size([256, 200])

In [None]:
encoded_kw.size()

torch.Size([256, 3, 500])

In [None]:
attn = Attention()
attn = attn(outputs[:,1,:], hidden, encoded_kw)

In [None]:
attn.size()

torch.Size([256, 3])

In [None]:
# Context - ct
ct = torch.bmm(attn.unsqueeze(1), encoded_kw)

In [None]:
ct.size()

torch.Size([256, 1, 500])

In [None]:
lambdas.size()

torch.Size([256, 5, 1])

In [None]:
l = lambdas[:,0,:].unsqueeze(1)
l.size()

torch.Size([256, 1, 1])

In [None]:
lambdas_c.size()

torch.Size([256, 5, 1])

In [None]:
lambdas_c[:, 1, :].size()

torch.Size([256, 1])

In [None]:
outputs[:,1,:].size()

torch.Size([256, 200])

In [None]:
(lambdas_c[:, 1, :]*outputs[:,1,:]).size()

torch.Size([256, 200])

In [None]:
outputs.size()

torch.Size([256, 5, 200])

In [None]:
# Template times 1-lambda
torch.bmm(lambdas_c.permute(0, 2, 1), outputs).size()

torch.Size([256, 1, 200])

In [None]:
kek = ct*l
kek.size()

torch.Size([256, 1, 500])

In [None]:
sentence_embed.size()

torch.Size([256, 4, 500])

In [None]:
lambdas_c[:, 1, :].size()

torch.Size([256, 1])

In [None]:
outputs[:,1,:].size()

torch.Size([256, 200])

In [None]:
kek = lambdas_c[:, 1, :]*outputs[:,1,:]



---



In [None]:
print(htt.size())
print(context.size())


torch.Size([256, 1, 200])
torch.Size([256, 1, 500])


In [None]:
mt.size()

torch.Size([256, 1, 700])

In [None]:
sentence_embed[:, 0, :].size()

torch.Size([256, 500])

In [None]:
g = nn.GRU(input_size=1200, hidden_size=500, num_layers=4, batch_first=True, 
                      dropout=0.5, bidirectional=True)

In [None]:
decoder_input.size()

torch.Size([256, 1, 1200])

In [None]:
o, h = g(decoder_input)

In [None]:
o.size()

torch.Size([256, 1, 1000])

In [None]:
decoded[0].size()

torch.Size([256, 1, 1000])

In [None]:
l = nn.Linear(2200, vocab.n_words)

In [None]:
a = decoded[0].squeeze(1)

In [None]:
a.size() == 

torch.Size([256, 1000])

In [None]:
b = mt.squeeze(1)
b.size()

torch.Size([256, 700])

In [None]:
c = torch.cat((a, b), dim=1)
c.size()

torch.Size([256, 1700])



---



In [None]:
gen = Generator()

In [None]:
preds = gen(data[0])

In [None]:
preds.size()

torch.Size([256, 3])

In [None]:
outputs, hidden, encoded_kw, sentence_embed, lambdas, lambdas_c, context, htt, mt, decoder_input = gen(data[0])

torch.Size([256, 200])
torch.Size([256, 1000])
torch.Size([256, 1000])
torch.Size([256, 1000])


In [None]:
outputs, hidden, encoded_kw, sentence_embed, lambdas, lambdas_c, context, htt, mt, decoder_input = gen(data[0])

torch.Size([8, 256, 100])


In [None]:
mt.size()

True

In [None]:
hidden.size()

torch.Size([256, 200])

In [None]:
preds = gen(data[0])

In [None]:
data[0][2][0].size()


torch.Size([256, 5])

In [None]:
data[0][3].size()

torch.Size([256, 4])

In [None]:
a = torch.tensor([1,2,3,10])
b = torch.tensor([4,5,6,11])
c = torch.tensor([7,8,9,12])
d = torch.stack((a, b, c), dim=1)

In [None]:
d[1]

tensor([2, 5, 8])

In [74]:
a = torch.tensor([[[1,2,3, 0, 0], [2,1, 0, 0, 0], [2,0,0,0,0]], [[1,2,3, 0, 0], [2,1, 0, 0, 0], [2,0,0,0,0]], [[1,2,3, 0, 0], [2,1, 0, 0, 0], [2,0,0,0,0]]])
b = (a != 0)
print(b)

tensor([[[ True,  True,  True, False, False],
         [ True,  True, False, False, False],
         [ True, False, False, False, False]],

        [[ True,  True,  True, False, False],
         [ True,  True, False, False, False],
         [ True, False, False, False, False]],

        [[ True,  True,  True, False, False],
         [ True,  True, False, False, False],
         [ True, False, False, False, False]]])


In [75]:
a.size()

torch.Size([3, 3, 5])

In [76]:
c = torch.tensor([[[1,2,3, 4, 5], [6,7, 8, 1, 2], [2,3,4,5,1]], [[1,2,3, 4, 5], [6,7, 8, 1, 2], [2,3,4,5,1]], [[1,2,3, 4, 5], [6,7, 8, 1, 2], [2,3,4,5,1]]])
c.masked_fill(b == 0, -1e10)

tensor([[[           1,            2,            3, -10000000000, -10000000000],
         [           6,            7, -10000000000, -10000000000, -10000000000],
         [           2, -10000000000, -10000000000, -10000000000, -10000000000]],

        [[           1,            2,            3, -10000000000, -10000000000],
         [           6,            7, -10000000000, -10000000000, -10000000000],
         [           2, -10000000000, -10000000000, -10000000000, -10000000000]],

        [[           1,            2,            3, -10000000000, -10000000000],
         [           6,            7, -10000000000, -10000000000, -10000000000],
         [           2, -10000000000, -10000000000, -10000000000, -10000000000]]])

In [80]:
embedder = nn.Embedding(num_embeddings=vocab.n_words, embedding_dim=500, padding_idx=0)

In [81]:
a = embedder(data[5][0])

In [82]:
a.size()

torch.Size([128, 4, 500])

In [88]:
b = a[:,:,0]

In [93]:
mask = (a != 0)

In [95]:
mask.size()

torch.Size([128, 4, 500])

In [100]:
mask[0]

tensor([[ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False]])

In [101]:
torch.any(mask, dim=2)[0]

tensor([ True,  True, False, False])

In [102]:
torch.any(mask, dim=2).size()

torch.Size([128, 4])

In [89]:
b.size()

torch.Size([128, 4])