# Language moddeling with PyTorch

Import necessary modules:

In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
from torch.distributions.categorical import Categorical
import torch.nn as nn
import numpy as np

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

Text preprocessing:

In [3]:
text = ''

files = [
    '1056-0.txt', 
    '1074-0.txt',
    '215-0.txt'
]

for file_path in files:
    with open(file_path, 'r', encoding='utf8') as f:
        file_text = ' '.join(f.read().split())

        start_index = file_text.find('START OF THE PROJECT GUTENBERG')
        end_index = file_text.find('END OF THE PROJECT GUTENBERG')

        file_text = file_text[start_index:end_index]
        text += file_text
        
chars = set(text)

print(f'Text length: {len(text)}')
print(f'Unique characters: {len(chars)}')

Text length: 1340256
Unique characters: 97


Map characters to integers:

In [4]:
chars_sorted = sorted(chars)
char_to_int = {char: i for i, char in enumerate(chars_sorted)}
char_array = np.array(chars_sorted)
text_encoded = np.array(
    [char_to_int[ch] for ch in text],
    dtype=np.int32
)
print(f'Encoded text shape: {text_encoded.shape}')

Encoded text shape: (1340256,)


In [5]:
message_to_encode = 'The Sea-Wolf'

encoded_message = [char_to_int[char] for char in message_to_encode]

print(f'{message_to_encode} ===> {encoded_message}')

decoded_message = ''.join([char_array[index] for index in encoded_message])

print(f'REVERSE\n'
      f'{encoded_message} ===> {decoded_message}')

The Sea-Wolf ===> [44, 61, 58, 0, 43, 58, 54, 10, 47, 68, 65, 59]
REVERSE
[44, 61, 58, 0, 43, 58, 54, 10, 47, 68, 65, 59] ===> The Sea-Wolf


Split the text into chunks and set up the `Dataset`: 

In [6]:
seq_len = 40
chunk_size = seq_len + 1
text_chunks = [text_encoded[i:i+chunk_size]
               for i in range(len(text_encoded) - chunk_size)]

class TextDataset(Dataset):
    def __init__(self, text_chunks):
        self.text_chunks = text_chunks

    def __len__(self):
        return len(self.text_chunks)
    
    def __getitem__(self, item):
        text_chunk = self.text_chunks[item]
        return text_chunk[:-1].long(), text_chunk[1:].long()
    
seq_dataset = TextDataset(
    torch.tensor(
        np.array(text_chunks)
    )
)

Set up the `DataLoader`:

In [7]:
batch_size = 64
seq_loader = DataLoader(
    seq_dataset, batch_size=batch_size,
    shuffle=True, drop_last=True
)

Build the RNN model:

In [8]:
class TextGenerationNet(nn.Module):
    def __init__(self, vocab_size, embed_dim, rnn_hidden_size):
        super().__init__()

        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.rnn_hidden_size = rnn_hidden_size
        self.rnn = nn.LSTM(embed_dim, rnn_hidden_size,
                           batch_first=True)
        self.fc = nn.Linear(rnn_hidden_size, vocab_size)

    def forward(self, x, hidden, cell):
        out = self.embedding(x).unsqueeze(1)
        out, (hidden, cell) = self.rnn(out, (hidden, cell))
        out = self.fc(out).reshape(out.size(0), -1)
        return out, hidden, cell
    
    def init_hidden(self, batch_size):
        hidden = torch.zeros(1, batch_size, self.rnn_hidden_size)
        cell = torch.zeros(1, batch_size, self.rnn_hidden_size)
        return hidden.to(device), cell.to(device)

In [9]:
vocab_size = len(char_array)
embed_dim = 256
rnn_hidden_size = 512

model = TextGenerationNet(vocab_size, embed_dim, rnn_hidden_size).to(device)
model

TextGenerationNet(
  (embedding): Embedding(97, 256)
  (rnn): LSTM(256, 512, batch_first=True)
  (fc): Linear(in_features=512, out_features=97, bias=True)
)

In [10]:
loss = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)

n_epochs = 10_000

for epoch in range(n_epochs):
    hidden, cell = model.init_hidden(batch_size)
    seq_batch, target_batch = next(iter(seq_loader))

    seq_batch = seq_batch.to(device)
    target_batch = target_batch.to(device)

    optimizer.zero_grad()
    loss_val = 0.

    for c in range(seq_len):
        pred, hidden, cell = model(
            seq_batch[:, c], hidden, cell
        )
        loss_val += loss(pred, target_batch[:, c])

    loss_val.backward()
    optimizer.step()

    loss_val = loss_val.item() / seq_len
    if epoch % 250 == 0:
        print(f'Epoch {epoch} loss: {loss_val:.4f}')

Epoch 0 loss: 4.5942
Epoch 250 loss: 1.6008
Epoch 500 loss: 1.4895
Epoch 750 loss: 1.4544
Epoch 1000 loss: 1.4028
Epoch 1250 loss: 1.4184
Epoch 1500 loss: 1.2880
Epoch 1750 loss: 1.2823
Epoch 2000 loss: 1.2993
Epoch 2250 loss: 1.3498
Epoch 2500 loss: 1.2856
Epoch 2750 loss: 1.3110
Epoch 3000 loss: 1.2731
Epoch 3250 loss: 1.2899
Epoch 3500 loss: 1.3153
Epoch 3750 loss: 1.2897
Epoch 4000 loss: 1.2916
Epoch 4250 loss: 1.2909
Epoch 4500 loss: 1.2816
Epoch 4750 loss: 1.2276
Epoch 5000 loss: 1.2772
Epoch 5250 loss: 1.2752
Epoch 5500 loss: 1.3032
Epoch 5750 loss: 1.3307
Epoch 6000 loss: 1.3222
Epoch 6250 loss: 1.2488
Epoch 6500 loss: 1.2474
Epoch 6750 loss: 1.2712
Epoch 7000 loss: 1.2367
Epoch 7250 loss: 1.2242
Epoch 7500 loss: 1.2726
Epoch 7750 loss: 1.2036
Epoch 8000 loss: 1.2346
Epoch 8250 loss: 1.2676
Epoch 8500 loss: 1.2737
Epoch 8750 loss: 1.2713
Epoch 9000 loss: 1.2578
Epoch 9250 loss: 1.2628
Epoch 9500 loss: 1.2869
Epoch 9750 loss: 1.2562


In [11]:
def sample(model, input_string, generate_len=500, scale_factor=1.0):
    encoded_input = torch.tensor(
        [char_to_int[c] for c in input_string]
    ).to(device)
    encoded_input = torch.reshape(
        encoded_input, (1, -1)
    )
    generated_string = input_string
    model.eval()
    hidden, cell = model.init_hidden(1)
    for c in range(len(input_string)-1):
        _, hidden, cell = model(
            encoded_input[:, c].view(1), hidden, cell
        )

    last_char = encoded_input[:, -1]
    for i in range(generate_len):
        logits, hidden, cell = model(
            last_char.view(1), hidden, cell
        )
        logits = torch.squeeze(logits, 0)
        scaled_logits = logits * scale_factor
        m = Categorical(logits=scaled_logits)
        last_char = m.sample()
        generated_string += str(char_array[last_char])
        
    return generated_string

In [16]:
input_string = ' '.join('''
This sympathy comes to us
'''.split())

with open('output.txt', 'w') as out:
    print(sample(model, input_string, generate_len=1000), file=out)

In [17]:
torch.save(model.state_dict(), '../model/model.pt')