In [10]:
import torch
import torch.nn as nn
from torchinfo import summary
from tqdm import tqdm
from nn_zoo.models.components import SelfAttention

In [11]:
with open("shakespeare.txt", "r") as f:
    text = f.read()

vocab = sorted(set(text))
char_to_idx = {char: idx for idx, char in enumerate(vocab)}
idx_to_char = {idx: char for char, idx in char_to_idx.items()}
text = torch.tensor([char_to_idx[char] for char in text], dtype=torch.long)
print(f"Vocab size: {len(vocab):,}")

seq_len = 64
split = 0.2

train_len = int((1 - split) * len(text))

train_text = text[:train_len]
val_text = text[train_len:]
print(f"Train size: {len(train_text):,}")
print(f"Val size: {len(val_text):,}")

train_text = train_text[:len(train_text) // seq_len * seq_len].view(-1, seq_len)
val_text = val_text[:len(val_text) // seq_len * seq_len].view(-1, seq_len)
print(f"Train size: {train_text.shape}")
print(f"Val size: {val_text.shape}")

train_loader = torch.utils.data.DataLoader(
    torch.utils.data.TensorDataset(train_text),
    batch_size=128,
    shuffle=True,
    num_workers=2,
    persistent_workers=True,
    pin_memory=True,

)
val_loader = torch.utils.data.DataLoader(
    torch.utils.data.TensorDataset(val_text),
    batch_size=128,
    shuffle=False,
    num_workers=2,
    persistent_workers=True,
    pin_memory=True,
)

Vocab size: 65
Train size: 892,314
Val size: 223,079
Train size: torch.Size([13942, 64])
Val size: torch.Size([3485, 64])


In [12]:
def encode(text: str) -> torch.Tensor:
    return torch.tensor([char_to_idx[char] for char in text], dtype=torch.long)

def decode(tensor: torch.Tensor) -> str:
    tensor = tensor.tolist()
    return "".join([idx_to_char[idx] for idx in tensor])

encoded = encode("Hello, World!")
decoded = decode(encoded)
assert decoded == "Hello, World!"

In [13]:
device = "cpu"
if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"

print(f"Using device: {device}")

@torch.no_grad()
def evaluate(model, data_loader):
    model.eval()
    total_loss = 0
    for inputs, in data_loader:
        inputs = inputs.to(device)

        x = inputs[:, :-1]
        y = inputs[:, 1:]

        y_pred = model(x)
        loss = model.loss(y_pred, y)
        total_loss += loss
       
    return total_loss.item() / len(data_loader)

Using device: mps


In [14]:
from einops.layers.torch import Rearrange
import torch.nn.functional as F

class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, emb_dim):
        super().__init__()
        self.num_heads = num_heads
        self.emb_dim = emb_dim

        self.heads = nn.ModuleList([
            SelfAttention(emb_dim, num_heads) for _ in range(num_heads)
        ])

        self.fc = nn.Linear(emb_dim * num_heads, emb_dim)

    def forward(self, x):
        heads = [head(x) for head in self.heads]
        x = torch.cat(heads, dim=-1)
        x = self.fc(x)
        return x

class TransformerBlock(nn.Module):
    def __init__(self, emb_dim, num_heads):
        super().__init__()
        # self.attention = MultiHeadAttention(num_heads, emb_dim)
        self.attention = nn.Sequential(
            Rearrange("b s e -> b e s"),
            nn.Conv1d(emb_dim, emb_dim, 3, padding=1),
            nn.GELU(),
            nn.Conv1d(emb_dim, emb_dim, 3, padding=1),
            nn.GELU(),
            nn.Conv1d(emb_dim, emb_dim, 3, padding=1),
            nn.GELU(),
            nn.Conv1d(emb_dim, emb_dim, 3, padding=1),
            Rearrange("b e s -> b s e"),
        )
        self.ln1 = nn.LayerNorm(emb_dim)
        self.ln2 = nn.LayerNorm(emb_dim)

        self.fc = nn.Sequential(
            nn.Linear(emb_dim, 1 * emb_dim),
            nn.GELU(),
            nn.Linear(1 * emb_dim, emb_dim),
        )

    def forward(self, x):
        x = x + self.attention(self.ln1(x))
        x = x + self.fc(self.ln2(x))
        return x

class Model(nn.Module):
    def __init__(self, emb_dim):
        super().__init__()
        self.tok_emb = nn.Embedding(len(vocab), emb_dim)
        self.pos_emb = nn.Embedding(seq_len, emb_dim)

        self.blocks = nn.ModuleList([
            TransformerBlock(emb_dim, 1) for _ in range(1)
        ])

        self.lm_head = nn.Linear(emb_dim, len(vocab))

    def forward(self, x):
        tok_emb = self.tok_emb(x)
        pos_emb = self.pos_emb(torch.arange(x.shape[1], device=x.device))[None, :, :]
        
        x = tok_emb + pos_emb

        for block in self.blocks:
            x = block(x)

        x = self.lm_head(x)
        return x
    
    def loss(self, y_pred, y):
        return torch.nn.functional.cross_entropy(y_pred.reshape(-1, y_pred.shape[-1]), y.reshape(-1), ignore_index=-1)
    
    @torch.no_grad()
    def generate(self, start: list[int] | torch.Tensor | None = None, max_len: int = 100, temperature: float = 1.0, top_k: int = 0):
        if start is None:
            start = torch.randint(len(vocab), (1, 1), device=device)
        elif isinstance(start, list):
            start = torch.tensor(start, dtype=torch.long, device=device).unsqueeze(0)
        x = start

        for _ in tqdm(range(max_len)):
            y_pred = self(x)
            y_pred = y_pred[:, -1, :] / temperature
            if top_k > 0:
                y_pred = torch.topk(y_pred, top_k, dim=-1).values
            next_char = torch.multinomial(torch.nn.functional.softmax(y_pred, dim=-1), 1)
            x = torch.cat([x, next_char], dim=1)
        
        return x
    
model = Model(128).to(device)
print(decode(model.generate(max_len=100).squeeze()))
summary(model, (64, 65), dtypes=[torch.long], device=device)

  0%|          | 0/100 [00:00<?, ?it/s]

100%|██████████| 100/100 [00:00<00:00, 255.27it/s]

rxbRQWgZO;Jf-hqQb,.KHdd
NA.-,dz'Y,UwOSI,'viP 3vfICBCig;aXwGfdCmQ.e3TEfJ:!3HbqpyoK!p.U&M,MS!xed?&Kq,:,





Layer (type:depth-idx)                   Output Shape              Param #
Model                                    [64, 65, 65]              --
├─Embedding: 1-1                         [64, 65, 128]             8,320
├─Embedding: 1-2                         [65, 128]                 8,192
├─ModuleList: 1-3                        --                        --
│    └─TransformerBlock: 2-1             [64, 65, 128]             --
│    │    └─LayerNorm: 3-1               [64, 65, 128]             256
│    │    └─Sequential: 3-2              [64, 65, 128]             197,120
│    │    └─LayerNorm: 3-3               [64, 65, 128]             256
│    │    └─Sequential: 3-4              [64, 65, 128]             33,024
├─Linear: 1-4                            [64, 65, 65]              8,385
Total params: 255,553
Trainable params: 255,553
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 823.77
Input size (MB): 0.03
Forward/backward pass size (MB): 40.57
Params size (MB): 1.02
Estimat

In [16]:
# b, c, h, w -> b, t, c
rearrage = Rearrange("b c h w -> b (h w) c")
x = torch.randn(1, 3, 64, 64)
rearrage(x).shape

torch.Size([1, 4096, 3])

In [6]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.0004)

In [9]:
if "val_loss" in locals():
    pass
else:
    val_loss = float("inf")


for epoch in range(100):
    model.train()
    pbar = tqdm(train_loader, desc=f"Epoch {epoch}")
    for inputs, in pbar:
        inputs = inputs.to(device)

        x = inputs[:, :-1]
        y = inputs[:, 1:]

        optimizer.zero_grad()
        
        y_pred = model(x)
        loss = model.loss(y_pred, y)

        loss.backward()
        optimizer.step()
        pbar.set_postfix(loss=loss.item(), val_loss=val_loss)

        # clip gradients
        clamp_gradients(model, 1e-3)

    val_loss = evaluate(model, val_loader)

Epoch 0: 100%|██████████| 109/109 [00:02<00:00, 47.59it/s, loss=316, val_loss=inf]
Epoch 1: 100%|██████████| 109/109 [00:02<00:00, 51.40it/s, loss=8.03e+3, val_loss=315]
Epoch 2: 100%|██████████| 109/109 [00:02<00:00, 52.25it/s, loss=3.59e+4, val_loss=7.7e+3]
Epoch 3: 100%|██████████| 109/109 [00:02<00:00, 52.58it/s, loss=1.31e+5, val_loss=3.43e+4]
Epoch 4:   6%|▋         | 7/109 [00:00<00:02, 49.69it/s, loss=1.32e+5, val_loss=1.34e+5]


KeyboardInterrupt: 

In [None]:
print(decode(model.generate(max_len=128).squeeze()))

In [None]:
# for named parameters, print gradient statistics
print("Gradient statistics")
print(f"{'Name':30} {'Mean':<7} {'Std':<7} {'Norm':<7}")
for name, param in model.named_parameters():
    print(f"{name:30} {param.grad.mean():.5f} {param.grad.std():.5f} {param.grad.norm():.5f}")

    # clip gradient norm
    

In [8]:
@torch.no_grad()
def clamp_gradients(model, max_value: float):
    for param in model.parameters():
        param.grad.clamp_(-max_value, max_value)

In [None]:
clamp_gradients(model, 1e-3)