<a href="https://colab.research.google.com/github/hardik-vala/unicorn-namegen/blob/main/namegen_gpu_test.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install wandb

Collecting wandb
  Downloading wandb-0.16.6-py3-none-any.whl (2.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m7.6 MB/s[0m eta [36m0:00:00[0m
Collecting GitPython!=3.1.29,>=1.0.0 (from wandb)
  Downloading GitPython-3.1.43-py3-none-any.whl (207 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m207.3/207.3 kB[0m [31m9.3 MB/s[0m eta [36m0:00:00[0m
Collecting sentry-sdk>=1.0.0 (from wandb)
  Downloading sentry_sdk-2.1.1-py2.py3-none-any.whl (277 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m277.3/277.3 kB[0m [31m7.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting docker-pycreds>=0.4.0 (from wandb)
  Downloading docker_pycreds-0.4.0-py2.py3-none-any.whl (9.0 kB)
Collecting setproctitle (from wandb)
  Downloading setproctitle-1.3.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (30 kB)
Collecting gitdb<5,>=4.0.1 (from GitPython!=3.1.29,>=1.0.0->wandb

In [2]:
from dataclasses import dataclass

import torch
import torch.nn as nn
from torch.nn import functional as F


class Head(nn.Module):
    """one head of self-attention"""

    def __init__(self, config):
        super().__init__()
        head_size = config.n_embd // config.n_head
        self.key = nn.Linear(config.n_embd, head_size, bias=False)
        self.query = nn.Linear(config.n_embd, head_size, bias=False)
        self.value = nn.Linear(config.n_embd, head_size, bias=False)
        self.register_buffer(
            "tril", torch.tril(torch.ones(config.block_size, config.block_size))
        )
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x)  # (B,T,C)
        q = self.query(x)  # (B,T,C)
        # compute attention scores ("affinities")
        wei = q @ k.transpose(-2, -1) * C**-0.5  # (B,T,C) @ (B,C,T) -> (B,T,T)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float("-inf"))  # (B,T,T)
        wei = F.softmax(wei, dim=-1)  # (B,T,T)
        wei = self.dropout(wei)
        # perform the weighted aggregation of the values
        v = self.value(x)  # (B,T,C)
        out = wei @ v  # (B,T,T) @ (B,T,C) -> (B,T,C)
        return out


class MultiHeadAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.heads = nn.ModuleList([Head(config) for _ in range(config.n_head)])
        self.proj = nn.Linear(config.n_embd, config.n_embd)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)  # (B, T, C)
        return self.dropout(self.proj(out))


class FeedForward(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(config.n_embd, 4 * config.n_embd),
            nn.ReLU(),
            nn.Linear(4 * config.n_embd, config.n_embd),
            nn.Dropout(config.dropout),
        )

    def forward(self, x):
        return self.net(x)


class LayerNorm:

    def __init__(self, dim, eps=1e-5, momentum=0.1):
        self.eps = eps
        self.gamma = torch.ones(dim)
        self.beta = torch.zeros(dim)

    def __call__(self, x):
        # calculate the forward pass
        xmean = x.mean(1, keepdim=True)
        xvar = x.var(1, keepdim=True)
        xhat = (x - xmean) / torch.sqrt(xvar + self.eps)  # normalize to unit variance
        self.out = self.gamma * xhat + self.beta
        return self.out

    def parameters(self):
        return [self.gamma, self.beta]


class Block(nn.Module):
    """Transformer Block: Communication followed by computation"""

    def __init__(self, config):
        super().__init__()
        self.sa = MultiHeadAttention(config)
        self.ffwd = FeedForward(config)
        self.ln1 = nn.LayerNorm(config.n_embd)
        self.ln2 = nn.LayerNorm(config.n_embd)

    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x


@dataclass
class ModelConfig:
    block_size: int
    vocab_size: int
    n_layer: int
    n_head: int
    n_embd: int
    dropout: float = 0.0
    bias: bool = True  # True: bias in Linears and LayerNorms, like GPT-2.


class Namegen(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.token_embedding_table = nn.Embedding(config.vocab_size, config.n_embd)
        self.position_embedding_table = nn.Embedding(config.block_size, config.n_embd)
        modules = [Block(config) for _ in range(config.n_layer)] + [
            nn.LayerNorm(config.n_embd)
        ]
        self.blocks = nn.Sequential(*modules)
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size)
        self.ffwd = FeedForward(config)
        self.block_size = config.block_size

    def forward(self, idx, targets=None):
        device = idx.device
        B, T = idx.shape
        tok_emb = self.token_embedding_table(idx)  # (B, T, C)
        pos = torch.arange(T, device=device)
        pos_emb = self.position_embedding_table(pos)  # (T, C)
        x = tok_emb + pos_emb  # (B, T, C)
        x = self.blocks(x)  # (B, T, C)
        logits = self.lm_head(x)  # (B, T, vocab_size)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B * T, C)
            targets = targets.view(B * T)
            loss = F.cross_entropy(logits, targets)
        return logits, loss

    def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
        # start with all of the candidate parameters
        param_dict = {pn: p for pn, p in self.named_parameters()}
        # filter out those that do not require grad
        param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
        # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
        # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
        decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
        nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
        optim_groups = [
            {"params": decay_params, "weight_decay": weight_decay},
            {"params": nodecay_params, "weight_decay": 0.0}
        ]
        num_decay_params = sum(p.numel() for p in decay_params)
        num_nodecay_params = sum(p.numel() for p in nodecay_params)
        print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
        print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
        # Create AdamW optimizer and use the fused version if it is available
        # fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
        # use_fused = fused_available and device_type == 'cuda'
        # extra_args = dict(fused=True) if use_fused else dict()
        extra_args = dict()
        optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)
        # print(f"using fused AdamW: {use_fused}")

        return optimizer

    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -self.block_size :]
            logits, loss = self(idx_cond)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)
        return idx


In [3]:
from contextlib import nullcontext
import numpy as np
import os
import pickle
import time
import torch

# ------------------------------------------------------------------------------
# I/O
out_dir = "out"
eval_iters = 200
log_interval = 1000
# wandb logging
wandb_log = True
wandb_project = "namegen"
wandb_run_name = "506-80V-gpu"
# data
dataset = "names"
batch_size = 64
block_size = 20  # context length
vocab_size = 80
# model
n_layer = 6
n_head = 6
n_embd = 84
dropout = 0.2  # for pretraining 0 is good, for finetuning try 0.1+
bias = False  # do we use bias inside LayerNorm and Linear layers?
write_checkpoint = True
# adamw optimizer
learning_rate = 1e-4  # max learning rate
max_iters = 40000
weight_decay=1e-1
beta1 = 0.9
beta2 = 0.99
# system
device = (
    "cuda"  # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks
)
# ------------------------------------------------------------------------------
config_keys = [
    k
    for k, v in globals().items()
    if not k.startswith("_") and isinstance(v, (int, float, bool, str))
]
config = {k: globals()[k] for k in config_keys}  # for logging
# ------------------------------------------------------------------------------

tokens_per_iter = batch_size * block_size
print(f"tokens per iteration will be: {tokens_per_iter:,}")

torch.manual_seed(24)
device_type = "cuda" if "cuda" in device else "cpu" # for later use in torch.autocast
ctx = nullcontext()

# data loader
data_dir = os.path.join("data", dataset)


def get_batch(split):
    if split == "train":
        data = np.memmap(os.path.join("data", "train.bin"), dtype=np.uint16, mode="r")
    else:
        data = np.memmap(os.path.join("data", "val.bin"), dtype=np.uint16, mode="r")
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack(
        [torch.from_numpy((data[i : i + block_size]).astype(np.int64)) for i in ix]
    )
    y = torch.stack(
        [
            torch.from_numpy((data[i + 1 : i + block_size + 1]).astype(np.int64))
            for i in ix
        ]
    )
    if device_type == "cuda":
        # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True)
        x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
    else:
      x, y = x.to(device), y.to(device)
    return x, y


# attempt to derive vocab_size from the dataset
meta_path = os.path.join("data", "meta.pkl")
meta_vocab_size = None
if os.path.exists(meta_path):
    with open(meta_path, "rb") as f:
        meta = pickle.load(f)
    meta_vocab_size = meta["vocab_size"]
    print(f"found vocab_size = {meta_vocab_size} (inside {meta_path})")

# model init
model_args = dict(
    n_layer=n_layer,
    n_head=n_head,
    n_embd=n_embd,
    block_size=block_size,
    bias=bias,
    vocab_size=meta_vocab_size,
    dropout=dropout,
)

modelconf = ModelConfig(**model_args)
model = Namegen(modelconf)
model.to(device)

# optimizer
optimizer = model.configure_optimizers(weight_decay, learning_rate, (beta1, beta2), device_type)


# helps estimate an arbitrarily accurate loss over either split using many batches
@torch.no_grad()
def estimate_loss(model):
    out = {}
    model.eval()
    for split in ["train", "val"]:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            with ctx:
              logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out


# logging
if wandb_log:
    import wandb
    wandb.init(project=wandb_project, name=wandb_run_name, config=config)


# training loop
t0 = time.time()
for step in range(max_iters):
    if step % log_interval == 0:
        losses = estimate_loss(model)
        print(
            f"step {step}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}"
        )
        if wandb_log:
            wandb.log(
                {"step": step, "train/loss": losses["train"], "val/loss": losses["val"]}
            )
    xb, yb = get_batch("train")
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

# timing
t1 = time.time()
dt = t1 - t0
print(f"time: {dt:.2f}s")

# write checkpoint
if write_checkpoint:
    checkpoint = {
        "model": model.state_dict(),
        "optimizer": optimizer.state_dict(),
        "model_args": model_args,
        "iter_num": step,
        "final_val_loss": losses["val"],
        "config": config,
    }
    print(f"saving checkpoint to {out_dir}")
    torch.save(checkpoint, os.path.join(out_dir, "ckpt.pt"))


tokens per iteration will be: 1,280
found vocab_size = 80 (inside data/meta.pkl)
num decayed parameter tensors: 131, with 579,600 parameters
num non-decayed parameter tensors: 47, with 5,708 parameters


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

 ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


step 0: train loss 4.4896, val loss 4.4908
step 1000: train loss 2.9118, val loss 2.9180
step 2000: train loss 2.8097, val loss 2.8125
step 3000: train loss 2.7186, val loss 2.7196
step 4000: train loss 2.6344, val loss 2.6404
step 5000: train loss 2.5595, val loss 2.5786
step 6000: train loss 2.5011, val loss 2.5197
step 7000: train loss 2.4515, val loss 2.4698
step 8000: train loss 2.4206, val loss 2.4421
step 9000: train loss 2.3817, val loss 2.4063
step 10000: train loss 2.3488, val loss 2.3805
step 11000: train loss 2.3344, val loss 2.3620
step 12000: train loss 2.3190, val loss 2.3421
step 13000: train loss 2.2970, val loss 2.3299
step 14000: train loss 2.2739, val loss 2.3088
step 15000: train loss 2.2665, val loss 2.2964
step 16000: train loss 2.2531, val loss 2.2826
step 17000: train loss 2.2351, val loss 2.2611
step 18000: train loss 2.2312, val loss 2.2600
step 19000: train loss 2.2145, val loss 2.2479
step 20000: train loss 2.1993, val loss 2.2480
step 21000: train loss 2.1

In [8]:
from contextlib import nullcontext
import os
import pickle
import torch

# -----------------------------------------------------------------------------
out_dir = "out"  # ignored if init_from is not 'resume'
num_samples = 10  # number of samples to draw
max_new_tokens = 100  # number of tokens generated in each sample
temperature = (
    0.8  # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions
)
top_k = (
    200  # retain only the top_k most likely tokens, clamp others to have 0 probability
)
seed = 24
device = "cuda"  # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc.
# -----------------------------------------------------------------------------

# torch.manual_seed(seed)
device_type = "cuda" if "cuda" in device else "cpu" # for later use in torch.autocast
ctx = nullcontext()

ckpt_path = os.path.join(out_dir, "ckpt.pt")
checkpoint = torch.load(ckpt_path, map_location=device)
modelconf = ModelConfig(**checkpoint["model_args"])
model = Namegen(modelconf)
state_dict = checkpoint["model"]
unwanted_prefix = "_orig_mod."
for k, v in list(state_dict.items()):
    if k.startswith(unwanted_prefix):
        state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k)
model.load_state_dict(state_dict)

model.eval()
model.to(device)

meta_path = os.path.join("data", "meta.pkl")
print(f"Loading meta from {meta_path}...")
with open(meta_path, "rb") as f:
    meta = pickle.load(f)
merges = meta["merges"]
stoi, itos = meta["stoi"], meta["itos"]
encode1 = lambda s: [stoi[c] for c in s]
decode1 = lambda l: "".join([itos[i] for i in l])

def unmerge(ids, pair, idx):
    newids = []
    for i in ids:
        if i == idx:
            newids.append(pair[0])
            newids.append(pair[1])
        else:
            newids.append(i)
    return newids


def decode(ids):
    tokens = list(ids)
    for pair, idx in reversed(merges.items()):
        tokens = unmerge(tokens, pair, idx)
    return decode1(tokens)

sample_cnt = 0
with torch.no_grad():
    while True:
        x = torch.full((1, 1), stoi["!"], dtype=torch.long, device=device)
        y = model.generate(x, max_new_tokens)
        raw = decode(y[0].tolist())
        parts = raw.split("!")
        for i in range(1, len(parts) - 1):
            print(parts[i])
            sample_cnt += 1
            if sample_cnt >= num_samples:
                break
        if sample_cnt >= num_samples:
            break


Loading meta from data/meta.pkl...
Surfish Diagnostics
Upucom
Pavix Jewell
Gorge Hands OverWakers
CarmaCental LLS
VID
dy Lee
Basey Lea Holding
PartnerApartments Partners
TriLabs
