In [None]:
!nvidia-smi

Fri Aug 22 16:46:44 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A100-SXM4-40GB          Off |   00000000:00:04.0 Off |                    0 |
| N/A   34C    P0             43W /  400W |       0MiB /  40960MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
                                                

In [None]:
%pip install torchinfo --quiet

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchinfo import summary
import requests

from torch.nn.attention import SDPBackend, sdpa_kernel
from torch.cuda.amp import autocast # For GPU acceleration

import tiktoken
tok = tiktoken.get_encoding("p50k_base")

In [None]:
import requests

# List of books to download (title: URL)
books = {
    "moby_dick.txt": "https://www.gutenberg.org/files/2701/2701-0.txt",
    "pride_prejudice.txt": "https://www.gutenberg.org/files/1342/1342-0.txt",
    "frankenstein.txt": "https://www.gutenberg.org/files/84/84-0.txt",
    "alice_wonderland.txt": "https://www.gutenberg.org/files/11/11-0.txt",
    "christmas_carol.txt": "https://www.gutenberg.org/files/98/98-0.txt",
}

# Download and store in `data`
data = ""

for name, url in books.items():
    print(f"Downloading {name}...")
    response = requests.get(url)
    if response.status_code == 200:
        text = response.text.strip()
        data += text + "\n\n"
    else:
        print(f"Failed to download {name}")

print("\n✅ Download complete. `data` now contains the combined text of all books.")

Downloading moby_dick.txt...
Downloading pride_prejudice.txt...
Downloading frankenstein.txt...
Downloading alice_wonderland.txt...
Downloading christmas_carol.txt...

✅ Download complete. `data` now contains the combined text of all books.


In [None]:
class Rope(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(Rope, self).__init__()

        # Store the embedding dimension (must be even for proper 2D rotations)
        assert d_model % 2 == 0, "d_mAodel must be even for proper RoPE implementation"
        self.d_model = d_model

        # Store the maximum sequence length this RoPE can handle
        self.max_len = max_len

        # Create position indices [0, 1, 2, ..., max_len-1] and add dimension for broadcasting
            # Shape: (max_len, 1) - each position gets its own row
        self.register_buffer('position_ids', torch.arange(max_len).unsqueeze(1))

        # Create frequency terms for rotation - these determine how fast each dimension pair rotates
            # Uses exponential decay: smaller indices rotate faster, larger indices rotate slower
            # torch.arange(0, d_model, 2) creates [0, 2, 4, 6, ...] (even indices only)
            # The formula creates frequencies: [1/10000^(0/d), 1/10000^(2/d), 1/10000^(4/d), ...]
            # This is equal to (10000 ** (-(2 * i) / d_model) for i in range(d_model // 2))
        self.register_buffer('div_term', torch.exp(torch.arange(0, d_model, 2) * -(torch.log(torch.tensor(10000.0)) / d_model)))

    def forward(self, x):
        # Input shape: (batch_size, seq_len, d_model)
        batch_size, seq_len, d_model = x.shape

        # Get position indices for current sequence length (trim to actual sequence length)
            # If input has 100 tokens, this gets positions [0, 1, 2, ..., 99]
        position_ids = self.position_ids[:seq_len]  # Shape: (seq_len, 1)

        # Calculate rotation angles for each position and frequency based on 2017 paper
            # Multiply each position by each frequency term to get rotation angles
            # Shape: (seq_len, d_model//2)
            # This is basically: pos/(10000^(2i/d_model))
        angles = position_ids * self.div_term

        # Calculate sine and cosine values for rotation
            # Shape: (seq_len, d_model//2)
        cos_vals = torch.cos(angles)
        sin_vals = torch.sin(angles)

        # Reshape input to separate even and odd dimensions for 2D rotation
            # Split x into pairs: (x_0, x_1), (x_2, x_3), (x_4, x_5), ...
            # Shape: (batch_size, seq_len, d_model//2, 2)
        x_pairs = x.view(batch_size, seq_len, d_model // 2, 2)

        # Extract even and odd components
            # x_even contains x_0, x_2, x_4, ... (first element of each pair)
            # x_odd contains x_1, x_3, x_5, ... (second element of each pair)
        x_even = x_pairs[..., 0]  # Shape: (batch_size, seq_len, d_model//2)
        x_odd = x_pairs[..., 1]   # Shape: (batch_size, seq_len, d_model//2)

        # Apply 2D rotation to each pair of dimensions
            # Rotation matrix: [[cos, -sin], [sin, cos]]
            # For each pair (x_i, x_{i+1}), compute:
            # x_i' = x_i * cos - x_{i+1} * sin
            # x_{i+1}' = x_i * sin + x_{i+1} * cos
        rotated_even = x_even * cos_vals - x_odd * sin_vals
        rotated_odd = x_even * sin_vals + x_odd * cos_vals

        rotated_pairs = torch.stack([rotated_even, rotated_odd], dim=-1)
        rotated_x = rotated_pairs.view(batch_size, seq_len, d_model)

        return rotated_x

In [None]:
class GQA(nn.Module):
    def __init__(self,
                 d_model: int = 384,
                 n_heads: int = 8,
                 gqa_groups: int = 2,
                 max_len: int = 1024,
                ):
        super().__init__()  # initialize base Module
        assert d_model % n_heads == 0, "d_model must be divisible by n_heads"  # validate head split
        assert n_heads % gqa_groups == 0, "n_heads must be divisible by gqa_groups"  # validate grouping

        self.d_model = d_model  # store model dimension
        self.n_heads = n_heads  # store number of query heads
        self.gqa_groups = gqa_groups  # store number of groups for GQA
        self.head_dim = d_model // n_heads  # compute per-head dimension
        self.n_kv_heads = n_heads // gqa_groups  # compute number of K/V heads for GQA
        self.max_len = max_len  # store max sequence length

        # Define bias-free linear projections
        self.q_proj = nn.Linear(d_model, n_heads * self.head_dim, bias=False)  # Q projection: d_model -> H*D
        self.k_proj = nn.Linear(d_model, self.n_kv_heads * self.head_dim, bias=False)  # K projection: d_model -> H_kv*D
        self.v_proj = nn.Linear(d_model, self.n_kv_heads * self.head_dim, bias=False)  # V projection: d_model -> H_kv*D
        self.o_proj = nn.Linear(n_heads * self.head_dim, d_model, bias=False)  # Output projection: H*D -> d_model

        # Instantiate two RoPE modules with the exact composite dims requested
        self.rope_q = Rope(d_model=n_heads * self.head_dim, max_len=max_len)  # RoPE for Q (expects (B,T,H*D))
        self.rope_k = Rope(d_model=self.n_kv_heads * self.head_dim, max_len=max_len)  # RoPE for K (expects (B,T,H_kv*D))

    def forward(self,
                x: torch.Tensor,  # (B, T, d_model)
                attention_mask: torch.Tensor | None = None  # (B, T) 1=real,0=pad (ignored: no attention bias)
                ) -> torch.Tensor:  # returns (B, T, d_model)
        B, T, C = x.shape  # unpack input shape

        # Linear projections for Q, K, V
        q = self.q_proj(x)  # (B, T, H*D)
        k = self.k_proj(x)  # (B, T, H_kv*D)
        v = self.v_proj(x)  # (B, T, H_kv*D)

        # Apply RoPE to Q over the flattened head dimension
        q = self.rope_q(q)  # (B, T, H*D) with rotary positional encoding applied
        q = q.view(B, T, self.n_heads, self.head_dim).permute(0, 2, 1, 3).contiguous()  # (B, H, T, D)

        # Apply RoPE to K over the flattened head dimension
        k = self.rope_k(k)  # (B, T, H_kv*D) with rotary positional encoding applied
        k = k.view(B, T, self.n_kv_heads, self.head_dim).permute(0, 2, 1, 3).contiguous()  # (B, H_kv, T, D)

        # Reshape V to heads (no RoPE for V)
        v = v.view(B, T, self.n_kv_heads, self.head_dim).permute(0, 2, 1, 3).contiguous()  # (B, H_kv, T, D)

        # Expand K and V from n_kv_heads to n_heads via repeat_interleave on the head axis
        expand_factor = self.n_heads // self.n_kv_heads  # compute replication factor
        k = k.repeat_interleave(expand_factor, dim=1)  # (B, H, T, D)
        v = v.repeat_interleave(expand_factor, dim=1)  # (B, H, T, D)
        # Above thing converts [1,2,3,4] -> [1,1,1,2,2,2,3,3,3,4,4,4] when expand_factor is 3 and dim=0

        # Compute SDPA with purely causal masking (no external attention bias, this uses flash attention
        with sdpa_kernel(
            backends=SDPBackend.FLASH_ATTENTION,
            enable_math=False,
            enable_mem_efficient=False,
            enable_flash=True
        ):
            out = F.scaled_dot_product_attention(q, k, v, attn_mask=None, is_causal=True)  # (B, H, T, D)

        # Merge heads back to (B, T, H*D)
        out = out.permute(0, 2, 1, 3).contiguous().view(B, T, self.n_heads * self.head_dim)  # (B, T, d_model)

        # Project to output dimension
        out = self.o_proj(out)  # (B, T, d_model)

        return out  # return attended representations

In [None]:
class MLP(nn.Module):
    """
    SwiGLU MLP for a decoder-only Transformer block.

    - d_model: 384
    - d_ff: ~768 (≈2.5 × d_model)
    - Linear layers are bias-free
    - RMSNorm is applied outside this module
    - Input/Output shape: (batch, seq_len, d_model)
    - BF16-friendly: uses ops that preserve input dtype
    """
    def __init__(self, d_model: int = 384, d_ff: int = 768):
        super().__init__()
        # Fused "up" + "gate" projection to reduce matmuls: d_model -> 2*d_ff
        self.w1 = nn.Linear(d_model, 2 * d_ff, bias=False)
        # Down projection: d_ff -> d_model
        self.w2 = nn.Linear(d_ff, d_model, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (B, T, d_model)
        up, gate = self.w1(x).chunk(2, dim=-1)  # (B, T, d_ff) each

        # We split in two because SwiGLU works like that and it takes -
            # First half which is content
            # Second half which is how much of content in the context
        x = up * F.silu(gate)                   # SwiGLU: up ⊗ swish(gate)
        x = self.w2(x)                          # (B, T, d_model)
        return x

In [None]:
class Block(nn.Module):
    def __init__(
        self,
        d_model: int = 384,
        n_heads: int = 8,
        gqa_groups: int = 2,
        max_len: int = 1024,
        d_ff: int = 768,
        eps: float = 1e-5,
        dropout_p: float = 0.0,  # keep 0.0 for pretrain
    ):
        super().__init__()
        self.rms1 = nn.modules.normalization.RMSNorm(d_model, eps)
        self.rms2 = nn.modules.normalization.RMSNorm(d_model, eps)

        self.attn = GQA(d_model, n_heads, gqa_groups, max_len)  # should include proj_out
        self.mlp = MLP(d_model, d_ff)  # your SwiGLU MLP (bias-free)

        self.drop_attn = nn.Dropout(dropout_p)
        self.drop_mlp = nn.Dropout(dropout_p)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # pre-rms
        x = x + self.drop_attn(self.attn(self.rms1(x)))
        x = x + self.drop_mlp(self.mlp(self.rms2(x)))
        return x

In [None]:
class GatorGPT(nn.Module):
    def __init__(
        self,
        vocab_size: int,
        d_model: int = 384,
        n_heads: int = 8,
        gqa_groups: int = 2,
        max_len: int = 1024,
        d_ff: int = 768,
        eps: float = 1e-5,
        dropout_p: float = 0.0,
        blocks: int = 10,
    ):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d_model)
        self.unembed = nn.Linear(d_model, vocab_size, bias=False)

        self.final_rms = nn.modules.normalization.RMSNorm(d_model, eps)
        self.embed.weight = self.unembed.weight

        self.blocks = nn.ModuleList(
            [
                Block(
                    d_model=d_model,
                    n_heads=n_heads,
                    gqa_groups=gqa_groups,
                    max_len=max_len,
                    d_ff=d_ff,
                    eps=eps,
                    dropout_p=dropout_p,
                ) for _ in range(blocks)
            ]
        )

    def forward(self, x):
        """
        Forward method that takes in the tokens
        """
      # x: (batch, seq_len) of token ids
        h = self.embed(x)                 # (batch, seq_len, d_model)
        for block in self.blocks:         # run each transformer block
            h = block(h)
        h = self.final_rms(h)
        logits = self.unembed(h)          # (batch, seq_len, vocab_size)
        return logits

In [None]:
class GatorGPTDataset(Dataset):
    def __init__(self, txt, tokenizer=tok, max_length=1024, stride=512):
        self.input_ids = []
        self.target_ids = []

        # Tokenize the entire text
        token_ids = tokenizer.encode(txt, allowed_special={"<|endoftext|>"})

        # Use a sliding window to chunk the book into overlapping sequences of max_length
        for i in range(0, len(token_ids) - max_length, stride):
            self.input_ids.append(torch.tensor(token_ids[i:i + max_length]))
            self.target_ids.append(torch.tensor(token_ids[i + 1: i + max_length + 1]))

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return self.input_ids[idx], self.target_ids[idx]

In [None]:
###################################### TURNS TEXTS TO TOKENS
def text_to_token_ids(text, tokenizer):
  encoded = tokenizer.encode(text, allowed_special={'<|endoftext|>'})
  encoded_tensor = torch.tensor(encoded).unsqueeze(0)
  return encoded_tensor
############################################################################

###################################### TURN TOKENS TO TEXT
def token_ids_to_text(token_ids, tokenizer):
  decoded = tokenizer.decode(token_ids.squeeze(0).tolist())
  return decoded
############################################################################

###################################### CREATES DATA LOADERS
def create_dataloader(txt, batch_size=4, max_length=256,
                         stride=128, shuffle=True, drop_last=True,
                         num_workers=0): # Changed default to 0

    # Initialize the tokenizer
    tokenizer = tiktoken.get_encoding("p50k_base")

    # Create dataset
    dataset = GatorGPTDataset(txt, tokenizer, max_length, stride)

    # Create dataloader
    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        drop_last=drop_last,
        num_workers=num_workers # Pass the potentially modified num_workers
    )

    return dataloader
############################################################################

In [None]:
###################################### LOSS FOR ONE BATCH
# Caluclates loss for a batch
def calc_loss_batch(input_batch, target_batch, model, device):
    input_batch, target_batch = input_batch.to(device), target_batch.to(device)

    # For other devices
    # logits = model(input_batch)
    # loss = torch.nn.functional.cross_entropy(logits.flatten(0, 1), target_batch.flatten())

    # For A100s
    with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
        logits = model(input_batch)
        loss = torch.nn.functional.cross_entropy(
            logits.flatten(0, 1),
            target_batch.flatten()
        )

    return loss
############################################################################

###################################### LOSS FOR ENTIRE LOADER
# Caluculates loss for ENTIRE data_loader which calls calc_loss_batch function inside itself
def calc_loss_loader(data_loader, model, device, num_batches=None):
    total_loss = 0.
    if len(data_loader) == 0:
        return float("nan")
    elif num_batches is None:
        num_batches = len(data_loader)
    else:
        # Reduce the number of batches to match the total number of batches in the data loader
        # if num_batches exceeds the number of batches in the data loader
        num_batches = min(num_batches, len(data_loader))
    for i, (input_batch, target_batch) in enumerate(data_loader):
        if i < num_batches:
            loss = calc_loss_batch(input_batch, target_batch, model, device)
            total_loss += loss.item()
        else:
            break
    return total_loss / num_batches
############################################################################

###################################### USE THIS TO EVALUATE MODEL DIRECLTY
def evaluate_model(model, train_loader, val_loader, device, eval_iter):
  # basically returns the losses for training and validation
  model.eval()
  with torch.no_grad():
    return calc_loss_loader(train_loader, model, device, eval_iter), calc_loss_loader(val_loader, model, device, eval_iter)
############################################################################

####################################### GENERATING NEW TOKENS
def generate_text_simple(model, idx, max_new_tokens, context_size):

  # idx is (batch, n_tokens) array of indices in current context
  for _ in range(max_new_tokens):
    # If LLM suports only 5 tokens, and the context size is 10, then we only use last 5 toens as context.
    idx_cond = idx[:, -context_size:]

    # Gettings the predictions
    with torch.no_grad():
      # Reshape idx_cond to (batch_size, sequence_length, emb_dim)
      # idx_cond = idx_cond.unsqueeze(-1).repeat(1 , 1, model.norm1.scale.shape[0]) # Or model.att.d_in to get the embedding dimension
      logits = model(idx_cond) # (batch, num_tokens, vocab_size)

    # We take the last row. We dont do anything to the batches neither to the last dimension of the vocabularies, but take the last row
    logits = logits[:, -1, :] # (batch, vocab_size)

    # getting probablities from the logits. We can say something like 50% chances of this, 2% chances of this...
    probs = torch.softmax(logits, dim=-1) # (batch, vocab_size)

    # We see the highest value's index
    idx_next = torch.argmax(probs, dim=-1, keepdim=True) # (batch, 1)

    # Append the predicted token_id generated to the original index
    idx = torch.cat((idx, idx_next), dim=1) # (batch, num_tokens+1)

  return idx
############################################################################

###################################### GENERATING AND PRINTING SAMPLES
def generate_and_print_sample(model, tokenizer, device, start_context, context_size):
  # we print out, what the model is generating right now at the end of each epoch. Also, we print 50 items!
  model.eval()
  encoded = text_to_token_ids(start_context, tokenizer).to(device)

  with torch.no_grad():
    token_ids = generate_text_simple(
        model, encoded, 50, context_size
    )

  decoded = token_ids_to_text(token_ids, tokenizer)
  print(decoded.replace("\n", " "))
  model.train()
############################################################################

In [None]:
def train_model_simple(model, train_loader, val_loader, optimizer, device, num_epochs,
                       eval_freq, eval_iter, start_context, tokenizer, Patience):
    # Initialize tracking lists
    train_losses, val_losses, track_tokens_seen = [], [], []
    tokens_seen, global_step = 0, -1
    best_val_loss = float('inf')
    PATIENCE = Patience
    patience_counter = 0

    # Check dataset size
    print(f"📊 Dataset info:")
    print(f"   Training batches: {len(train_loader)}")
    print(f"   Validation batches: {len(val_loader)}")
    print(f"   Total steps per epoch: {len(train_loader)}")
    print(f"   Evaluation every {eval_freq} steps")
    print("="*50)

    for epoch in range(num_epochs):
        print(f"🔁 Starting epoch {epoch+1}/{num_epochs}")
        print("-"*30)
        model.train()

        epoch_steps = 0
        epoch_loss = 0.0

        for input_batch, target_batch in train_loader:

            if patience_counter >= PATIENCE:
              break

            optimizer.zero_grad()
            loss = calc_loss_batch(input_batch, target_batch, model, device)
            loss.backward()
            optimizer.step()

            tokens_seen += input_batch.numel()
            global_step += 1
            epoch_steps += 1
            epoch_loss += loss.item()

            # Print progress within epoch
            if epoch_steps % max(1, len(train_loader) // 5) == 0:  # Print 5 times per epoch
                avg_loss = epoch_loss / epoch_steps
                progress = epoch_steps / len(train_loader) * 100
                print(f"   Step {epoch_steps}/{len(train_loader)} ({progress:.1f}%) - "
                      f"Current loss: {loss.item():.3f}, Avg loss: {avg_loss:.3f}, patience level: {patience_counter}/5")

            # Evaluation during training
            if global_step % eval_freq == 0:
                train_loss, val_loss = evaluate_model(
                    model, train_loader, val_loader, device, eval_iter)
                train_losses.append(train_loss)
                val_losses.append(val_loss)
                track_tokens_seen.append(tokens_seen)

                improvement = "✅" if val_loss < best_val_loss else "❌"
                print(f"   Eval at step {global_step}: Train {train_loss:.3f}, "
                      f"Val {val_loss:.3f} {improvement}, patience level: {patience_counter}/5")

                if val_loss < best_val_loss:
                    best_val_loss = val_loss
                    patience_counter = 0
                else:
                    patience_counter += 1

        # End of epoch summary
        avg_epoch_loss = epoch_loss / epoch_steps if epoch_steps > 0 else float('inf')
        print(f"✅ Epoch {epoch+1} complete: {epoch_steps} steps, avg loss: {avg_epoch_loss:.3f}")

        # Generate sample after each epoch
        print("🎯 Generated sample:")
        generate_and_print_sample(model, tokenizer, device, start_context, 1024)
        print("="*50)

    print(f"🏁 Training complete! Best validation loss: {best_val_loss:.3f}")
    return train_losses, val_losses, track_tokens_seen

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

model = GatorGPT(vocab_size=tok.n_vocab)

# To cuda, compiling before running
model = model.to(device)
model = torch.compile(model)

optim = torch.optim.AdamW(model.parameters(), lr=0.001, eps=1e-08, weight_decay=0.01)
EPOCHS = 5
EVAL_FREQ = 543
EVAL_ITER = 180
CONTEXT = "We ran across a field that was"
PATIENCE = 5

train_loader = create_dataloader(
    data[:int(0.9*len(data))],
    batch_size=4,
    max_length=1024,
    stride=128,
    shuffle=True,
    drop_last=True,
    num_workers=0 # Explicitly setting to 0 here as well for clarity
)

val_loader = create_dataloader(
    data[int(0.9*len(data)):],
    batch_size=4,
    max_length=1024,
    stride=128,
    shuffle=True,
    drop_last=True,
    num_workers=0 # Explicitly setting to 0 here as well for clarity
)

In [None]:
train_model_simple(model, train_loader, val_loader, optim, device, EPOCHS, EVAL_FREQ, EVAL_ITER, CONTEXT, tok, PATIENCE)

📊 Dataset info:
   Training batches: 1629
   Validation batches: 184
   Total steps per epoch: 1629
   Evaluation every 543 steps
🔁 Starting epoch 1/5
------------------------------
   Eval at step 0: Train 10.267, Val 10.223 ✅, patience level: 0/5
   Step 325/1629 (20.0%) - Current loss: 4.716, Avg loss: 5.713, patience level: 0/5
   Eval at step 543: Train 4.757, Val 4.910 ✅, patience level: 0/5
   Step 650/1629 (39.9%) - Current loss: 4.673, Avg loss: 5.293, patience level: 0/5
   Step 975/1629 (59.9%) - Current loss: 4.446, Avg loss: 5.052, patience level: 0/5
   Eval at step 1086: Train 4.374, Val 4.727 ✅, patience level: 0/5
   Step 1300/1629 (79.8%) - Current loss: 4.193, Avg loss: 4.876, patience level: 0/5
   Step 1625/1629 (99.8%) - Current loss: 4.448, Avg loss: 4.738, patience level: 0/5
✅ Epoch 1 complete: 1629 steps, avg loss: 4.736
🎯 Generated sample:
 the same time, and the same time, and the same, and
🔁 Starting epoch 2/5
------------------------------
   Eval at step 

([10.266812176174588,
  4.756670006116232,
  4.374165822399987,
  4.058816957473755,
  3.812153563234541,
  3.4509832514656913,
  3.1961449914508395,
  2.869881663057539,
  2.5189720312754313,
  2.147828308078978,
  1.8217950297726526],
 [10.223125457763672,
  4.909924348195394,
  4.727024560504489,
  4.627366911040412,
  4.621872064802382,
  4.608041501045227,
  4.7057045261065165,
  4.904332576857673,
  5.101007869508531,
  5.35958661503262,
  5.726420248879327],
 [4096,
  2228224,
  4452352,
  6676480,
  8900608,
  11124736,
  13348864,
  15572992,
  17797120,
  20021248,
  22245376])

In [None]:
token_ids_to_text(generate_text_simple(model, text_to_token_ids("Fascinated by the black", tok).to(device), 50, 1024), tok).replace("\n", " ").replace("\r\\", "").replace("\r", "")

'Fascinated by the blackium, and the coffin-ends of the Pequod’s company, and the nearest. The Chase made itself, and the seamen went by the purchasing house air, and the rolling ship, and the rolling'