In [1]:
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [2]:
len(text)

1115393

In [3]:
print(text[:300])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us


In [4]:
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(''.join(chars))
print(vocab_size)


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
65


In [5]:
token_to_idx = {ch: i for i, ch in enumerate(chars)}
idx_to_token = {i: ch for i, ch in enumerate(chars)}

def encode(text):
    return [token_to_idx.get(ch, 0) for ch in text]

def decode(text_encoded):
    return ''.join([idx_to_token.get(i, '') for i in text_encoded])

encode("Oh Lord")


[27, 46, 1, 24, 53, 56, 42]

In [6]:
decode([27, 46, 1, 24, 53, 56, 42])

'Oh Lord'

In [7]:
import torch
data = torch.tensor(encode(text), dtype=torch.long)
data[:300]

tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43, 44,
        53, 56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39, 52, 63,
         1, 44, 59, 56, 58, 46, 43, 56,  6,  1, 46, 43, 39, 56,  1, 51, 43,  1,
        57, 54, 43, 39, 49,  8,  0,  0, 13, 50, 50, 10,  0, 31, 54, 43, 39, 49,
         6,  1, 57, 54, 43, 39, 49,  8,  0,  0, 18, 47, 56, 57, 58,  1, 15, 47,
        58, 47, 64, 43, 52, 10,  0, 37, 53, 59,  1, 39, 56, 43,  1, 39, 50, 50,
         1, 56, 43, 57, 53, 50, 60, 43, 42,  1, 56, 39, 58, 46, 43, 56,  1, 58,
        53,  1, 42, 47, 43,  1, 58, 46, 39, 52,  1, 58, 53,  1, 44, 39, 51, 47,
        57, 46, 12,  0,  0, 13, 50, 50, 10,  0, 30, 43, 57, 53, 50, 60, 43, 42,
         8,  1, 56, 43, 57, 53, 50, 60, 43, 42,  8,  0,  0, 18, 47, 56, 57, 58,
         1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 18, 47, 56, 57, 58,  6,  1, 63,
        53, 59,  1, 49, 52, 53, 61,  1, 15, 39, 47, 59, 57,  1, 25, 39, 56, 41,
        47, 59, 57,  1, 47, 57,  1, 41, 

In [8]:
n = int(0.9*len(data)) # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]

In [9]:
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class Dataset(torch.utils.data.Dataset):
    def __init__(self, data, seq_len, block_size=128):
        self.data = data.to(device)
        self.seq_len = seq_len
        self.block_size = block_size
        self.sos_token = torch.tensor([token_to_idx[" "]], dtype=torch.long).to(device)

    def __len__(self):
        return (len(self.data) - self.seq_len) // self.block_size

    def __getitem__(self, i):
        x = self.data[i * self.block_size: (i + 1) * self.block_size]
        x = torch.cat([self.sos_token, x])
        y = self.data[i * self.block_size + 1: (i + 1) * self.block_size + 1]
        y = torch.cat([self.sos_token, y])
        return x, y
    
seq_len = 768
train_ds = Dataset(train_data, seq_len)
val_ds = Dataset(val_data, seq_len)

len(train_ds), len(val_ds)

(7836, 865)

In [10]:
import torch
import math

class PositionalEncoding(torch.nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return x

class SimpleModel(torch.nn.Module):
    def __init__(self, vocab_size, embedding_dim):
        super(SimpleModel, self).__init__()
        self.embedding_dim = embedding_dim
        self.embedding = torch.nn.Embedding(vocab_size, embedding_dim)
        self.pos_encoder = PositionalEncoding(embedding_dim)
        self.encoder_layer = torch.nn.TransformerEncoderLayer(d_model=embedding_dim, nhead=8, batch_first=True)
        self.encoder = torch.nn.TransformerEncoder(self.encoder_layer, num_layers=1)
        self.decoder_layer = torch.nn.TransformerDecoderLayer(d_model=embedding_dim, nhead=8, batch_first=True)
        self.decoder = torch.nn.TransformerDecoder(self.decoder_layer, num_layers=1)
        self.fc = torch.nn.Linear(embedding_dim, vocab_size)
        self.dropout = torch.nn.Dropout(0.2)

    def forward(self, src, tgt=None):
        src = self.embedding(src) * math.sqrt(self.embedding_dim)
        src = self.pos_encoder(src) 
        memory = self.encoder(src)
 
        if tgt is None:
            out = self.decoder(src, memory)
            out=torch.nn.functional.relu(out)
            out = self.fc(out)
            return out
        tgt = self.embedding(tgt) * math.sqrt(self.embedding_dim)
        out = self.decoder(tgt, memory)
        out=torch.nn.functional.relu(out)
        out=self.dropout(out)
        out=self.fc(out)
        return out

model = SimpleModel(vocab_size, 128).to(device)
train_loader=torch.utils.data.DataLoader(train_ds, batch_size=32, shuffle=True)

for xb, yb in train_loader:
    print(xb.shape, yb.shape)
    print("xb end:", xb[-1], "yb end:", yb[-1])
    out = model(xb, yb)
    print(out.shape)
    break

torch.Size([32, 129]) torch.Size([32, 129])
xb end: tensor([ 1,  1, 54, 50, 39, 41, 43,  1, 61, 47, 58, 46,  1, 61, 47, 57, 46, 47,
        52, 45,  6,  0, 27, 56,  1, 58, 46, 39, 58,  1, 58, 46, 43,  1, 56, 43,
        57, 53, 50, 59, 58, 43,  1, 39, 41, 58, 47, 52, 45,  1, 53, 44,  1, 63,
        53, 59, 56,  1, 40, 50, 53, 53, 42,  0, 15, 53, 59, 50, 42,  1, 46, 39,
        60, 43,  1, 39, 58, 58, 39, 47, 52,  5, 42,  1, 58, 46, 43,  1, 43, 44,
        44, 43, 41, 58,  1, 53, 44,  1, 63, 53, 59, 56,  1, 53, 61, 52,  1, 54,
        59, 56, 54, 53, 57, 43,  6,  0, 35, 46, 43, 58, 46, 43, 56,  1, 63, 53,
        59,  1, 46], device='cuda:0') yb end: tensor([ 1, 54, 50, 39, 41, 43,  1, 61, 47, 58, 46,  1, 61, 47, 57, 46, 47, 52,
        45,  6,  0, 27, 56,  1, 58, 46, 39, 58,  1, 58, 46, 43,  1, 56, 43, 57,
        53, 50, 59, 58, 43,  1, 39, 41, 58, 47, 52, 45,  1, 53, 44,  1, 63, 53,
        59, 56,  1, 40, 50, 53, 53, 42,  0, 15, 53, 59, 50, 42,  1, 46, 39, 60,
        43,  1, 39, 58

In [11]:
def generate_text(model, text, length):
    model.eval()
    with torch.no_grad():
        # Initialisieren Sie die Eingabe mit einem Start-Token oder der kodierten Eingabe
        input_data = torch.tensor(encode(text), dtype=torch.long, device=device).unsqueeze(0)
        
        # Bereiten Sie eine leere Sequenz für die Decoder-Eingabe vor (oder verwenden Sie ein SOS-Token)
        decoder_input = torch.tensor([[token_to_idx[" "]]], dtype=torch.long, device=device)  # SOS_TOKEN ersetzen durch den tatsächlichen Wert

        for i in range(length):
            # Verwenden Sie die aktuelle Decoder-Eingabe für die Vorhersage
            out = model(input_data, decoder_input)
            last_token_logits = out[:, -1, :]
            last_token_prob = torch.softmax(last_token_logits, dim=-1)
            
            # print top 5 tokens
            values, indices = torch.topk(last_token_prob, 5)
            for v, i in zip(values[0], indices[0]):
                print(f"{decode([i.item()])} ({v.item():.4f})", end=", ")
            print()
            predicted_token = torch.argmax(last_token_prob, dim=1).unsqueeze(1)
            
            # Aktualisieren Sie die Decoder-Eingabe mit dem neu vorhergesagten Token
            decoder_input = torch.cat([decoder_input, predicted_token], dim=1)
            input_data = torch.cat([input_data, predicted_token], dim=1)
            
    print(decoder_input )
    return decode(decoder_input.cpu().numpy()[0])  # Ignorieren Sie das SOS-Token bei der Decodierung

generate_text(model, "Oh Lord", 3)


x (0.0455), Y (0.0328), g (0.0305), j (0.0304), A (0.0271), 
s (0.0461), a (0.0382), i (0.0336), u (0.0322), V (0.0294), 
Y (0.0323), g (0.0311), ! (0.0304), a (0.0268), x (0.0257), 
tensor([[ 1, 62, 57, 37]], device='cuda:0')


' xsY'

In [15]:
from tqdm.notebook import tqdm

loss_func = torch.nn.CrossEntropyLoss()
model = SimpleModel(vocab_size, 1024)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.00001)

model.to(device)

def train_epoch(model, train_loader, loss_func, optimizer):
    model.train()
    total_loss = 0
    for xb, yb in train_loader:
        optimizer.zero_grad()
        y_pred = model(xb, yb)
        loss = loss_func(y_pred.view(-1, vocab_size), yb.view(-1))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss/len(train_loader)

def validate_epoch(model, val_loader, loss_func):
    model.eval()
    total_loss = 0

    with torch.no_grad():
        for xb, yb in val_loader:
            y_pred = model(xb, yb)
            loss = loss_func(y_pred.view(-1, vocab_size), yb.view(-1))
            total_loss += loss.item()
    return total_loss/len(val_loader)

train_loader=torch.utils.data.DataLoader(train_ds, batch_size=128, shuffle=True)
val_loader=torch.utils.data.DataLoader(val_ds, batch_size=128, shuffle=False)

train_losses = []
val_losses = []

for i in range(5000):
    train_loss = train_epoch(model, train_loader, loss_func, optimizer)
    train_losses.append(train_loss)
    if i % 10 == 0:
        val_loss = validate_epoch(model, val_loader, loss_func)
        val_losses.append(val_loss)
        print(f'Epoch {i}, train_loss: {train_loss}, val_loss: {val_loss} sample_output:', generate_text(model, "Oh Lor", 15))
    else:
        val_losses.append(val_losses[-1])

import matplotlib.pyplot as plt
plt.plot(train_losses, label='train loss')
plt.plot(val_losses, label='val loss')
plt.legend()
plt.show()

  (0.1110), e (0.0334), s (0.0306), t (0.0279), , (0.0265), 
  (0.1124), e (0.0346), s (0.0308), t (0.0283), , (0.0271), 
  (0.1111), e (0.0353), s (0.0307), t (0.0285), , (0.0276), 
  (0.1094), e (0.0356), s (0.0306), t (0.0287), , (0.0279), 
  (0.1079), e (0.0357), s (0.0305), t (0.0289), , (0.0280), 
  (0.1067), e (0.0358), s (0.0304), t (0.0290), , (0.0281), 
  (0.1055), e (0.0359), s (0.0303), t (0.0290), , (0.0282), 
  (0.1045), e (0.0359), s (0.0302), t (0.0291), d (0.0283), 
  (0.1036), e (0.0359), s (0.0301), t (0.0291), d (0.0285), 
  (0.1029), e (0.0359), s (0.0301), t (0.0291), d (0.0287), 
  (0.1022), e (0.0359), s (0.0300), t (0.0292), d (0.0289), 
  (0.1016), e (0.0358), s (0.0299), t (0.0292), d (0.0291), 
  (0.1011), e (0.0358), s (0.0298), t (0.0292), d (0.0292), 
  (0.1006), e (0.0358), s (0.0298), d (0.0293), t (0.0293), 
  (0.1001), e (0.0358), s (0.0297), d (0.0294), t (0.0293), 
tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], device='cuda:0')
Epoch 0, 

In [None]:
generate_text(model, 'Thank you', 5)

'Thank youuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuu'

In [None]:
generate_text(model, "This is a strange repose, to be asleep With eyes wide open; standing, speaking, moving,", 5)

'This is a strange repose, to be asleep With eyes wide open; standing, speaking, moving,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,'

In [None]:
for i in range(100):
    train_loss = train_epoch(model, train_loader, loss_func, optimizer)
    train_losses.append(train_loss)
    if i % 10 == 0:
        val_loss = validate_epoch(model, val_loader, loss_func)
        val_losses.append(val_loss)
        print(f'Epoch {i}, train_loss: {train_loss}, val_loss: {val_loss}')
    else:
        val_losses.append(val_losses[-1])

import matplotlib.pyplot as plt
plt.plot(train_losses, label='train loss')
plt.plot(val_losses, label='val loss')
plt.legend()
plt.show()

Epoch 0, Train Loss: 0.0042, Val Loss: 0.0023
Epoch 1, Train Loss: 0.0031, Val Loss: 0.0018
Epoch 2, Train Loss: 0.0026, Val Loss: 0.0015


In [None]:
generate_text(model, 'Thank you', 100)

In [None]:
generate_text(model, "This is a strange repose, to be asleep With eyes wide open; standing, speaking, moving,", 100)

'This is a strange repose, to be asleep With eyes wide open; standing, speaking, moving,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,'