<a href="https://colab.research.google.com/github/fabiancpl/my-gpt/blob/master/my-gpt.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F

In [2]:
torch.manual_seed(1337)

<torch._C.Generator at 0x7f00343b9870>

### Download and explore the data

In [3]:
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

--2024-02-28 00:35:41--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt’


2024-02-28 00:35:41 (57.2 MB/s) - ‘input.txt’ saved [1115394/1115394]



In [4]:
with open("input.txt", "r", encoding="utf-8") as f:
    text = f.read()

In [5]:
print("Length of the dataset in characters:", len(text))

Length of the dataset in characters: 1115394


In [6]:
print(text[:1000])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us, the object of our misery, is as an
inventory to particularise their abundance; our
sufferance is a gain to them Let us revenge this with
our pikes, ere we become rakes: for the gods know I
speak this in hunger for bread, not in thirst for revenge.



In [7]:
chars = sorted(list(set(text)))
vocab_size = len(chars)
print("".join(chars))
print(vocab_size)


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
65


### Define the tokenizer

In [8]:
stoi = { ch: i for i, ch in enumerate(chars) }
itos = { i: ch for i, ch in enumerate(chars) }

encode = lambda s: [ stoi[c] for c in s ]
decode = lambda l: "".join([ itos[i] for i in l ])

In [9]:
print(encode("hi there"))
print(decode(encode("hi there")))

[46, 47, 1, 58, 46, 43, 56, 43]
hi there


In [10]:
data = torch.tensor(encode(text), dtype=torch.long)

In [11]:
print(data.shape, data.dtype)
print(data[:1000])

torch.Size([1115394]) torch.int64
tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43, 44,
        53, 56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39, 52, 63,
         1, 44, 59, 56, 58, 46, 43, 56,  6,  1, 46, 43, 39, 56,  1, 51, 43,  1,
        57, 54, 43, 39, 49,  8,  0,  0, 13, 50, 50, 10,  0, 31, 54, 43, 39, 49,
         6,  1, 57, 54, 43, 39, 49,  8,  0,  0, 18, 47, 56, 57, 58,  1, 15, 47,
        58, 47, 64, 43, 52, 10,  0, 37, 53, 59,  1, 39, 56, 43,  1, 39, 50, 50,
         1, 56, 43, 57, 53, 50, 60, 43, 42,  1, 56, 39, 58, 46, 43, 56,  1, 58,
        53,  1, 42, 47, 43,  1, 58, 46, 39, 52,  1, 58, 53,  1, 44, 39, 51, 47,
        57, 46, 12,  0,  0, 13, 50, 50, 10,  0, 30, 43, 57, 53, 50, 60, 43, 42,
         8,  1, 56, 43, 57, 53, 50, 60, 43, 42,  8,  0,  0, 18, 47, 56, 57, 58,
         1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 18, 47, 56, 57, 58,  6,  1, 63,
        53, 59,  1, 49, 52, 53, 61,  1, 15, 39, 47, 59, 57,  1, 25, 39, 56, 41,
      

### Split the dataset

In [12]:
n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]

In [13]:
block_size = 8 # The maximum context length for predictions
train_data[:block_size+1]

tensor([18, 47, 56, 57, 58,  1, 15, 47, 58])

In [14]:
x = train_data[:block_size]
y = train_data[1:block_size+1]

for t in range(block_size):
    context = x[:t+1]
    target = y[t]
    print(f"When input is {context} the target is {target}")

When input is tensor([18]) the target is 47
When input is tensor([18, 47]) the target is 56
When input is tensor([18, 47, 56]) the target is 57
When input is tensor([18, 47, 56, 57]) the target is 58
When input is tensor([18, 47, 56, 57, 58]) the target is 1
When input is tensor([18, 47, 56, 57, 58,  1]) the target is 15
When input is tensor([18, 47, 56, 57, 58,  1, 15]) the target is 47
When input is tensor([18, 47, 56, 57, 58,  1, 15, 47]) the target is 58


### Define the batch function

In [15]:
batch_size = 4 # Number of independent sequences to be processed in parallel

In [16]:
def get_batch(split):
  """ Generate a small batch of data of inputs x and targets y. """
  data = train_data if split == "train" else val_data
  ix = torch.randint(len(data) - block_size, (batch_size,))
  x = torch.stack([data[i:i+block_size] for i in ix])
  y = torch.stack([data[i+1:i+block_size+1] for i in ix])
  return x, y

In [17]:
xb, yb = get_batch("train")

print("Inputs:")
print(xb.shape)
print(xb)

print("Targets:")
print(yb.shape)
print(yb)

Inputs:
torch.Size([4, 8])
tensor([[24, 43, 58,  5, 57,  1, 46, 43],
        [44, 53, 56,  1, 58, 46, 39, 58],
        [52, 58,  1, 58, 46, 39, 58,  1],
        [25, 17, 27, 10,  0, 21,  1, 54]])
Targets:
torch.Size([4, 8])
tensor([[43, 58,  5, 57,  1, 46, 43, 39],
        [53, 56,  1, 58, 46, 39, 58,  1],
        [58,  1, 58, 46, 39, 58,  1, 46],
        [17, 27, 10,  0, 21,  1, 54, 39]])


## Create a bigram model

In [32]:
class BigramLanguageModel(nn.Module):

    def __init__(self, vocab_size):
        super().__init__()

        # Each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)

    def forward(self, idx, targets=None):
        # idx and targets are both (B,T) tensor of integers
        logits = self.token_embedding_table(idx) # (B,T,C)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx is (B,T) array of indices in the current context
        for _ in range(max_new_tokens):
            # get the predictions
            logits, _ = self(idx)
            # focus only on the last time step
            logits = logits[:,-1,:] # becomes (B,C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B,C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B,1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B,T+1)
        return idx

In [33]:
m = BigramLanguageModel(vocab_size)

In [34]:
logits, loss = m (xb, yb)

In [35]:
print(logits.shape)
print(loss)

torch.Size([32, 65])
tensor(4.5262, grad_fn=<NllLossBackward0>)


In [36]:
print(decode(m.generate(idx = torch.zeros((1, 1), dtype=torch.long), max_new_tokens=100)[0].tolist()))


'JgC.JZWqUkpdtkSpmzjM-,RqzgaN?vC:hgjnAnBZDga-APqGUH!WdCbIb;$DefOYbEvcaKGMmnO'q$KdS-'ZH
.YSqr'X!Q! d;


In [37]:
# create a PyTorch optimizer
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)

In [38]:
batch_size = 32

In [45]:
for steps in range(1000):
    # sample a batch of data
    xb, yb = get_batch("train")

    # evaluate the loss
    logits, loss = m(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

    print(loss.item())

2.5532820224761963
2.6917624473571777
2.567890167236328
2.688652992248535
2.6838221549987793
2.567260265350342
2.6408329010009766
2.6442699432373047
2.6120669841766357
2.696449041366577
2.7054357528686523
2.614013433456421
2.549588203430176
2.594839572906494
2.5640158653259277
2.644991159439087
2.689906597137451
2.554537296295166
2.496067762374878
2.6992199420928955
2.6552038192749023
2.7516117095947266
2.798856019973755
2.5983567237854004
2.5689709186553955
2.729954719543457
2.6315274238586426
2.6256508827209473
2.664034366607666
2.640119791030884
2.5638206005096436
2.6674554347991943
2.6244919300079346
2.613607168197632
2.573763847351074
2.6594362258911133
2.5143492221832275
2.63219952583313
2.5816938877105713
2.72271466255188
2.6074681282043457
2.7085769176483154
2.6815378665924072
2.688018560409546
2.524186849594116
2.755908966064453
2.490719795227051
2.6219358444213867
2.5467312335968018
2.6935389041900635
2.6546945571899414
2.5186445713043213
2.5875463485717773
2.643221855163574


In [47]:
print(decode(m.generate(idx = torch.zeros((1, 1), dtype=torch.long), max_new_tokens=500)[0].tolist()))



DUEveninOfo.
P.
QTH:  spises'mu ixaco's:ha'thber d twesthaWh nfanWh &d nqf yp; Mqonal3fowat hthispreldGLI ce oicun akerst:y LI'le;hazthRM:
Y:
MNonowine sat Frme sPat d:
G s d.
FFins,
MYin o oby, cri&oote.forar XR:
V:
Toryhangr-ke w s.
UMPs inPeORET:; bueAnththe,Ru, houlAsld ddf nPUHofo eimery:
we FBetor br ay,

HARDEmpamzur.p wover tos nS:

huLOurknoum l-
TY:

Tom; Coau woE pu-
HELLOfou hivas
Honomor, h E:
Y$DWhou iovCELI s mes ve3herePWLound
Foveeend f s;
QX

TTEwi&JUurd, thlloth Vbomy avPULLc
