    Scaled dot product

In [1]:
## tqdm for loading bars
from tqdm.notebook import tqdm
from sliding_chunks import sliding_chunks_matmul_pv, sliding_chunks_matmul_qk

# native
import os
import math

## PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim

## Torchvision
import torchvision
from torchvision.datasets import CIFAR100
from torchvision import transforms

import pytorch_lightning as pl


def scaled_dot_product(q, k, v, mask=None, w=16):
    d_k = q.size()[-1]
    # attn_logits = torch.matmul(q, k.transpose(-2, -1))
    # [bsz, seqlen, num_heads, head_dim]
    attn_logits = sliding_chunks_matmul_qk(q=q, k=k, w=w, padding_value=0)
    attn_logits = attn_logits / math.sqrt(d_k)
    if mask is not None:
        attn_logits = attn_logits.masked_fill(mask == 0, -9e15)
    attention = F.softmax(attn_logits, dim=-1)
    # TODO add dropout
    # values = torch.matmul(attention, v)
    values = sliding_chunks_matmul_pv(prob=attention, v=v, w=w)
    return values, attention

  Referenced from: '/Users/derekleung/miniforge3/envs/xerini/lib/python3.10/site-packages/torchvision/image.so'
  Expected in: '/Users/derekleung/miniforge3/envs/xerini/lib/python3.10/site-packages/torch/lib/libtorch_cpu.dylib'
  warn(f"Failed to load image Python extension: {e}")


In [2]:
# seq_len, d_k = 1024, 16
# pl.seed_everything(42)
# q = torch.randn(seq_len, d_k)
# k = torch.randn(seq_len, d_k)
# v = torch.randn(seq_len, d_k)k
# values, attention = scaled_dot_product(q, k, v)
# print("Q shape", q.shape, "\n", q)
# print("K shape", k.shape, "\n", k)
# print("V shape", v.shape, "\n", v)
# print("Values\n", values)
# print("Attention\n", attention)

In [3]:
class MultiheadAttention(nn.Module):

    def __init__(self, input_dim, embed_dim, num_heads):
        super().__init__()
        assert embed_dim % num_heads == 0, "Embedding dimension must be 0 modulo number of heads."

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        # Stack all weight matrices 1...h together for efficiency
        # Note that in many implementations you see "bias=False" which is optional
        self.qkv_proj = nn.Linear(input_dim, 3 * embed_dim)
        self.o_proj = nn.Linear(embed_dim, embed_dim)
        self._reset_parameters()

    def _reset_parameters(self):
        # Original Transformer initialization, see PyTorch documentation
        nn.init.xavier_uniform_(self.qkv_proj.weight)
        self.qkv_proj.bias.data.fill_(0)
        nn.init.xavier_uniform_(self.o_proj.weight)
        self.o_proj.bias.data.fill_(0)

    def forward(self, x, mask=None, return_attention=False):
        batch_size, seq_length, _ = x.size()
        # project to [batch,seqlen->3*emb_dim(n_heads*head_dim)]
        qkv = self.qkv_proj(x)
        # project to [bz, seqlen n_head, head_dim]
        qkv = qkv.reshape(batch_size, seq_length, self.num_heads, 3 * self.head_dim)
        # [Batch, Head, SeqLen, Dims]
        # qkv = qkv.permute(0, 2, 1, 3)
        q, k, v = qkv.chunk(3, dim=-1)
        values, attention = scaled_dot_product(q, k, v, mask=mask)
        # values = values.permute(0, 2, 1, 3)  # [Batch, SeqLen, Head, Dims]
        # reshape to (bsz, seqlen, emb_dim(n_head*head_dim))
        values = values.reshape(batch_size, seq_length, self.embed_dim)
        o = self.o_proj(values)
        if return_attention:
            return o, attention
        else:
            return o

Examples

In [4]:
import math

import torch
import torch.nn as nn
from sliding_chunks import sliding_chunks_matmul_pv, sliding_chunks_matmul_qk

# Calculate the attention matrix
bsz = 64
seqlen = 1024
input_dim = 1024
embed_dim = 128
num_heads = 8
head_dim = embed_dim // num_heads

# 1. input sequence (BxSxI)
input_seq = torch.randn((bsz, seqlen, input_dim))
assert input_seq.shape == (bsz, seqlen, input_dim)
# 2. QKV projection (Ix3Emb)
qkv_proj = nn.Linear(input_dim, 3 * embed_dim)
assert qkv_proj(input_seq).shape == (bsz, seqlen, 3 * embed_dim)
# 3. QKV projection (BxSx3Emb)
qkv = qkv_proj(input_seq).reshape(bsz, seqlen, num_heads, 3 * head_dim)
assert qkv.shape == (bsz, seqlen, num_heads, 3 * head_dim)
# qkv = qkv.permute(0, 2, 1, 3)
# 4. Split the chunk (BxSxHx3HD->HD)
q, k, v = qkv.chunk(3, dim=-1)
assert q.shape == (bsz, seqlen, num_heads, head_dim)
# 5. Scaled dot product (BxSxHx257
w = 16
attn_logits = sliding_chunks_matmul_qk(q=q, k=k, w=w, padding_value=0)
assert attn_logits.shape == (bsz, seqlen, num_heads, 2 * w + 1), attn_logits.shape
# attn_logits = torch.matmul(q, k.transpose(-2, -1))
# 6. divide by number of heads
attn_logits = attn_logits / math.sqrt(head_dim)
# if mask is not None:
#     attn_logits = attn_logits.masked_fill(mask == 0, -9e15)1
# 7. QKV multiplications
values = sliding_chunks_matmul_pv(attn_logits, v, w=w)
assert values.shape == (bsz, seqlen, num_heads, head_dim)

Stack the Transformer encoder

In [5]:
class EncoderBlock(nn.Module):

    def __init__(self, input_dim, num_heads, dim_feedforward, dropout=0.0):
        super().__init__()

        # Attention layer
        self.self_attn = MultiheadAttention(input_dim, input_dim, num_heads)

        # Two-layer MLP
        self.linear_net = nn.Sequential(
            nn.Linear(input_dim, dim_feedforward),
            nn.Dropout(dropout),
            nn.ReLU(inplace=True),
            nn.Linear(dim_feedforward, input_dim)
        )

        # Layers to apply in between the main layers
        self.norm1 = nn.LayerNorm(input_dim)
        self.norm2 = nn.LayerNorm(input_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        # Attention part
        attn_out = self.self_attn(x, mask=mask)
        x = x + self.dropout(attn_out)
        x = self.norm1(x)

        # MLP part
        linear_out = self.linear_net(x)
        x = x + self.dropout(linear_out)
        x = self.norm2(x)

        return x



inp = torch.randn(bsz, seqlen, input_dim)
enc = EncoderBlock(input_dim, num_heads, 128)
assert enc(inp).shape == (bsz, input_dim, input_dim)

In [6]:
class TransformerEncoder(nn.Module):

    def __init__(self, num_layers, **block_args):
        super().__init__()
        self.layers = nn.ModuleList([EncoderBlock(**block_args) for _ in range(num_layers)])

    def forward(self, x, mask=None):
        for l in self.layers:
            x = l(x, mask=mask)
        return x

    def get_attention_maps(self, x, mask=None):
        attention_maps = []
        for l in self.layers:
            _, attn_map = l.self_attn(x, mask=mask, return_attention=True)
            attention_maps.append(attn_map)
            x = l(x)
        return attention_maps

transformer = TransformerEncoder(num_layers=2,
                                      input_dim=input_dim,
                                      dim_feedforward=128,
                                      num_heads=8,
                                      dropout=0.3)

assert transformer(inp).shape == (bsz, input_dim, input_dim)

In [21]:
import numpy as np

class CosineWarmupScheduler(optim.lr_scheduler._LRScheduler):
    def __init__(self, optimizer, warmup, max_iters):
        self.warmup = warmup
        self.max_num_iters = max_iters
        super().__init__(optimizer)

    def get_lr(self):
        lr_factor = self.get_lr_factor(epoch=self.last_epoch)
        return [base_lr * lr_factor for base_lr in self.base_lrs]

    def get_lr_factor(self, epoch):
        lr_factor = 0.5 * (1 + np.cos(np.pi * epoch / self.max_num_iters))
        if epoch <= self.warmup:
            lr_factor *= epoch * 1.0 / self.warmup
        return lr_factor

class PositionalEncoding(nn.Module):

    def __init__(self, d_model, max_len=5000):
        super().__init__()
        # Create matrix of [SeqLen, HiddenDim] representing the positional encoding for max_len inputs
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        # register_buffer => Tensor which is not a parameter, but should be part of the modules state.
        # Used for tensors that need to be on the same device as the module.
        # persistent=False tells PyTorch to not add the buffer to the state dict (e.g. when we save the model)
        self.register_buffer('pe', pe, persistent=False)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1)]
        return x

class TransformerPredictor(pl.LightningModule):

    def __init__(self, input_dim, model_dim, num_classes, num_heads, num_layers, lr, warmup, max_iters, dropout=0.0, input_dropout=0.0):
        super().__init__()
        self.save_hyperparameters()
        self._create_model()

    def _create_model(self):
        self.emb = nn.Embedding(self.hparams.input_dim, 128)
        self.input_net = nn.Sequential(
            nn.Dropout(self.hparams.input_dropout),
            nn.Linear(128, self.hparams.model_dim)
        )
        # Positional encoding for sequences
        self.positional_encoding = PositionalEncoding(d_model=self.hparams.model_dim)
        # Transformer
        self.transformer = TransformerEncoder(num_layers=self.hparams.num_layers,
                                              input_dim=self.hparams.model_dim,
                                              dim_feedforward=2*self.hparams.model_dim,
                                              num_heads=self.hparams.num_heads,
                                              dropout=self.hparams.dropout)
        # Output classifier per sequence element
        self.output_net = nn.Sequential(
            nn.Linear(self.hparams.model_dim, self.hparams.model_dim),
            nn.LayerNorm(self.hparams.model_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(self.hparams.dropout),
            nn.Linear(self.hparams.model_dim, self.hparams.num_classes)
        )

    def forward(self, x, mask=None, add_positional_encoding=True):
        # print(self.emb.weight.shape)
        print(x.shape)
        x = self.emb(x)
        print(x.shape)
        x = self.input_net(x)
        # if add_positional_encoding:
        #     x = self.positional_encoding(x)
        x = self.transformer(x, mask=mask)
        x = self.output_net(x)
        return x

    @torch.no_grad()
    def get_attention_maps(self, x, mask=None, add_positional_encoding=True):
        x = self.input_net(x)
        if add_positional_encoding:
            x = self.positional_encoding(x)
        attention_maps = self.transformer.get_attention_maps(x, mask=mask)
        return attention_maps

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=self.hparams.lr)
        lr_scheduler = CosineWarmupScheduler(optimizer,
                                             warmup=self.hparams.warmup,
                                             max_iters=self.hparams.max_iters)
        return [optimizer], [{'scheduler': lr_scheduler, 'interval': 'step'}]

    def training_step(self, batch, batch_idx):
        raise NotImplementedError

    def validation_step(self, batch, batch_idx):
        raise NotImplementedError

    def test_step(self, batch, batch_idx):
        raise NotImplementedError


class ReverseDataset(data.Dataset):

    def __init__(self, num_categories, seq_len, size):
        super().__init__()
        self.num_categories = num_categories
        self.seq_len = seq_len
        self.size = size

        self.data = torch.randint(self.num_categories, size=(self.size, self.seq_len))

    def __len__(self):
        return self.size

    def __getitem__(self, idx):
        inp_data = self.data[idx]
        labels = torch.flip(inp_data, dims=(0,))
        return inp_data, labels

class ReversePredictor(TransformerPredictor):
    def _calculate_loss(self, batch, mode="t`rain"):
        inp_data, labels = batch
        # inp_data = F.one_hot(inp_data, num_classes=self.hparams.num_classes).float()

        preds = self.forward(inp_data, add_positional_encoding=True)
        loss = F.cross_entropy(preds.view(-1,preds.size(-1)), labels.view(-1))
        acc = (preds.argmax(dim=-1) == labels).float().mean()

        self.log(f"{mode}_loss", loss)
        self.log(f"{mode}_acc", acc)
        return loss, acc

    def training_step(self, batch, batch_idx):
        loss, _ = self._calculate_loss(batch, mode="train")
        return loss

    def validation_step(self, batch, batch_idx):
        _ = self._calculate_loss(batch, mode="val")

    def test_step(self, batch, batch_idx):
        _ = self._calculate_loss(batch, mode="test")

In [16]:
from functools import partial

dataset = partial(ReverseDataset, 10, 32) # cat, len, size
train_loader = data.DataLoader(dataset(50000), batch_size=128, shuffle=True, drop_last=True, pin_memory=True)
val_loader   = data.DataLoader(dataset(1000), batch_size=128)
test_loader  = data.DataLoader(dataset(10000), batch_size=128)
inp_data, labels = train_loader.dataset[0]
print("Input data:", inp_data)
print("Labels:    ", labels)
print(next(iter(val_loader))[0].dtype)

Input data: tensor([9, 2, 4, 1, 1, 1, 4, 1, 4, 3, 4, 6, 6, 1, 2, 7, 8, 1, 3, 1, 0, 7, 5, 0,
        0, 4, 3, 0, 5, 2, 7, 4])
Labels:     tensor([4, 7, 2, 5, 0, 3, 4, 0, 0, 5, 7, 0, 1, 3, 1, 8, 7, 2, 1, 6, 6, 4, 3, 4,
        1, 4, 1, 1, 1, 4, 2, 9])
torch.int64


In [22]:
from pytorch_lightning.callbacks import ModelCheckpoint

DATASET_PATH = "../data"
CHECKPOINT_PATH = "../saved_models/tutorial6"

def train_reverse(**kwargs):
    root_dir = os.path.join(CHECKPOINT_PATH, "ReverseTask")
    os.makedirs(root_dir, exist_ok=True)
    trainer = pl.Trainer(default_root_dir=root_dir,
                         callbacks=[ModelCheckpoint(save_weights_only=True, mode="max", monitor="val_acc")],
                         devices=1,
                         max_epochs=10,
                         gradient_clip_val=5)
    trainer.logger._default_hp_metric = None

    model = ReversePredictor(max_iters=trainer.max_epochs*len(train_loader), **kwargs)
    trainer.fit(model, train_loader, val_loader)

    val_result = trainer.test(model, val_loader, verbose=False)
    test_result = trainer.test(model, test_loader, verbose=False)
    result = {"test_acc": test_result[0]["test_acc"], "val_acc": val_result[0]["test_acc"]}

    model = model.to("cpu")
    return model, result


reverse_model, reverse_result = train_reverse(input_dim=train_loader.dataset.num_categories,
                                              model_dim=32,
                                              num_heads=1,
                                              num_classes=train_loader.dataset.num_categories,
                                              num_layers=1,
                                              dropout=0.0,
                                              lr=5e-4,
                                              warmup=50)

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name                | Type               | Params
-----------------------------------------------------------
0 | emb                 | Embedding          | 1.3 K 
1 | input_net           | Sequential         | 4.1 K 
2 | positional_encoding | PositionalEncoding | 0     
3 | transformer         | TransformerEncoder | 8.5 K 
4 | output_net          | Sequential         | 1.4 K 
-----------------------------------------------------------
15.4 K    Trainable params
0         Non-trainable params
15.4 K    Total params
0.062     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])


  rank_zero_warn(


Training: 0it [00:00, ?it/s]

torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32]

Validation: 0it [00:00, ?it/s]

torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([104, 32])
torch.Size([104, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32]

Validation: 0it [00:00, ?it/s]

torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([104, 32])
torch.Size([104, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32]

Validation: 0it [00:00, ?it/s]

torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([104, 32])
torch.Size([104, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32]

Validation: 0it [00:00, ?it/s]

torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([104, 32])
torch.Size([104, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32]

Validation: 0it [00:00, ?it/s]

torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([104, 32])
torch.Size([104, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32]

Validation: 0it [00:00, ?it/s]

torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([104, 32])
torch.Size([104, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32]

Validation: 0it [00:00, ?it/s]

torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([104, 32])
torch.Size([104, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32]

Validation: 0it [00:00, ?it/s]

torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([104, 32])
torch.Size([104, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32]

Validation: 0it [00:00, ?it/s]

torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([104, 32])
torch.Size([104, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=10` reached.


torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([104, 32])
torch.Size([104, 32, 128])


  rank_zero_warn(


Testing: 0it [00:00, ?it/s]

torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([104, 32])
torch.Size([104, 32, 128])


Testing: 0it [00:00, ?it/s]

torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32])
torch.Size([128, 32, 128])
torch.Size([128, 32]

In [10]:
next(iter(val_loader))[0].shape # batch x seqlen

torch.Size([128, 32, 32]) # before linear
torch.Size([128, 32, 10]) # after linear

next(iter(val_loader))[0].shape

torch.Size([128, 32])

In [11]:
inp = torch.randint(10, size=(128, 32)).float()
lin = nn.Linear(32, 32)
lin(inp).shape

torch.Size([128, 32])

In [12]:
inp = torch.randint(128, size=(128, 32)) # batch x seqlen
emb = nn.Embedding(128, 32) # batch x seqlen x input dimension
inp.shape, emb(inp).shape

(torch.Size([128, 32]), torch.Size([128, 32, 32]))