In [1]:
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [2]:
len(text)

1115393

In [3]:
print(text[:300])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us


In [4]:
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(''.join(chars))
print(vocab_size)


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
65


In [5]:
token_to_idx = {ch: i for i, ch in enumerate(chars)}
idx_to_token = {i: ch for i, ch in enumerate(chars)}
def encode(text):
    return [token_to_idx.get(ch, 0) for ch in text]

def decode(text_encoded):
    return ''.join([idx_to_token.get(i, '') for i in text_encoded])

encode("Oh Lord")


[27, 46, 1, 24, 53, 56, 42]

In [6]:
decode([27, 46, 1, 24, 53, 56, 42])

'Oh Lord'

In [7]:
import torch
data = torch.tensor(encode(text), dtype=torch.long)
data[:300]

tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43, 44,
        53, 56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39, 52, 63,
         1, 44, 59, 56, 58, 46, 43, 56,  6,  1, 46, 43, 39, 56,  1, 51, 43,  1,
        57, 54, 43, 39, 49,  8,  0,  0, 13, 50, 50, 10,  0, 31, 54, 43, 39, 49,
         6,  1, 57, 54, 43, 39, 49,  8,  0,  0, 18, 47, 56, 57, 58,  1, 15, 47,
        58, 47, 64, 43, 52, 10,  0, 37, 53, 59,  1, 39, 56, 43,  1, 39, 50, 50,
         1, 56, 43, 57, 53, 50, 60, 43, 42,  1, 56, 39, 58, 46, 43, 56,  1, 58,
        53,  1, 42, 47, 43,  1, 58, 46, 39, 52,  1, 58, 53,  1, 44, 39, 51, 47,
        57, 46, 12,  0,  0, 13, 50, 50, 10,  0, 30, 43, 57, 53, 50, 60, 43, 42,
         8,  1, 56, 43, 57, 53, 50, 60, 43, 42,  8,  0,  0, 18, 47, 56, 57, 58,
         1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 18, 47, 56, 57, 58,  6,  1, 63,
        53, 59,  1, 49, 52, 53, 61,  1, 15, 39, 47, 59, 57,  1, 25, 39, 56, 41,
        47, 59, 57,  1, 47, 57,  1, 41, 

In [8]:
n = int(0.9*len(data)) # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]

In [9]:
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')


class Dataset(torch.utils.data.Dataset):
    def __init__(self, data, seq_len):
        self.data = data.to(device)
        self.seq_len = seq_len

    def __len__(self):
        return len(self.data) - self.seq_len

    def __getitem__(self, i):
        x = self.data[i:i+self.seq_len]
        y = self.data[i+1:i+self.seq_len+1]  
        return (x, y)
    
seq_len = 65
train_ds = Dataset(train_data, seq_len)
val_ds = Dataset(val_data, seq_len)

train_ds[0]

(tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43, 44,
         53, 56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39, 52, 63,
          1, 44, 59, 56, 58, 46, 43, 56,  6,  1, 46, 43, 39, 56,  1, 51, 43,  1,
         57, 54, 43, 39, 49,  8,  0,  0, 13, 50, 50], device='cuda:0'),
 tensor([47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43, 44, 53,
         56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39, 52, 63,  1,
         44, 59, 56, 58, 46, 43, 56,  6,  1, 46, 43, 39, 56,  1, 51, 43,  1, 57,
         54, 43, 39, 49,  8,  0,  0, 13, 50, 50, 10], device='cuda:0'))

In [18]:
class SimpleModel(torch.nn.Module):
    def __init__(self, vocab_size, embedding_dim):
        super(SimpleModel, self).__init__()
        self.embedding = torch.nn.Embedding(vocab_size, vocab_size)
        
    def forward(self, x):
        x = self.embedding(x)
        return x
    
model = SimpleModel(vocab_size, 128).to(device)
train_loader=torch.utils.data.DataLoader(train_ds, batch_size=32, shuffle=True)

for xb, yb in train_loader:
    print(xb.shape, yb.shape)
    out = model(xb)
    print(out.shape)
    break

torch.Size([32, 65]) torch.Size([32, 65])
torch.Size([32, 65, 65])


In [19]:
def generate_text(model, text, length):
    model.eval()
    with torch.no_grad():
        for i in range(length):
            x = torch.tensor(encode(text), dtype=torch.long).to(device)
            x = x.unsqueeze(0)
            y = model(x)[:,-1,:].argmax()
            text += idx_to_token[y.item()]
    return text

generate_text(model, "Oh Lord", 100)

'Oh LordOg,tlaSkhFC&SkhFC&SkhFC&SkhFC&SkhFC&SkhFC&SkhFC&SkhFC&SkhFC&SkhFC&SkhFC&SkhFC&SkhFC&SkhFC&SkhFC&SkhF'

In [20]:
loss_func = torch.nn.CrossEntropyLoss()
model = SimpleModel(vocab_size, 256)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)

model.to(device)

def train_epoch(model, train_loader, loss_func, optimizer):
    model.train()
    total_loss = 0
    for xb, yb in train_loader:
        optimizer.zero_grad()
        y_pred = model(xb)
        loss = loss_func(y_pred.view(-1, vocab_size), yb.view(-1))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss/len(train_loader)

def validate_epoch(model, val_loader, loss_func):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for xb, yb in val_loader:
            y_pred = model(xb)
            loss = loss_func(y_pred.view(-1, vocab_size), yb.view(-1))
            total_loss += loss.item()
    return total_loss/len(val_loader)

train_loader=torch.utils.data.DataLoader(train_ds, batch_size=512, shuffle=True)
val_loader=torch.utils.data.DataLoader(val_ds, batch_size=512, shuffle=False)

for i in range(3):
    train_loss = train_epoch(model, train_loader, loss_func, optimizer)
    val_loss = validate_epoch(model, val_loader, loss_func)
    print(f'Epoch {i}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')

Epoch 0, Train Loss: 3.5560, Val Loss: 2.8951
Epoch 1, Train Loss: 2.6581, Val Loss: 2.5551
Epoch 2, Train Loss: 2.4977, Val Loss: 2.5030


In [21]:
generate_text(model, 'Thank you', 100)

'Thank your the the the the the the the the the the the the the the the the the the the the the the the the th'

In [22]:
generate_text(model, "This is a strange repose, to be asleep With eyes wide open; standing, speaking, moving,", 100)

'This is a strange repose, to be asleep With eyes wide open; standing, speaking, moving, the the the the the the the the the the the the the the the the the the the the the the the the the'

In [23]:
for i in range(10):
    train_loss = train_epoch(model, train_loader, loss_func, optimizer)
    val_loss = validate_epoch(model, val_loader, loss_func)
    print(f'Epoch {i}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')

Epoch 0, Train Loss: 2.4681, Val Loss: 2.4896
Epoch 1, Train Loss: 2.4589, Val Loss: 2.4857
Epoch 2, Train Loss: 2.4552, Val Loss: 2.4842
Epoch 3, Train Loss: 2.4534, Val Loss: 2.4842
Epoch 4, Train Loss: 2.4527, Val Loss: 2.4847
Epoch 5, Train Loss: 2.4523, Val Loss: 2.4853
Epoch 6, Train Loss: 2.4521, Val Loss: 2.4866
Epoch 7, Train Loss: 2.4520, Val Loss: 2.4871
Epoch 8, Train Loss: 2.4520, Val Loss: 2.4881
Epoch 9, Train Loss: 2.4520, Val Loss: 2.4887


In [24]:
generate_text(model, "This is a strange repose, to be asleep With eyes wide open; standing, speaking, moving,", 100)

'This is a strange repose, to be asleep With eyes wide open; standing, speaking, moving, the the the the the the the the the the the the the the the the the the the the the the the the the'

In [39]:
generate_text(model, "a", 100)

'and the the the the the the the the the the the the the the the the the the the the the the the the t'