<a href="https://colab.research.google.com/github/jingmingliu01/build-GPT-from-scratch/blob/main/bigram_build_GPT_from_scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# import

In [None]:
import torch
import torch.nn as nn
from torch.nn import functional as F

# 1.hyperparameters

In [None]:
batch_size=32
block_size=8
max_iters=3000
eval_interval=300
learning_rate=1e-2
device='cuda' if torch.cuda.is_available() else 'cpu'
eval_iters=200

torch.manual_seed(1337)

<torch._C.Generator at 0x7b0f1052f210>

In [None]:
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

--2025-03-24 04:27:52--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt.2’


2025-03-24 04:27:52 (18.4 MB/s) - ‘input.txt.2’ saved [1115394/1115394]



In [None]:
with open('input.txt','r',encoding='utf-8') as f:
  text=f.read()

# 2.data preparation

In [None]:
chars=sorted(list(set(text)))
vocab_size=len(chars)
stoi={ch:i for i,ch in enumerate(chars)}
itos={i:ch for i,ch in enumerate(chars)}
encode=lambda s:[stoi[c] for c in s]
decode=lambda l:''.join([itos[i] for i in l])

data=torch.tensor(encode(text),dtype=torch.long)
n=int(0.9*len(data))
train_data=data[:n]
val_data=data[n:]

# 3.data loading

In [None]:
def get_batch(split):
  data=train_data if split=='train' else val_data
  ix=torch.randint(len(data)-block_size,(batch_size,))
  x=torch.stack([data[i:i+block_size] for i in ix])
  y=torch.stack([data[i+1:i+block_size+1] for i in ix])
  x,y=x.to(device),y.to(device)
  return x,y

In [None]:
@torch.no_grad()
def estimate_loss():
  out = {}
  model.eval()
  for split in ['train','val']:
    losses = torch.zeros(eval_iters)
    for k in range(eval_iters):
      X,Y=get_batch(split)
      logits,loss=model(X,Y)
      losses[k]=loss.item()
    out[split]=losses.mean()
  model.train()
  return out

# 4.model define

In [None]:
class BigramLanguageModel(nn.Module):
  def __init__(self,vocab_size):
    super().__init__()
    self.token_embedding_table=nn.Embedding(vocab_size,vocab_size)

  def forward(self,idx,targets=None):
    logits=self.token_embedding_table(idx)

    if targets is None:
      loss=None
    else:
      B,T,C=logits.shape
      logits=logits.view(B*T,C)
      targets=targets.view(B*T)
      loss = F.cross_entropy(logits,targets)

    return logits,loss

  def generate(self,idx,max_new_tokens):
    for _ in range(max_new_tokens):
      logits,loss = self(idx)
      logits=logits[:,-1,:]
      probs = F.softmax(logits,dim=-1)
      idx_next=torch.multinomial(probs,num_samples=1)
      idx=torch.cat((idx,idx_next),dim=1)

    return idx

# 5.train and generate

In [None]:
model=BigramLanguageModel(vocab_size)
m=model.to(device)

optimizer=torch.optim.AdamW(model.parameters(),lr=learning_rate)

In [None]:
for iter in range(max_iters):
  if iter % eval_interval==0:
    losses=estimate_loss()
    print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

  xb,yb=get_batch('train')

  logits,loss=model(xb,yb)
  optimizer.zero_grad(set_to_none=True)
  loss.backward()
  optimizer.step()

step 0: train loss 4.7305, val loss 4.7241
step 300: train loss 2.8110, val loss 2.8249
step 600: train loss 2.5434, val loss 2.5682
step 900: train loss 2.4932, val loss 2.5088
step 1200: train loss 2.4863, val loss 2.5035
step 1500: train loss 2.4665, val loss 2.4921
step 1800: train loss 2.4683, val loss 2.4936
step 2100: train loss 2.4696, val loss 2.4846
step 2400: train loss 2.4638, val loss 2.4879
step 2700: train loss 2.4738, val loss 2.4911


In [None]:
context=torch.zeros((1,1),dtype=torch.long,device=device)
print(decode(m.generate(context,max_new_tokens=500)[0].tolist()))




CEThik brid owindakis b, bth

HAPet bobe d e.
S:
O:3 my d?
LUCous:
Wanthar u qur, t.
War dXENDoate awice my.

Hastarom oroup
Yowhthetof isth ble mil ndill, ath iree sengmin lat Heriliovets, and Win nghir.
Swanousel lind me l.
HAshe ce hiry:
Supr aisspllw y.
Hentofu n Boopetelaves
MPOLI s, d mothakleo Windo whth eisbyo the m dourive we higend t so mower; te

AN ad nterupt f s ar igr t m:

Thin maleronth,
Mad
RD:

WISo myrangoube!
KENob&y, wardsal thes ghesthinin couk ay aney IOUSts I&fr y ce.
J
