In [1]:
!pip install datasets &> /dev/null
!pip install transformers &> /dev/null

In [2]:
import math

import datasets
import tokenizers 
import torch
import torch.nn as nn
import torch.nn.functional as F
import transformers


# Dataset & Tokenizer

In [3]:
raw_dataset = datasets.load_dataset('imdb', split=['train', 'test[:5000]'])

Downloading:   0%|          | 0.00/1.92k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.05k [00:00<?, ?B/s]

Downloading and preparing dataset imdb/plain_text (download: 80.23 MiB, generated: 127.02 MiB, post-processed: Unknown size, total: 207.25 MiB) to /root/.cache/huggingface/datasets/imdb/plain_text/1.0.0/e3c66f1788a67a89c7058d97ff62b6c30531e05b549de56d3ab91891f0561f9a...


Downloading:   0%|          | 0.00/84.1M [00:00<?, ?B/s]

0 examples [00:00, ? examples/s]

0 examples [00:00, ? examples/s]

0 examples [00:00, ? examples/s]

Dataset imdb downloaded and prepared to /root/.cache/huggingface/datasets/imdb/plain_text/1.0.0/e3c66f1788a67a89c7058d97ff62b6c30531e05b549de56d3ab91891f0561f9a. Subsequent calls will reuse this data.


  0%|          | 0/2 [00:00<?, ?it/s]

In [5]:
tokenizer = transformers.AutoTokenizer.from_pretrained("bert-base-cased")
print(tokenizer.vocab_size)

28996


In [6]:
def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True)

train_dataset = raw_dataset[0].map(tokenize_function, batched=True)
test_dataset = raw_dataset[0].map(tokenize_function, batched=True)

  0%|          | 0/25 [00:00<?, ?ba/s]

  0%|          | 0/25 [00:00<?, ?ba/s]

In [7]:
test = next(iter(train_dataset))

In [11]:
test.keys()

dict_keys(['attention_mask', 'input_ids', 'label', 'text', 'token_type_ids'])

In [10]:
def dynamic_padding_collator(batch):
    """Dynamically pads a batch, creates mask and masks out padded labels.
    Given batch_size sequences of different lengths, computes
    k = log2(batch_size). All sequences are padded/truncated to the length 
    of the kth longest sequence. Additionally, masks are constructed based 
    on the new dynamic lengths and padded tokens are masked out within labels.
    Code Based on:
    https://gist.github.com/pommedeterresautee/1a334b665710bec9bb65965f662c94c8
    https://huggingface.co/transformers/_modules/transformers/data/data_collator.html#default_data_collator
    Args:
      batch: (list) Each element is a dict {'input_ids': tokens} where tokens
        is a list of token indices coreesponding to a single sample.
    Returns:
      batch: (dict) Contains 3 elements: 
        input_ids: (torch.tensor) Padded input indices.
        attention_mask: (torch.tensor) Mask that takes the value 1 for valid
          tokens and 0 for padded ones.
        labels: (torch.tensor) Same as input_ids but padded tokens are replaced
          with -100.
    """
    lengths = [sum(sample['attention_mask']) for sample in batch]
    max_len = max(lengths)
    input_ids = [sample['input_ids'][:max_len] for sample in batch]
    input_ids = torch.tensor(input_ids).long()
    attention_mask = [sample['attention_mask'][:max_len] for sample in batch]
    attention_mask = torch.tensor(attention_mask).long()
    labels = [sample['label'] for sample in batch]
    labels = torch.tensor(labels).long()
    lengths = torch.tensor(lengths).long()

    batch = {
        'input_ids': input_ids, 
        'attention_mask': attention_mask, 
        'lengths': lengths,
        'labels': labels
    }
    return batch


1

# Models

In [53]:
class MHA(nn.Module):
    """Heart of https://arxiv.org/pdf/2006.16236.pdf
    """

    def __init__(self, d_model, n_heads, use_cos, kernel, 
                 dropout, denom_eps, bias):
        super(MHA, self).__init__()
        assert d_model % n_heads == 0, 'd_model must be a multiple of n_heads'
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_head = self.d_model // self.n_heads
        self.denom_eps = denom_eps

        if kernel == 'relu':
            self.kernel = self.relu_kernel
        elif kernel == 'elu':
            self.kernel = self.elu_kernel
        else:
            raise NotImplementedError(
                "The only options for 'kernel' are 'relu and 'elu'.")
            
        if use_cos:
            self.attention_func = self.cos_linear_attention
        else:
            self.attention_func = self.linear_attention

        self.w_qkv = nn.Linear(d_model, 3 * d_model, bias=bias)
        self.w_o = nn.Linear(d_model, d_model, bias=bias)
        self.dropout = nn.Dropout(dropout)

    def apply_mask(self, x, mask):
        # x -> [batch_size, seq_len, _]
        # mask -> [batch_size, seq_len, 1] or None
        if not mask is None:
            #x.masked_fill_(~mask, 0)
            x = x.masked_fill(~mask, 0)
        return x

    def split_heads(self, x):
        batch_size, seq_len = x.shape[:2]
        # x -> [batch_size, seq_len, d_model]
        x = x.view(batch_size, seq_len, self.n_heads, self.d_head)
        # x -> [batch_size, seq_len, n_heads, d_head]
        return x

    def join_heads(self, x):
        batch_size, seq_len = x.shape[:2]
        # x -> [batch_size, seq_len, n_heads, d_head]
        x = x.view(batch_size, seq_len, self.d_model).contiguous()
        # x -> [batch_size, seq_len, d_model]
        return x

    def elu_kernel(self, x):
        return F.elu(x) + 1

    def relu_kernel(self, x):
        return F.relu(x)

    def linear_attention(self, q, k, v, weights=None):
        # stolen from 
        # https://github.com/tensorflow/models/blob/master/official/nlp/modeling/layers/kernel_attention.py
        # q, k, v -> [batch_size, seq_len, n_heads, d_head]
        kv = torch.einsum('bsnx,bsnz->bnxz', k, v)
        # kv -> [batch_size, n_heads, d_head, d_head]
        # add dropout here
        denominator = 1.0 / (torch.einsum('bsnd,bnd->bsn', q, k.sum(axis=1)) + self.denom_eps)
        # denominator -> [batch_size, seq_len, n_heads]

        output = torch.einsum('bsnx,bnxz,bsn->bsnz', q, kv, denominator).contiguous()
        # output -> [batch_size, seq_len, n_heads, d_head]

        return output

    def cos_linear_attention(self, q, k, v, weights):
        # q, k, v -> [batch_size, seq_len, n_heads, d_head]
        cos, sin = weights
        # cos, sin -> [batch_size, seq_len]
        q_cos = torch.einsum('bsnd,bs->bsnd', q, cos)
        q_sin = torch.einsum('bsnd,bs->bsnd', q, sin)
        k_cos = torch.einsum('bsnd,bs->bsnd', k, cos)
        k_sin = torch.einsum('bsnd,bs->bsnd', k, sin)
        # q_cos, q_sin, k_cos, k_sin -> [batch_size, seq_len, n_heads, d_head]

        kv_cos = torch.einsum('bsnx,bsnz->bnxz', k_cos, v)
        # kv_cos -> [batch_size, n_heads, d_head, d_head]
        qkv_cos = torch.einsum('bsnx,bnxz->bsnz', q_cos, kv_cos)
        # qkv_cos -> [batch_size, seq_len, n_heads, d_head]

        kv_sin = torch.einsum('bsnx,bsnz->bnxz', k_sin, v)
        # kv_sin -> [batch_size, n_heads, d_head, d_head]
        qkv_sin = torch.einsum('bsnx,bnxz->bsnz', q_sin, kv_sin)
        # qkv_sin -> [batch_size, seq_len, n_heads, d_head]

        # denominator
        denominator = 1.0 / (torch.einsum('bsnd,bnd->bsn', q_cos, k_cos.sum(axis=1)) 
            + torch.einsum('bsnd,bnd->bsn', q_sin, k_sin.sum(axis=1))
            + self.denom_eps)
        # denominator -> [batch_size, seq_len, n_heads]
        
        output = torch.einsum('bsnz,bsn->bsnz', qkv_cos + qkv_sin, denominator).contiguous()
        # output -> [batch_size, seq_len, n_heads, d_head]
        return output        

    def forward(self, x, mask, weights):
        # x -> [batch_size, seq_len, d_model]
        # mask -> [batch_size, seq_len, 1] or None
        q, k, v = torch.chunk(self.w_qkv(x), 3, -1) 
        # q, k, v -> [batch_size, seq_len, d_model]

        q = self.kernel(self.split_heads(q))
        k = self.kernel(self.split_heads(k))
        #v = self.apply_mask(self.split_heads(v), mask)
        v = self.split_heads(self.apply_mask(v, mask))
        # q, k, v -> [batch_size, seq_len, n_heads, d_head]

        x = self.attention_func(q, k, v, weights)
        # x -> [batch_size, seq_len, n_heads, d_head]
        x = self.join_heads(x)
        x = self.dropout(self.w_o(x))
        # x -> [batch_size, seq_len, d_model]

        return x


class FFN(nn.Module):

    def __init__(self, d_model, ffn_ratio, dropout, bias):
        super(FFN, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(d_model, ffn_ratio * d_model, bias=bias),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(ffn_ratio * d_model, d_model),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        # x -> [batch_size, seq_len, d_model]
        x = self.layers(x)
        # x -> [batch_size, seq_len, d_model]
        return x


class MHA_block(nn.Module):
    """
    Implements the Pre-LN Architecture as suggested here:
    https://arxiv.org/pdf/2002.04745.pdf
    """

    def __init__(self, d_model, n_heads, use_cos, kernel, dropout, 
                 ffn_ratio, ln_eps, denom_eps, bias, rezero):

        super(MHA_block, self).__init__()
        self.ln1 = nn.LayerNorm(d_model, eps=ln_eps)
        self.ln2 = nn.LayerNorm(d_model, eps=ln_eps)

        self.mha = MHA(
            d_model, n_heads, use_cos, kernel, dropout, denom_eps, bias)  
        self.ffn = FFN(d_model, ffn_ratio, dropout, bias)

        # ReZero is All You Need
        # https://arxiv.org/pdf/2003.04887.pdf
        # https://github.com/majumderb/rezero
        if rezero:
            self.alpha = nn.Parameter(torch.Tensor([0]))
        else:
            self.register_buffer('alpha', torch.Tensor([1]))

    def forward(self, x, mask, weights):

        # x -> [batch_size, seq_len, d_model]
        fx = self.alpha * self.mha(self.ln1(x), mask, weights)
        x = x + fx

        fx = self.alpha * self.ffn(self.ln2(x))
        x = x + fx

        return x


class Positional_embeddings(nn.Module):
    """
    Stolen from:
    https://github.com/lucidrains/linear-attention-transformer/blob/master/linear_attention_transformer/linear_attention_transformer.py
    """
    def __init__(self, d_model, max_len):
        super(Positional_embeddings, self).__init__()
        inv_freq = 1. / (10000 ** (torch.arange(0, d_model, 2).float() / d_model))
        position = torch.arange(0, max_len, dtype=torch.float)
        sinusoid_inp = torch.einsum("i,j->ij", position, inv_freq)
        emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1)
        self.register_buffer('emb', emb)

    def forward(self, x):
        return self.emb[None, :x.shape[1], :].to(x)

class Segment_embeddings(nn.Module):
    """
    Stolen from:
    https://github.com/dreamgonfly/BERT-pytorch/blob/master/bert/train/model/embeddings.py
    """
    def __init__(self, d_model):
        super(Segment_embeddings, self).__init__()
        self.emb = nn.Embedding(2, d_model)

    def forward(self, x):
        return self.emb(x)


class Kernel_transformer(nn.Module):

    def __init__(self, d_model, n_heads, use_cos, kernel, dropout, 
                 ffn_ratio, n_layers, n_emb, tie_emb, ln_eps, denom_eps, 
                 bias, rezero, for_clf, n_classes=None, max_len=1024, xavier=True):
        super(Kernel_transformer, self).__init__()

        self.for_clf = for_clf
        self.use_cos = use_cos
        self.n_classes = n_classes

        self.emb_in = nn.Embedding(n_emb, d_model)
        if self.for_clf:
            if self.n_classes == 2: self.n_classes = 1
            self.emb_out = nn.Linear(d_model, self.n_classes)
        else:
            self.emb_out = nn.Linear(d_model, n_emb)
        # Tie input & output embeddings as in https://arxiv.org/abs/1608.05859
        if not self.for_clf and tie_emb:
            self.emb_out.weight = self.emb_in.weight

        if not self.use_cos:
            self.emb_pos = Positional_embeddings(d_model, max_len)
        #self.emb_seg = Segment_embeddings(d_model)

        self.mha_blocks = nn.ModuleList(
            [MHA_block(
                d_model, n_heads, use_cos, kernel, dropout, 
                ffn_ratio, ln_eps, denom_eps, bias, rezero
                ) for _ in range(n_layers)]
        )

        if self.for_clf and  self.n_classes == 1:
            self.loss_fn = F.binary_cross_entropy_with_logits
        else:
            self.loss_fn = F.cross_entropy
        
        # Trick to get model device. Stolen from:
        # https://stackoverflow.com/questions/58926054/how-to-get-the-device-type-of-a-pytorch-module-conveniently
        self.dummy_param = nn.Parameter(torch.empty(0))

        if xavier:
            self.init_xavier_uniform()

    def dev(self):
        """Returns the device where the model is stored"""
        return self.dummy_param.device

    def init_xavier_uniform(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                torch.nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    torch.nn.init.zeros_(m.bias)

    def get_mask(self, lens, max_len):
        # lens -> [batch_size]
        mask = torch.arange(max_len)[None, :].to(lens) < lens[:, None]
        # mask -> [batch_size, max_len]
        return mask

    def get_cos_weights(self, lengths, max_len=None):
        # lengths -> [batch_size]
        if max_len is None:
            max_len = lengths.max()
        # For each sample x in the batch, calculate M(x) = len(x)
        M = lengths
        # M -> [batch_size]
        idxs = math.pi / 2 * torch.arange(max_len).to(lengths)
        # idxs -> [max_len]
        idxs = torch.outer(M, idxs)#[..., None, None]
        # idxs -> [batch_size, max_len]

        cos = torch.cos(idxs).detach()
        sin = torch.sin(idxs).detach()
        # cos, sin -> [batch_size, max_len, 1, 1]

        return cos, sin

    def forward(self, input_ids, labels, attention_mask=None, lengths=None):
        # input_idxs -> [batch_size, seq_len]
        # labels -> [batch_size]
        # attention_mask -> [batch_size, max_len] or None
        # lengths -> [batch_size]
        # weights ->  (tuple 2 X [batch_size, seq_len]) or None

        if lengths is None:
            lengths = torch.full(
                [input_ids.shape[0]], input_ids.shape[0], device=input_ids.device)
            
        input_ids = input_ids[:, :lengths.max()]

        if not attention_mask is None:
            attention_mask = attention_mask[:, :lengths.max()]
            attention_mask = attention_mask[..., None]
        # attention_mask -> [batch_size, max_len, 1] or None
            
        if self.use_cos:
            cos_weights = self.get_cos_weights(lengths, lengths.max())
        else:
            cos_weights = None

        x = self.emb_in(input_ids)
        if not self.use_cos:
            x += self.emb_pos(x)
        # x += self.emb_seg(seg_idxs)

        for block in self.mha_blocks:
            x = block(x, attention_mask, cos_weights)

        if self.for_clf: x = x[:, 0, :]
        x = self.emb_out(x)
        if self.n_classes == 1: x = x.squeeze(-1)
        loss = self.loss_fn(x, labels)
        return loss, x





In [117]:
class Baseline_transformer(nn.Module):

    def __init__(self, **kwargs):
        super(Baseline_transformer, self).__init__()

        self.for_clf = kwargs['for_clf']
        self.n_classes = kwargs['n_classes']

        self.emb_in = nn.Embedding(kwargs['n_emb'], kwargs['d_model'])
        if self.for_clf:
            if self.n_classes == 2: self.n_classes = 1
            self.emb_out = nn.Linear(kwargs['d_model'], self.n_classes)
        else:
            self.emb_out = nn.Linear(kwargs['d_model'], kwargs['n_emb'])
        # Tie input & output embeddings as in https://arxiv.org/abs/1608.05859
        if not self.for_clf and kwargs['tie_emb']:
            self.emb_out.weight = self.emb_in.weight

        self.emb_pos = Positional_embeddings(kwargs['d_model'], kwargs['max_len'])

        self.mha_blocks = nn.ModuleList([])
        for _ in range(kwargs['n_layers']):
            block = nn.TransformerEncoderLayer(
                d_model=kwargs['d_model'], 
                nhead=kwargs['n_heads'], 
                dim_feedforward=kwargs['d_model'] * kwargs['ffn_ratio'], 
                dropout=kwargs['dropout'], 
                activation='gelu', 
                layer_norm_eps=kwargs['ln_eps'], 
                batch_first=True, 
                #norm_first=True,
            )
            self.mha_blocks.append(block)

        if self.for_clf and  self.n_classes == 1:
            self.loss_fn = F.binary_cross_entropy_with_logits
        else:
            self.loss_fn = F.cross_entropy

        if kwargs['xavier']:
            self.init_xavier_uniform()

    def init_xavier_uniform(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                torch.nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    torch.nn.init.zeros_(m.bias)

    def forward(self, input_ids, labels, attention_mask=None, lengths=None):
        # input_idxs -> [batch_size, seq_len]
        # labels -> [batch_size]
        # attention_mask -> [batch_size, max_len] or None
        # lengths -> [batch_size]
        # weights ->  (tuple 2 X [batch_size, seq_len]) or None

        if lengths is None:
            lengths = torch.full(
                [input_ids.shape[0]], input_ids.shape[0], device=input_ids.device)
            
        input_ids = input_ids[:, :lengths.max()]

        if not attention_mask is None:
            attention_mask = ~attention_mask[:, :lengths.max()]
            # attention_mask = attention_mask[..., None]
        # attention_mask -> [batch_size, max_len, 1] or None

        x = self.emb_in(input_ids)
        x += self.emb_pos(x)

        for block in self.mha_blocks:
            x = block(x, src_key_padding_mask=attention_mask)
        return x

        if self.for_clf: x = x[:, 0, :]
        x = self.emb_out(x)
        if self.n_classes == 1: x = x.squeeze(-1)
        loss = self.loss_fn(x, labels)
        return loss, x


In [84]:
args = {
    'd_model': 16,
    'n_heads': 2,
    'use_cos': True,
    'kernel': 'relu',
    'dropout': 0.2,
    'ffn_ratio': 4,
    'n_layers': 2,
    'n_emb': 1000,
    'tie_emb': True,
    'ln_eps': 1e-5,
    'denom_eps': 1e-6,
    'bias': False,
    'rezero': True,
    'for_clf': True,
    'n_classes': 1,
    'max_len': 1024,
    'xavier': True,
}

seq_len = 420
batch_size = 4





In [64]:
model = Kernel_transformer(**args)


In [65]:
input_ids = torch.randint(0, n_emb, [batch_size, seq_len])
labels = torch.randint(0, 2, [batch_size]).float()
lengths = torch.randint(1, seq_len, [batch_size])
attention_mask = model.get_mask(lengths, seq_len)
with torch.no_grad()
    output = model(input_ids=input_ids, labels=labels, 
                  attention_mask=attention_mask, lengths=lengths)

In [66]:
output

(tensor(0.5428, grad_fn=<BinaryCrossEntropyWithLogitsBackward>),
 tensor([2.8745, 1.0506, 1.4905, 2.0364], grad_fn=<SqueezeBackward1>))

In [118]:
baseline_model = Baseline_transformer(**args)

In [119]:
input_ids = torch.randint(0, n_emb, [batch_size, seq_len])
labels = torch.randint(0, 2, [batch_size]).float()
lengths = torch.randint(1, seq_len, [batch_size])
attention_mask = model.get_mask(lengths, seq_len)
with torch.no_grad():
    output = baseline_model(input_ids=input_ids, labels=labels, 
                  attention_mask=attention_mask, lengths=lengths)

In [120]:
output

tensor([[[ 1.0545, -1.8327, -0.9058,  ..., -1.2681,  0.2465,  0.9357],
         [-0.6543,  2.1857,  0.9933,  ..., -0.6056,  0.1224, -0.5667],
         [-0.8128,  1.1857, -0.2744,  ..., -1.7245,  1.0190, -0.7364],
         ...,
         [-1.6287,  1.0857, -0.3737,  ...,  0.1775, -0.1073,  0.7185],
         [-1.8943,  0.3794, -0.4432,  ...,  0.0985,  0.9207, -0.2958],
         [-1.6188,  0.3864, -0.2285,  ...,  0.4092,  1.2253,  1.5361]],

        [[ 0.0352, -2.1191, -1.1612,  ...,  0.1259, -0.2450,  0.2202],
         [-0.0933, -1.9386,  0.8000,  ..., -1.0996, -0.5588,  1.7262],
         [-0.0720,  0.2683, -2.6102,  ...,  0.5064,  1.2449, -1.0605],
         ...,
         [-2.7520,  0.4956,  1.0429,  ..., -1.0408,  0.5911,  0.2282],
         [-0.5053, -0.1233,  0.1593,  ..., -0.3396,  2.2025,  0.1256],
         [-1.2218,  0.6287, -0.4916,  ...,  0.0296,  0.6222,  0.4544]],

        [[-0.1918, -0.1028, -0.4285,  ...,  0.0071,  0.4986,  1.0209],
         [-0.2433, -2.0299,  1.1185,  ..., -0

In [116]:
print(attention_mask)
print(~attention_mask)

tensor([[ True,  True,  True,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False]])
tensor([[False, False, False,  ...,  True,  True,  True],
        [False, False, False,  ...,  True,  True,  True],
        [False, False, False,  ...,  True,  True,  True],
        [False, False, False,  ...,  True,  True,  True]])


In [None]:
output.shape

torch.Size([4, 1000, 100])

In [None]:
batch = torch.rand((batch_size, seq_len, d_model))
print(batch.shape)
lens = torch.tensor([1, 3, 2, 1])
mask = model.get_mask(lens, seq_len)
print(mask.shape)


torch.Size([4, 3, 16])
torch.Size([4, 3, 1])


In [None]:
output = model(batch, mask)
print(output.shape)

torch.Size([4, 3, 16])


In [None]:
masked_batch = model.apply_mask(batch, mask)
print(masked_batch)

tensor([[[0.9638, 0.2398, 0.0241, 0.4604, 0.7244, 0.7315, 0.2613, 0.4878,
          0.9803, 0.0258, 0.0206, 0.0219, 0.2968, 0.4857, 0.6981, 0.2027],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],

        [[0.7327, 0.2265, 0.6460, 0.1862, 0.9855, 0.8580, 0.3619, 0.7935,
          0.0060, 0.6024, 0.0805, 0.4093, 0.0561, 0.8490, 0.9586, 0.4716],
         [0.7183, 0.1173, 0.7593, 0.5132, 0.6751, 0.1955, 0.3137, 0.3406,
          0.7958, 0.2284, 0.6372, 0.4300, 0.2027, 0.9165, 0.8389, 0.8188],
         [0.9290, 0.8413, 0.5069, 0.0595, 0.0059, 0.0356, 0.4498, 0.0824,
          0.1026, 0.3432, 0.1345, 0.4369, 0.1351, 0.6864, 0.0215, 0.5408]],

        [[0.4006, 0.1905, 0.7883, 0.8001, 0.1851, 0.8474, 0.8908, 0.5399,
          0.1389, 0.4567, 0.

In [None]:
lens = torch.tensor([5, 8, 2])
torch.arange(lens.max(), device=self.dev())[None, :] < lens[:, None]

tensor([[ True,  True,  True,  True,  True, False, False, False],
        [ True,  True,  True,  True,  True,  True,  True,  True],
        [ True,  True, False, False, False, False, False, False]])

In [None]:
batch_size = 2
seq_len = 4
d_model = 5

x = torch.arange(batch_size * seq_len * 3 * d_model).reshape(batch_size, seq_len, 3 * d_model)

y = torch.chunk(x, 3, -1)

print(x)

print(y)



In [None]:
a, b = x.shape[:2]
type(a)

int