In [69]:
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [70]:
len(text)

1115393

In [71]:
print(text[:300])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us


In [72]:
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(''.join(chars))
print(vocab_size)


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
65


In [73]:
token_to_idx = {ch: i for i, ch in enumerate(chars)}
idx_to_token = {i: ch for i, ch in enumerate(chars)}
def encode(text):
    return [token_to_idx.get(ch, 0) for ch in text]

def decode(text_encoded):
    return ''.join([idx_to_token.get(i, '') for i in text_encoded])

encode("Oh Lord")


[27, 46, 1, 24, 53, 56, 42]

In [74]:
decode([27, 46, 1, 24, 53, 56, 42])

'Oh Lord'

In [75]:
import torch
data = torch.tensor(encode(text), dtype=torch.long)
data[:300]

tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43, 44,
        53, 56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39, 52, 63,
         1, 44, 59, 56, 58, 46, 43, 56,  6,  1, 46, 43, 39, 56,  1, 51, 43,  1,
        57, 54, 43, 39, 49,  8,  0,  0, 13, 50, 50, 10,  0, 31, 54, 43, 39, 49,
         6,  1, 57, 54, 43, 39, 49,  8,  0,  0, 18, 47, 56, 57, 58,  1, 15, 47,
        58, 47, 64, 43, 52, 10,  0, 37, 53, 59,  1, 39, 56, 43,  1, 39, 50, 50,
         1, 56, 43, 57, 53, 50, 60, 43, 42,  1, 56, 39, 58, 46, 43, 56,  1, 58,
        53,  1, 42, 47, 43,  1, 58, 46, 39, 52,  1, 58, 53,  1, 44, 39, 51, 47,
        57, 46, 12,  0,  0, 13, 50, 50, 10,  0, 30, 43, 57, 53, 50, 60, 43, 42,
         8,  1, 56, 43, 57, 53, 50, 60, 43, 42,  8,  0,  0, 18, 47, 56, 57, 58,
         1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 18, 47, 56, 57, 58,  6,  1, 63,
        53, 59,  1, 49, 52, 53, 61,  1, 15, 39, 47, 59, 57,  1, 25, 39, 56, 41,
        47, 59, 57,  1, 47, 57,  1, 41, 

In [76]:
n = int(0.9*len(data)) # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]

In [80]:
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')


class Dataset(torch.utils.data.Dataset):
    def __init__(self, data, seq_len):
        self.data = data.to(device)
        self.seq_len = seq_len

    def __len__(self):
        return len(self.data) - self.seq_len

    def __getitem__(self, i):
        x = self.data[i:i+self.seq_len]
        y= self.data[i+self.seq_len+1] if i!=len(self.data)-self.seq_len-1 else self.data[0]
        return (x, y)
    
seq_len = 65
train_ds = Dataset(train_data, seq_len)
val_ds = Dataset(val_data, seq_len)

train_ds[0]

(tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43, 44,
         53, 56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39, 52, 63,
          1, 44, 59, 56, 58, 46, 43, 56,  6,  1, 46, 43, 39, 56,  1, 51, 43,  1,
         57, 54, 43, 39, 49,  8,  0,  0, 13, 50, 50], device='cuda:0'),
 tensor(0, device='cuda:0'))

In [102]:
class SimpleModel(torch.nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim):
        super(SimpleModel, self).__init__()
        self.embedding = torch.nn.Embedding(vocab_size, embedding_dim)
        self.rnn = torch.nn.LSTM(embedding_dim, hidden_dim, batch_first=True, num_layers=4)
        self.fc = torch.nn.Linear(hidden_dim, vocab_size)
        
    def forward(self, x):
        x = self.embedding(x)
        x, _ = self.rnn(x)
        x = self.fc(x) 
        # we want to return the last element of the sequence
        # batch_size, seq_len, vocab_size
        x = x[:, -1, :]
        # batch_size, vocab_size
        return x
    
model = SimpleModel(vocab_size, 64, 128).to(device)
train_loader=torch.utils.data.DataLoader(train_ds, batch_size=32, shuffle=True)

for xb, yb in train_loader:
    print(xb.shape, yb.shape)
    out = model(xb)
    print(out.shape)
    break

torch.Size([32, 65]) torch.Size([32])
torch.Size([32, 65])


In [103]:
def generate_text(model, text, length):
    model.eval()
    with torch.no_grad():
        for i in range(length):
            x = torch.tensor(encode(text), dtype=torch.long).to(device)
            x = x.unsqueeze(0)
            y = model(x).argmax()
            text += idx_to_token[y.item()]
    return text

generate_text(model, "Oh Lord", 100)

'Oh Lordllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllll'

In [105]:
loss_func = torch.nn.CrossEntropyLoss()
model = SimpleModel(vocab_size, 256, 128)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)

model.to(device)

def train_epoch(model, train_loader, loss_func, optimizer):
    model.train()
    total_loss = 0
    for xb, yb in train_loader:
        optimizer.zero_grad()
        y_pred = model(xb)
        loss = loss_func(y_pred.view(-1, vocab_size), yb.view(-1))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss/len(train_loader)

def validate_epoch(model, val_loader, loss_func):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for xb, yb in val_loader:
            y_pred = model(xb)
            loss = loss_func(y_pred.view(-1, vocab_size), yb.view(-1))
            total_loss += loss.item()
    return total_loss/len(val_loader)

train_loader=torch.utils.data.DataLoader(train_ds, batch_size=512, shuffle=True)
val_loader=torch.utils.data.DataLoader(val_ds, batch_size=512, shuffle=False)

for i in range(10):
    train_loss = train_epoch(model, train_loader, loss_func, optimizer)
    val_loss = validate_epoch(model, val_loader, loss_func)
    print(f'Epoch {i}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')

Epoch 0, Train Loss: 2.7758, Val Loss: 2.6150
Epoch 1, Train Loss: 2.4351, Val Loss: 2.5075
Epoch 2, Train Loss: 2.3127, Val Loss: 2.4280
Epoch 3, Train Loss: 2.2369, Val Loss: 2.3836
Epoch 4, Train Loss: 2.1823, Val Loss: 2.3373
Epoch 5, Train Loss: 2.1393, Val Loss: 2.3187
Epoch 6, Train Loss: 2.1031, Val Loss: 2.2989
Epoch 7, Train Loss: 2.0746, Val Loss: 2.2899
Epoch 8, Train Loss: 2.0489, Val Loss: 2.2762
Epoch 9, Train Loss: 2.0270, Val Loss: 2.2567


In [106]:
generate_text(model, 'Thank you', 100)

'Thank you\nhv otete o od ntede o od ntede o oe o otete\no od ntede o od ntede o oe ntedy o oe\nnto o oe ntediis '

In [107]:
for i in range(10):
    train_loss = train_epoch(model, train_loader, loss_func, optimizer)
    val_loss = validate_epoch(model, val_loader, loss_func)
    print(f'Epoch {i}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')

Epoch 0, Train Loss: 2.0083, Val Loss: 2.2508
Epoch 1, Train Loss: 1.9905, Val Loss: 2.2368
Epoch 2, Train Loss: 1.9749, Val Loss: 2.2300
Epoch 3, Train Loss: 1.9612, Val Loss: 2.2100
Epoch 4, Train Loss: 1.9484, Val Loss: 2.2154
Epoch 5, Train Loss: 1.9351, Val Loss: 2.2138
Epoch 6, Train Loss: 1.9241, Val Loss: 2.2224
Epoch 7, Train Loss: 1.9139, Val Loss: 2.2047
Epoch 8, Train Loss: 1.9029, Val Loss: 2.2016


KeyboardInterrupt: 

In [108]:
generate_text(model, 'Thank you', 100)

'Thank yous:Tnths\note o ode o ode o od o o o o ofin o ofin\notetie frmntin frmntin frmnt o o o o o o o o\nofet o'

In [None]:
for i in range(10):
    train_loss = train_epoch(model, train_loader, loss_func, optimizer)
    val_loss = validate_epoch(model, val_loader, loss_func)
    print(f'Epoch {i}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')

'E  tord  tonntou haartte '

In [None]:
generate_text(model, 'Thank you', 100)