In [12]:
import numpy as np

In [1]:
with open('./datasets/raw.txt', "r", encoding='utf-8') as file:
    text = file.read()

In [11]:
text

'CHESS FUNDAMENTALS\n\nPART I\n\nCHAPTER I\n\nFIRST PRINCIPLES: ENDINGS, MIDDLE-GAME AND OPENINGS\n\nThe first thing a student should do, is to familiarise himself with the\npower of the pieces. This can best be done by learning how to accomplish\nquickly some of the simple mates.\n\n1. SOME SIMPLE MATES\n\nEXAMPLE 1.--The ending Rook and King against King.\n\n_The principle is to drive the opposing King to the last line on any side\nof the board_.\n\n[Illustration] {4}\n\nIn this position the power of the Rook is demonstrated by the first move,\nR - R 7, which immediately confines the Black King to the last rank, and\nthe mate is quickly accomplished by: 1 R - R 7, K - Kt 1; 2 K - Kt 2.\n\nThe combined action of King and Rook is needed to arrive at a position in\nwhich mate can be forced. The general principle for a beginner to follow is\nto\n\n_keep his King as much as possible on the same rank, or, as in this case,\nfile, as the opposing King._\n\nWhen, in this case, the King has be

In [14]:
print(f"Dataset length: {len(text)}")

Dataset length: 228902


In [15]:
# Building BPE Tokenizer, Byte pair encoder

from collections import Counter, defaultdict
import re


In [38]:
# creating a char tokens list containing all the word in list form,
# containing each character in word, also after each word, adding </w> token.

char_tokens = []
for word in text.split():
    char_tokens.append(list(word) + ["</w>"])

In [20]:
# prefect
char_tokens[:10]

[['C', 'H', 'E', 'S', 'S', '</w>'],
 ['F', 'U', 'N', 'D', 'A', 'M', 'E', 'N', 'T', 'A', 'L', 'S', '</w>'],
 ['P', 'A', 'R', 'T', '</w>'],
 ['I', '</w>'],
 ['C', 'H', 'A', 'P', 'T', 'E', 'R', '</w>'],
 ['I', '</w>'],
 ['F', 'I', 'R', 'S', 'T', '</w>'],
 ['P', 'R', 'I', 'N', 'C', 'I', 'P', 'L', 'E', 'S', ':', '</w>'],
 ['E', 'N', 'D', 'I', 'N', 'G', 'S', ',', '</w>'],
 ['M', 'I', 'D', 'D', 'L', 'E', '-', 'G', 'A', 'M', 'E', '</w>']]

In [21]:
# now lets get the frequency of each word token

def get_stats(tokens: list):
    pairs = Counter()
    for word in tokens:
        for i in range(len(word) - 1):
            pairs[(word[i], word[i+1])] += 1
    return pairs

In [39]:
# another function to merge most frequently appearing pairs


def merge_pair(pair, tokens):
    new_tokens = []
    bigram = re.escape(' '.join(pair))
    pattern = re.compile(r'(?<!\S)' + bigram.replace(" ", r'\s') + r'(?!\S)')
    for word in tokens:
        new_word = []
        i = 0
        while i < len(word):
            if i < len(word) - 1 and (word[i], word[i+1]) == pair:
                new_word.append(word[i] + word[i+1])
                i += 2
            else:
                new_word.append(word[i])
                i += 1
        new_tokens.append(new_word)
    return new_tokens

In [40]:
# create BPE loop


vocab_size = 1024
vocab  = set(char for word in char_tokens for char in word)

In [41]:
len(vocab)

81

In [43]:
vocab

{'!',
 '"',
 "'",
 '(',
 ')',
 '*',
 ',',
 '-',
 '.',
 '/',
 '0',
 '1',
 '2',
 '3',
 '4',
 '5',
 '6',
 '7',
 '8',
 '9',
 ':',
 ';',
 '</w>',
 '?',
 'A',
 'B',
 'C',
 'D',
 'E',
 'F',
 'G',
 'H',
 'I',
 'J',
 'K',
 'L',
 'M',
 'N',
 'O',
 'P',
 'Q',
 'R',
 'S',
 'T',
 'U',
 'V',
 'W',
 'X',
 'Y',
 'Z',
 '[',
 ']',
 '_',
 'a',
 'b',
 'c',
 'd',
 'e',
 'f',
 'g',
 'h',
 'i',
 'j',
 'k',
 'l',
 'm',
 'n',
 'o',
 'p',
 'q',
 'r',
 's',
 't',
 'u',
 'v',
 'w',
 'x',
 'y',
 'z',
 '{',
 '}'}

In [44]:
for i in range(vocab_size - len(vocab)):
    pairs = get_stats(char_tokens)
    if not pairs:
        break
    best_pair = pairs.most_common(1)[0][0]
    char_tokens = merge_pair(best_pair, char_tokens)
    vocab.add(''.join(best_pair))
    print(f"Merge {i+1}: {best_pair}")

Merge 1: ('e', '</w>')
Merge 2: ('t', '</w>')
Merge 3: ('t', 'h')
Merge 4: ('s', '</w>')
Merge 5: ('.', '</w>')
Merge 6: ('-', '</w>')
Merge 7: ('i', 'n')
Merge 8: (',', '</w>')
Merge 9: ('th', 'e</w>')
Merge 10: ('d', '</w>')
Merge 11: ('a', 'n')
Merge 12: ('K', '</w>')
Merge 13: ('o', 'n')
Merge 14: ('R', '</w>')
Merge 15: ('e', 'r')
Merge 16: ('e', 'n')
Merge 17: ('B', '</w>')
Merge 18: ('o', '</w>')
Merge 19: ('K', 't</w>')
Merge 20: ('y', '</w>')
Merge 21: ('P', '</w>')
Merge 22: ('Q', '</w>')
Merge 23: ('h', 'i')
Merge 24: ('in', 'g')
Merge 25: ('t', 'o</w>')
Merge 26: ('t', 'i')
Merge 27: ('l', 'a')
Merge 28: ('f', '</w>')
Merge 29: ('.', '.')
Merge 30: ('o', 'u')
Merge 31: ('o', 'r')
Merge 32: ('l', 'l')
Merge 33: ('x', '</w>')
Merge 34: ('ing', '</w>')
Merge 35: ('i', 's</w>')
Merge 36: ('o', 'f</w>')
Merge 37: ('c', 'h')
Merge 38: ('r', 'e')
Merge 39: ('c', 'k')
Merge 40: ('s', 'i')
Merge 41: ('in', '</w>')
Merge 42: ('an', 'd</w>')
Merge 43: ('ti', 'on')
Merge 44: ('3', '</w

In [45]:
len(vocab)

1024

In [46]:
vocab

{'No',
 'igh',
 'game.</w>',
 'combin',
 'pt</w>',
 'v',
 'other</w>',
 'gives</w>',
 'ready</w>',
 'therefore</w>',
 'ation',
 '6,</w>',
 '.......',
 'thus</w>',
 'ending',
 'Black</w>',
 'Knight</w>',
 'gh',
 '.)</w>',
 '34.</w>',
 'pie',
 'EN',
 'even</w>',
 '9</w>',
 'lose</w>',
 'ter</w>',
 'at',
 'again',
 'action</w>',
 'previous</w>',
 "King's</w>",
 'simila',
 'dr',
 'pieces</w>',
 'following</w>',
 'see</w>',
 'i',
 'soon</w>',
 'idea</w>',
 '4}</w>',
 'GAM',
 'consider',
 '...',
 'also</w>',
 'bring</w>',
 'ke',
 'ag',
 'endings</w>',
 'accord',
 'ings</w>',
 'inter',
 'fac',
 '2;</w>',
 'ha',
 '8</w>',
 'better</w>',
 'kn',
 'forced</w>',
 'them',
 'for',
 'square',
 'ba',
 'part</w>',
 '6}</w>',
 '23.</w>',
 'obta',
 'm',
 '4</w>',
 'whi',
 'dd',
 'ers</w>',
 'and</w>',
 'de',
 'than</w>',
 'ving</w>',
 'ing',
 'sign',
 'obtain</w>',
 ']',
 'only</w>',
 'llu',
 'Queen</w>',
 'los',
 'T',
 'Bishop.</w>',
 'fi',
 'comes</w>',
 'O</w>',
 'n',
 'al,</w>',
 '19',
 'there</w>',


In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from torch.utils.data import Dataset, DataLoader
import torch_directml

from torchinfo import summary
import os

from tokenizers import Tokenizer, models, trainers, pre_tokenizers
import time


In [2]:
## check all available gpus

print("=== PyTorch Device Listing ===")

# CUDA devices (NVIDIA)
if torch.cuda.is_available():
    print(f"CUDA devices: {torch.cuda.device_count()}")
    for i in range(torch.cuda.device_count()):
        print(f"  - cuda:{i}: {torch.cuda.get_device_name(i)}")
else:
    print("No CUDA devices found")

# MPS (Apple Silicon)
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
    print("MPS device available: mps")
else:
    print("No MPS device found")

# DirectML / privateuseone devices
try:
    dml_device = torch.device("privateuseone:0")
    torch.zeros(1).to(dml_device)  # test allocation
    print("DirectML / privateuseone device available: privateuseone:0")

    dml_device = torch.device("privateuseone:1")
    torch.zeros(1).to(dml_device)  # test allocation
    print("DirectML / privateuseone device available: privateuseone:1")
except Exception:
    print("No DirectML / privateuseone device found")

print("Default device:", torch.device("cpu"))

=== PyTorch Device Listing ===
No CUDA devices found
No MPS device found
DirectML / privateuseone device available: privateuseone:0
DirectML / privateuseone device available: privateuseone:1
Default device: cpu


In [3]:
# Universal device picker
def get_device():
    # Standard CUDA (NVIDIA)
    if torch.cuda.is_available():
        return torch.device("cuda:0")
    
    # Apple M1/M2
    if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
        return torch.device("mps")

    # DirectML / privateuseone (AMD or other backends)
    try:
        dml = torch.device("privateuseone:0")
        # test by allocating a tiny tensor
        torch.zeros(1).to(dml)
        return str(dml)
    except Exception:
        pass

    # CPU fallback
    return torch.device("cpu")

device = get_device()
print(f"Using device: {device}")

Using device: privateuseone:0


In [4]:
# Using pre libs pytorch and tokenizer

# Train BPE tokenizer
tokenizer = Tokenizer(models.BPE())
tokenizer.pre_tokenizer = pre_tokenizers.Whitespace()

trainer = trainers.BpeTrainer(special_tokens=["<pad>", "<unk>", "<bos>", "<eos>"], vocab_size=5000)

files = ["./datasets/raw.txt"]  # your dataset
tokenizer.train(files, trainer)

tokenizer.save("./datasets/chess_tokenizer.json")


In [5]:
tokenizer

Tokenizer(version="1.0", truncation=None, padding=None, added_tokens=[{"id":0, "content":"<pad>", "single_word":False, "lstrip":False, "rstrip":False, "normalized":False, "special":True}, {"id":1, "content":"<unk>", "single_word":False, "lstrip":False, "rstrip":False, "normalized":False, "special":True}, {"id":2, "content":"<bos>", "single_word":False, "lstrip":False, "rstrip":False, "normalized":False, "special":True}, {"id":3, "content":"<eos>", "single_word":False, "lstrip":False, "rstrip":False, "normalized":False, "special":True}], normalizer=None, pre_tokenizer=Whitespace(), post_processor=None, decoder=None, model=BPE(dropout=None, unk_token=None, continuing_subword_prefix=None, end_of_word_suffix=None, fuse_unk=False, byte_fallback=False, ignore_merges=False, vocab={"<pad>":0, "<unk>":1, "<bos>":2, "<eos>":3, "!":4, """:5, "'":6, "(":7, ")":8, "*":9, ",":10, "-":11, ".":12, "/":13, "0":14, "1":15, "2":16, "3":17, "4":18, "5":19, "6":20, "7":21, "8":22, "9":23, ":":24, ";":25, "

In [11]:
with open('./datasets/raw.txt', "r", encoding='utf-8') as file:
    text = file.read()

# Split sizes
train_size = int(0.85 * len(text))
val_size = len(text) - train_size

# File paths
train_file = "./datasets/train_data.txt"
val_file = "./datasets/val_data.txt"

# Split data
train_data = text[:train_size]
val_data = text[train_size:]

# Save
with open(train_file, "w", encoding="utf-8") as f:
    f.write(train_data)

with open(val_file, "w", encoding="utf-8") as f:
    f.write(val_data)

print("Train size:", train_size)
print("Validation size:", val_size)


Train size: 194566
Validation size: 34336


In [12]:
## Dataset class


class ChessDataset(Dataset):
    def __init__(self, file_path, tokenizer, block_size=128):
        self.tokenizer = tokenizer
        with open(file_path, 'r', encoding='utf-8') as f:
            text = f.read()
        # Encode into IDs
        self.tokens = tokenizer.encode(text).ids
        self.block_size = block_size

    def __len__(self):
        return len(self.tokens) - self.block_size

    def __getitem__(self, idx):
        x = torch.tensor(self.tokens[idx: idx+self.block_size], dtype=torch.long)
        y = torch.tensor(self.tokens[idx+1: idx+self.block_size+1], dtype=torch.long)
        return x, y


In [13]:
class FastChessDataset(Dataset):
    def __init__(self, file_path, tokenizer, block_size=64):
        with open(file_path, 'r', encoding='utf-8') as f:
            text = f.read()
        tokens = tokenizer.encode(text).ids
        self.x = []
        self.y = []
        for i in range(len(tokens) - block_size):
            self.x.append(tokens[i:i+block_size])
            self.y.append(tokens[i+1:i+block_size+1])
        self.x = torch.tensor(self.x, dtype=torch.long)
        self.y = torch.tensor(self.y, dtype=torch.long)

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        return self.x[idx], self.y[idx]


In [14]:
## Self attention

class CausalSelfAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        assert embed_dim % num_heads == 0, "Embedding dim must be divisible by num_heads"

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        # linear projections for q, k, v
        self.qkv = nn.Linear(embed_dim, embed_dim * 3)
        self.out = nn.Linear(embed_dim, embed_dim)

    def forward(self, x):
        B, T, C = x.size()  # Batch, Time, Channels
        qkv = self.qkv(x)  # (B, T, 3*embed_dim)
        qkv = qkv.reshape(B, T, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)  # (3, B, heads, T, head_dim)
        q, k, v = qkv[0], qkv[1], qkv[2]

        # Compute attention scores
        att = (q @ k.transpose(-2, -1)) / (self.head_dim ** 0.5)  # (B, heads, T, T)
        att = F.softmax(att, dim=-1)

        out = att @ v  # (B, heads, T, head_dim)
        out = out.transpose(1, 2).contiguous().reshape(B, T, C)  # (B, T, embed_dim)

        return self.out(out)


In [15]:
# transformer block


class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, ff_hidden_mult=4, dropout=0.1):
        super().__init__()
        self.attn = CausalSelfAttention(embed_dim, num_heads)
        self.ln1 = nn.LayerNorm(embed_dim)
        self.ff = nn.Sequential(
            nn.Linear(embed_dim, ff_hidden_mult * embed_dim),
            nn.ReLU(),
            nn.Linear(ff_hidden_mult * embed_dim, embed_dim),
            nn.Dropout(dropout)
        )
        self.ln2 = nn.LayerNorm(embed_dim)

    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.ff(self.ln2(x))
        return x


In [16]:
## Minitransformer


class MiniTransformerLM(nn.Module):
    def __init__(self, vocab_size, embed_dim=128, num_heads=4, num_layers=4, block_size=64):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, embed_dim)
        self.pos_emb = nn.Embedding(block_size, embed_dim)
        self.blocks = nn.Sequential(*[TransformerBlock(embed_dim, num_heads) for _ in range(num_layers)])
        self.ln_f = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.size()
        
        pos = torch.arange(0, T, device=idx.device).unsqueeze(0)
        x = self.token_emb(idx) + self.pos_emb(pos)
        x = self.blocks(x)
        x = self.ln_f(x)
        logits = self.head(x)
    
        loss = None
        if targets is not None:
            targets = targets.to(logits.device)      # also make sure targets are aligned
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1),reduction="mean")
        return logits, loss


In [128]:

def get_model_summary(checkpoint_dir="checkpoints", batch_size=64, vocab_size=5000):
    ckpts = [f for f in os.listdir(checkpoint_dir) if f.endswith(".pt")]

    if ckpts:
        # Pick the latest checkpoint by step number
        ckpts.sort(key=lambda x: int(x.split("_")[-1].split(".")[0]))
        last_ckpt = ckpts[-1]
        ckpt_path = os.path.join(checkpoint_dir, last_ckpt)
        
        checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=False)
        # print(checkpoint)
        model.load_state_dict(checkpoint['model_state_dict'])

        # ✅ Auto-detect correct block_size from pos_emb
        block_size = model.pos_emb.num_embeddings
        # print(f"block_size: {block_size}")

        # Torchinfo summary with token IDs (long tensor)
        return summary(
            model,
            input_size=(batch_size, block_size),  # use detected block size
            dtypes=[torch.long],
            device="cpu",
            col_names=["input_size", "output_size", "num_params", "trainable"]
        )

    else:
        raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}")


In [100]:
tok = Tokenizer.from_file("./datasets/chess_tokenizer.json")


In [44]:
train_ds = FastChessDataset("./datasets/train_data.txt", tok)
train_dl = DataLoader(train_ds, batch_size=64, shuffle=True, pin_memory=True)

val_ds = FastChessDataset("./datasets/val_data.txt", tok)
val_dl = DataLoader(val_ds, batch_size=64, shuffle=True, pin_memory=True)


In [45]:
len(train_ds), len(val_ds)

(46056, 8081)

In [46]:
len(train_dl), len(val_dl)

(720, 127)

In [47]:
device

'privateuseone:0'

In [70]:
eval_steps = 100
model = MiniTransformerLM(vocab_size=tok.get_vocab_size(), embed_dim=256, block_size=128, num_heads=8, num_layers=8)
# model = torch.compile(model)
model.to(device)

MiniTransformerLM(
  (token_emb): Embedding(4292, 256)
  (pos_emb): Embedding(128, 256)
  (blocks): Sequential(
    (0): TransformerBlock(
      (attn): CausalSelfAttention(
        (qkv): Linear(in_features=256, out_features=768, bias=True)
        (out): Linear(in_features=256, out_features=256, bias=True)
      )
      (ln1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (ff): Sequential(
        (0): Linear(in_features=256, out_features=1024, bias=True)
        (1): ReLU()
        (2): Linear(in_features=1024, out_features=256, bias=True)
        (3): Dropout(p=0.1, inplace=False)
      )
      (ln2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
    )
    (1): TransformerBlock(
      (attn): CausalSelfAttention(
        (qkv): Linear(in_features=256, out_features=768, bias=True)
        (out): Linear(in_features=256, out_features=256, bias=True)
      )
      (ln1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (ff): Sequential(
        (0): Lin

In [71]:
optimizer = torch.optim.AdamW(
    model.parameters(), 
    lr=3e-4, 
    betas=(0.9, 0.95),
    eps=1e-8,
    weight_decay=0.01,
    foreach=False,   # <- key line for DirectML
    fused=False 
)


In [54]:
def save_checkpoint(model, optimizer, step, training_loss, val_loss, checkpoint_dir="checkpoints"):
    os.makedirs(checkpoint_dir, exist_ok=True)
    ckpt_path = os.path.join(checkpoint_dir, f"model_step_{step}.pt")
    torch.save({
        "step": step,
        "training_loss": training_loss,   # 🔥 store loss here
        "val_loss": val_loss,   # 🔥 store loss here
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
    }, ckpt_path)
    print(f"💾 Saved checkpoint: {ckpt_path}")


In [124]:

def load_latest_checkpoint(model, optimizer, checkpoint_dir="checkpoints", device="cpu"):
    if not os.path.exists(checkpoint_dir):
        return model, optimizer, 0
    
    ckpts = [f for f in os.listdir(checkpoint_dir) if f.endswith(".pt")]
    if not ckpts:
        return model, optimizer, 0

    ckpts.sort(key=lambda x: int(x.split("_")[-1].split(".")[0]))
    last_ckpt = ckpts[-1]
    ckpt_path = os.path.join(checkpoint_dir, last_ckpt)

    # 🛠 Load on CPU first to avoid device mismatch errors
    checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=False)

    model.load_state_dict(checkpoint["model_state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
    step = checkpoint["step"]

    # Now move model + optimizer to target device
    model.to(device)
    for state in optimizer.state.values():
        for k, v in state.items():
            if isinstance(v, torch.Tensor):
                state[k] = v.to(device)

    print(f"✅ Resumed from {ckpt_path} step={step}) on {device}")
    return model, optimizer, step


In [74]:

def estimate_val_loss(model, val_loader, device):
    """
    Run model on validation set and return average loss.
    """
    model.eval()
    total_loss = 0.0
    total_batches = 0

    with torch.no_grad():
        for batch in val_loader:
            inputs, targets = [x.to(device) for x in batch]
            logits, loss = model(inputs, targets)  # model returns (logits, loss)

            total_loss += loss.item()
            total_batches += 1

    model.train()  # back to training mode
    return total_loss / total_batches if total_batches > 0 else float("inf")


def train(model, optimizer, train_loader, val_loader, device, train_steps, eval_steps=100, save_every=100, checkpoint_dir="checkpoints"):
    # Auto-resume
    model, optimizer, global_step = load_latest_checkpoint(model, optimizer, checkpoint_dir, device)

    model.to(device)
    model.train()

    start_time = time.time()  # ⏱️ start training timer
    
    for step in range(train_steps):
        for batch in train_loader:
            batch_start = time.time()
            inputs, targets = [x.to(device) for x in batch]

            logits, loss = model(inputs, targets)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            global_step += 1

            # 🔹 Validation check
            if global_step % eval_steps == 0:
                val_loss = estimate_val_loss(model, val_loader, device)

                elapsed = time.time() - start_time
                step_time = time.time() - batch_start
                print(
                    f"Step {global_step} | "
                    f"Train Loss: {loss.item():.4f} | "
                    f"Val Loss: {val_loss:.4f} | "
                    f"Step Time: {step_time:.2f}s | "
                    f"Elapsed: {elapsed:.2f}s"
                )

            # 🔹 Save checkpoint
            if global_step % save_every == 0:
                save_checkpoint(model, optimizer, global_step, loss.item(), val_loss, checkpoint_dir)

    total_time = time.time() - start_time
    print(f"🎉 Training complete in {total_time/60:.2f} minutes")


In [75]:
checkpoints = 'checkpoints/ed_256_block_size_128_head_8_layers_8'

In [125]:
train(model, optimizer, train_dl, val_loader=val_dl, device=device, train_steps=4000, checkpoint_dir=checkpoints)

✅ Resumed from checkpoints/ed_256_block_size_128_head_8_layers_8\model_step_2000.pt step=2000) on privateuseone:0


KeyboardInterrupt: 

In [129]:
get_model_summary(checkpoint_dir=checkpoints, batch_size=256, vocab_size=tok.get_vocab_size())

Layer (type:depth-idx)                        Input Shape               Output Shape              Param #                   Trainable
MiniTransformerLM                             [256, 128]                [256, 128, 4292]          --                        True
├─Embedding: 1-1                              [256, 128]                [256, 128, 256]           1,098,752                 True
├─Embedding: 1-2                              [1, 128]                  [1, 128, 256]             32,768                    True
├─Sequential: 1-3                             [256, 128, 256]           [256, 128, 256]           --                        True
│    └─TransformerBlock: 2-1                  [256, 128, 256]           [256, 128, 256]           --                        True
│    │    └─LayerNorm: 3-1                    [256, 128, 256]           [256, 128, 256]           512                       True
│    │    └─CausalSelfAttention: 3-2          [256, 128, 256]           [256, 128, 256]     

In [132]:
# del model
# del tensor   # any big tensors

import gc
gc.collect()


2113

In [115]:
## Testing model

In [130]:
def generate_stream(model, tokenizer, device, prompt="e4", max_new_tokens=50, block_size=128):
    model.eval()
    tokens = tokenizer.encode(prompt).ids
    tokens = torch.tensor(tokens, dtype=torch.long).unsqueeze(0).to(device)
    prev_text = tokenizer.decode(tokens[0].cpu().tolist())

    yield prompt
    
    for _ in range(max_new_tokens):
        logits, _ = model(tokens[:, -block_size:])   # no .to(device) here
        probs = torch.softmax(logits[:, -1, :], dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)
        tokens = torch.cat([tokens, next_token.to(tokens.device)], dim=1)

        new_text = tokenizer.decode(tokens[0].cpu().tolist())
        yield new_text[len(prev_text):]   # only the newly added chunk
        prev_text = new_text


In [131]:
for tok in generate_stream(model, tokenizer, device, prompt="e4 e5", max_new_tokens=128):
    print(tok, end="", flush=True)

e4 e5

RuntimeError: tensor.device().type() == at::DeviceType::PrivateUse1 INTERNAL ASSERT FAILED at "C:\\__w\\1\\s\\pytorch-directml-plugin\\torch_directml\\csrc\\dml\\DMLTensor.cpp":31, please report a bug to PyTorch. unbox expects Dml at::Tensor as inputs