In [1]:
from datasets import load_dataset
from glob import glob
from random import shuffle

train_files = glob("data/train/docs_*.jsonl")
shuffle(train_files)

test_files = glob("data/test/docs_*.jsonl")

data_files = {
    "train": train_files,
    "test": test_files
}

dataset = load_dataset("json", data_files=data_files, streaming=True)
train_dataset = dataset["train"]

In [2]:
from streaming_dataset import StreamingTokenDataset
from transformers import AutoTokenizer
from torch.utils.data import DataLoader

# from transformers import AutoModel
# model = AutoModel.from_pretrained("allegro/herbert-base-cased")

tokenizer = AutoTokenizer.from_pretrained("allegro/herbert-base-cased")

CONTEXT_LENGTH = 128
BATCH_SIZE = 4
loader = DataLoader(StreamingTokenDataset(train_dataset, tokenizer, context_size=CONTEXT_LENGTH), batch_size=BATCH_SIZE)

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class SinusoidalPositionalEncoding(nn.Module):
    def __init__(self, embed_dim, context_length=128):
        super().__init__()

        self.embed_dim = embed_dim
        pe = torch.zeros(context_length, embed_dim)

        position = torch.arange(0, context_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, embed_dim, 2).float() * (-torch.math.log(10000.0) / embed_dim))

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer("pe", pe.unsqueeze(0))

    def forward(self, x):
        x = x + self.pe
        return x

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class CausalSelfAttention(nn.Module):
    def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.1):
        super().__init__()
        assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        # Linear projections for Q, K, V
        self.qkv_proj = nn.Linear(embed_dim, 3 * embed_dim)
        self.out_proj = nn.Linear(embed_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: (batch_size, seq_len, embed_dim)
        returns: (batch_size, seq_len, embed_dim)
        """
        B, T, C = x.shape

        # Compute Q, K, V and reshape for multi-head
        qkv: torch.Tensor = self.qkv_proj(x)  # (B, T, 3*embed_dim)
        q, k, v = qkv.chunk(3, dim=-1)

        # reshape to (B, heads, T, head_dim)
        q = q.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
        k = k.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
        v = v.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)

        # Scaled dot-product attention
        attn_scores = (q @ k.transpose(-2, -1)) / (self.head_dim ** 0.5)  # (B, heads, T, T)

        # Causal mask (prevent attending to future tokens)
        mask = torch.tril(torch.ones(T, T, device=x.device)).unsqueeze(0).unsqueeze(0)
        attn_scores = attn_scores.masked_fill(mask == 0, float('-inf'))

        attn_probs = F.softmax(attn_scores, dim=-1)
        attn_probs = self.dropout(attn_probs)

        # Attention output
        out = attn_probs @ v  # (B, heads, T, head_dim)
        out = out.transpose(1, 2).contiguous().view(B, T, C)  # (B, T, embed_dim)
        out = self.out_proj(out)
        return out

In [5]:
class FeedForward(nn.Module):
    def __init__(self, embed_dim: int, hidden_dim: int, dropout: float = 0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(embed_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, embed_dim),
            nn.Dropout(dropout)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)

In [6]:
class DecoderBlock(nn.Module):
    def __init__(self, embed_dim: int, num_heads: int, ff_hidden_dim: int, dropout: float = 0.1):
        super().__init__()
        self.ln1 = nn.LayerNorm(embed_dim)
        self.attn = CausalSelfAttention(embed_dim, num_heads, dropout)
        self.ln2 = nn.LayerNorm(embed_dim)
        self.mlp = FeedForward(embed_dim, ff_hidden_dim, dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x

In [16]:
class GPTDecoder(nn.Module):
    def __init__(
        self,
        vocab_size: int,
        embed_dim: int,
        num_heads: int,
        ff_hidden_dim: int,
        num_layers: int,
        context_length: int,
        dropout: float = 0.1,
    ):
        super().__init__()
        self.embed_tokens = nn.Embedding(vocab_size, embed_dim)
        self.pos_encoding = SinusoidalPositionalEncoding(embed_dim, context_length)
        
        self.layers = nn.ModuleList([
            DecoderBlock(embed_dim, num_heads, ff_hidden_dim, dropout)
            for _ in range(num_layers)
        ])
        self.ln_f = nn.LayerNorm(embed_dim)

        # Language modeling head
        self.lm_head = nn.Linear(embed_dim, vocab_size, bias=False)
        # Weight tying
        self.lm_head.weight = self.embed_tokens.weight

    def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
        """
        input_ids: (batch_size, seq_len)
        returns logits: (batch_size, seq_len, vocab_size)
        """
        x = self.embed_tokens(input_ids)       # (B, T, embed_dim)
        x = self.pos_encoding(x)               # add positional encoding

        for layer in self.layers:
            x = layer(x)

        x = self.ln_f(x)                       # final layer norm
        logits = self.lm_head(x)               # project to vocab size
        return logits


In [17]:
vocab_size = 50_000
embed_dim = 512
num_heads = 8
ff_hidden_dim = 2048
num_layers = 6
context_length = 128
dropout = 0.1

gpt = GPTDecoder(
    vocab_size=vocab_size,
    embed_dim=embed_dim,
    num_heads=num_heads,
    ff_hidden_dim=ff_hidden_dim,
    num_layers=num_layers,
    context_length=context_length,
    dropout=dropout
)

In [20]:
def choose_device() -> str:
    if torch.cuda.is_available():
        return "cuda"
    elif torch.backends.mps.is_available():
        return "mps"
    else:
        return "cpu"

In [None]:
import torch
import torch.nn as nn
from tqdm import tqdm

epochs = 10
learning_rate = 1e-3
weight_decay = 1e-2
grad_clip = 1.0
device = torch.device(choose_device())

print(f"Training on device: {device}")

gpt.to(device)
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)
optimizer = torch.optim.AdamW(gpt.parameters(), lr=learning_rate, weight_decay=weight_decay)

# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=1e-5)

for epoch in range(1, epochs + 1):
    gpt.train()
    total_loss = 0.0

    progress = tqdm(enumerate(loader), total=23642, desc=f"Epoch {epoch}/{epochs}")

    for i, (batch_x, batch_y) in progress:
        batch_x = batch_x.to(device)
        batch_y = batch_y.to(device)

        optimizer.zero_grad()
        out = gpt(batch_x)

        # Flatten for CrossEntropyLoss
        loss = criterion(out.view(-1, out.size(-1)), batch_y.view(-1))
        loss.backward()

        # Gradient clipping
        # torch.nn.utils.clip_grad_norm_(lstm.parameters(), grad_clip)

        optimizer.step()

        total_loss += loss.item()
        avg_loss = total_loss / (i + 1)

        progress.set_postfix({"loss": f"{avg_loss:.4f}", "lr": optimizer.param_groups[0]["lr"]})

    # scheduler.step()

    print(f"Epoch {epoch} done | Average Loss: {avg_loss:.4f}")

torch.save(gpt.state_dict(), "lstm_next_token_model.pt")
print("Training complete. Model saved to lstm_next_token_model.pt")


In [None]:
def generate_text(model, tokenizer, prompt, max_new_tokens=20, device=None):
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    model.to(device)
    model.eval()

    # Encode prompt
    tokens = tokenizer.encode(prompt, add_special_tokens=False)
    input_ids = torch.tensor(tokens, dtype=torch.long).unsqueeze(0).to(device)

    generated_tokens = tokens.copy()

    with torch.no_grad():
        hidden = None
        for _ in range(max_new_tokens):
            out, hidden = model(input_ids, hidden)
            last_logits = out[0, -1, :]  # last token
            probs = torch.softmax(last_logits, dim=-1)
            predicted_id = torch.argmax(probs).item()

            # Append predicted token
            generated_tokens.append(predicted_id)
            
            # Prepare next input
            input_ids = torch.tensor([[predicted_id]], dtype=torch.long).to(device)
    
    # Decode full sequence
    text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
    return text


Generated text: Odrzekł do nie centurion , a ja w tej chwili , gdy w głębi duszy , w którym się znajdował , w którym się znajdował
