<a href="https://colab.research.google.com/github/harryypham/MyMLPractice/blob/main/practice/rnn_shakespeare_draft.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
!pip install tqdm

--2024-06-25 10:37:49--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt.1’


2024-06-25 10:37:49 (127 MB/s) - ‘input.txt.1’ saved [1115394/1115394]



In [3]:
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torchvision import datasets
from torch.utils.data import DataLoader


In [37]:
device = "cuda" if torch.cuda.is_available() else "cpu"
sequence_length = 512
num_layers = 3
hidden_size = 1024
emb_dim = 256
batch_size = 64
lr = 3e-4
num_epochs = 5

In [5]:
with open('input.txt','r', encoding='utf-8') as f:
  text = f.read()
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(''.join(chars))
print(vocab_size)


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
65


In [6]:
stoi = {ch:i for i,ch in enumerate(chars)}
itos = {i:ch for i,ch in enumerate(chars)}
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])

In [7]:
data = torch.tensor(encode(text), dtype=torch.long)
batch_size = 64
block_size = 512


def get_batch():
  ix = torch.randint(len(data) - block_size, (batch_size,))
  x = torch.stack([data[i:i+block_size] for i in ix])
  y = torch.stack([data[i+1:i+block_size+1] for i in ix])
  return x, y

xb, yb = get_batch()

In [38]:
class LSTM(nn.Module):
  def __init__(self, sequence_length, hidden_size, emb_dim, num_layers, vocab_size):
    super().__init__()
    self.embedding_table = nn.Embedding(vocab_size, emb_dim)
    self.hidden_size = hidden_size
    self.num_layers = num_layers
    self.lstm = nn.LSTM(emb_dim, hidden_size, num_layers, batch_first=True)
    self.fc = nn.Linear(hidden_size, vocab_size)

  def forward(self, x, targets=None):
    h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
    c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)

    x = self.embedding_table(x)
    logits, _ = self.lstm(x, (h0, c0))
    if targets is not None:
      B, T, C = logits.shape
      logits = self.fc(logits.reshape(B*T, C))
      targets = targets.view(-1)
      loss = F.cross_entropy(logits, targets)
    else:
      logits = self.fc(logits[:, -1, :])
      loss = None
    return logits, loss

  def generate(self):
    idx = torch.zeros((1,1), dtype=torch.long)
    for _ in range(400):
      idx = idx.to(device)
      logits, loss = self(idx)

      probs = F.softmax(logits, dim=1)
      idx_next = torch.multinomial(probs, num_samples=1)
      idx = torch.cat([idx, idx_next], dim=1)
    return idx

In [11]:
def train(model, optimizer, data_loader, num_epochs, device):
  model.train()
  losses = []
  for epoch in range(1, num_epochs+1):
    print(f"Epoch {epoch}: ")
    pbar = tqdm(data_loader, leave=True)
    for batch_idx, (input, target) in enumerate(pbar):
      input = input.to(device).squeeze(1)
      target = target.to(device)

      logits, loss = model(input)

      losses.append(loss.item())
      optimizer.zero_grad()
      loss.backward()

      optimizer.step()

      pbar.set_postfix({"Loss": round(sum(losses)/len(losses), 4)})

@torch.no_grad()
def check_accuracy(model, data_loader, device):
  correct = total = 0

  model.eval()
  for input, target in data_loader:
    input = input.to(device).squeeze(1)
    target = target.to(device)

    output = model(input)

    _, preds = output.max(1)
    correct += (preds == target).sum()
    total += target.size(0)

  print(f"Accuracy: {correct/total*100:.2f}")
  model.train()


In [39]:
model = LSTM(sequence_length, hidden_size, emb_dim, num_layers, vocab_size).to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)

losses = []
for iter in range(1500):
  inputs, targets = get_batch()
  inputs = inputs.to(device)
  targets = targets.to(device)

  logits, loss = model(inputs, targets)
  optimizer.zero_grad()
  loss.backward()
  losses.append(loss.item())
  if iter % 50 == 0:
    print("iter", iter, sum(losses)/len(losses))

  if iter != 0 and iter % 1000 == 0:
    for g in optimizer.param_groups:
      g['lr'] = 3e-5

  optimizer.step()

iter 0 4.177000999450684
iter 50 3.320663475522808
iter 100 2.978285305570848
iter 150 2.718244893661398
iter 200 2.5318225100265805
iter 250 2.3888451421403314
iter 300 2.2735725494714276
iter 350 2.1782456979452713
iter 400 2.097607871243484
iter 450 2.0282790906677755
iter 500 1.9677050760882104
iter 550 1.9146102560842968
iter 600 1.867792470879642
iter 650 1.8251316135380125
iter 700 1.786496987186383
iter 750 1.7513975071684815
iter 800 1.719051100043917
iter 850 1.6888270036314126
iter 900 1.6607740118553846
iter 950 1.6340923724239682
iter 1000 1.609220661721625
iter 1050 1.5850331475913921
iter 1100 1.562580522356198
iter 1150 1.5417856807609314
iter 1200 1.522591334397747
iter 1250 1.50456757720807
iter 1300 1.487849196134211
iter 1350 1.472332255916186
iter 1400 1.4576667875668392
iter 1450 1.4438576870832174


In [40]:
print(decode(model.generate()[0].tolist()))


Be sits fair to touch the Clarence of his bent?

SICINIUS:
Dar'sness grows,
A beg of vilaging! I am teldow, my father.

VELERIA:
Prey, into thy clouds! what is it not?

First Gentleman:
Aming but it ears you ever since the word
Upon that broads 'long by the instrument:
They that lay hath something it as time
Swell commending pretition,
And ask along to me.

Second Murderer:
Speak no me, did I resi
