In [1]:
!curl -O https://www.gutenberg.org/files/1268/1268-0.txt

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 1124k  100 1124k    0     0  1642k      0 --:--:-- --:--:-- --:--:-- 1642k


In [2]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [3]:
import numpy as np

with open('1268-0.txt', 'r', encoding="utf8") as fp:
    text=fp.read()

start_indx = text.find('THE MYSTERIOUS ISLAND')
end_indx = text.find('End of the Project Gutenberg')
text = text[start_indx:end_indx]
char_set = set(text)

print('Total Length:', len(text))
print('Unique Characters:', len(char_set))

Total Length: 1112310
Unique Characters: 80


In [4]:
chars_sorted = sorted(char_set)

char2int = {ch:i for i,ch in enumerate(chars_sorted)} # e.g {"a": 2}
char_array = np.array(chars_sorted)
text_encoded = np.array([char2int[ch] for ch in text], dtype=np.int32)

print(text[:15], '== Encoding ==>', text_encoded[:15])
print(text_encoded[15:21], '== Reverse ==>',''.join(char_array[text_encoded[15:21]]))
for ex in text_encoded[:5]:
    print('{} -> {}'.format(ex, char_array[ex]))

THE MYSTERIOUS  == Encoding ==> [44 32 29  1 37 48 43 44 29 42 33 39 45 43  1]
[33 43 36 25 38 28] == Reverse ==> ISLAND
44 -> T
32 -> H
29 -> E
1 ->  
37 -> M


In [5]:
from torch.utils.data import Dataset

def create_text_chunks(text_encoded, chunk_size): 
    """ Create overlapping chunks from encoded text.
        Example
        text_encoded = [1, 2, 3, 4, 5, 6, 7], chunk_size = 4
        Chunk 0: [1, 2, 3, 4]  (positions 0-3)
        Chunk 1: [2, 3, 4, 5]  (positions 1-4)
        Chunk 2: [3, 4, 5, 6]  (positions 2-5)
        Chunk 3: [4, 5, 6, 7]  (positions 3-6)
    """
    chunks = []
    
    for i in range(len(text_encoded) - chunk_size + 1):
        chunk = text_encoded[i:i + chunk_size]
        chunks.append(chunk)
    return chunks

seq_length = 40
chunk_size = seq_length + 1
text_chunks = create_text_chunks(text_encoded, chunk_size)

In [6]:
from torch.utils.data import Dataset

class TextDataset(Dataset):
    def __init__(self, text_chunks):
        self.text_chunks = text_chunks
    
    def __len__(self):
        return len(self.text_chunks)

    def __getitem__(self, idx):
        text_chunk = self.text_chunks[idx]
        return text_chunk[:-1].long(), text_chunk[1:].long()

In [7]:
seq_dataset = TextDataset(torch.tensor(text_chunks))

for i, (input_token, output_token) in enumerate(seq_dataset):
    print('Input (x): ', repr(''.join(char_array[input_token])))
    print('Target (y): ', repr(''.join(char_array[output_token])))
    print()
    if i == 1:
        break

Input (x):  'THE MYSTERIOUS ISLAND\n\nby Jules Verne\n\n1'
Target (y):  'HE MYSTERIOUS ISLAND\n\nby Jules Verne\n\n18'

Input (x):  'HE MYSTERIOUS ISLAND\n\nby Jules Verne\n\n18'
Target (y):  'E MYSTERIOUS ISLAND\n\nby Jules Verne\n\n187'



  seq_dataset = TextDataset(torch.tensor(text_chunks))


In [8]:
from torch.utils.data import DataLoader
batch_size = 64
torch.manual_seed(1)
seq_dl = DataLoader(seq_dataset, batch_size=batch_size, shuffle=True, drop_last=True)

In [9]:
import torch.nn as nn
class RNN(nn.Module):
    def __init__(self, charset_size, embed_dim, rnn_hidden_size):
        super().__init__()
        self.rnn_hidden_size = rnn_hidden_size
        self.embedding = nn.Embedding(charset_size, embed_dim)
        self.rnn = nn.LSTM(embed_dim, rnn_hidden_size,batch_first=True)
        self.fc = nn.Linear(rnn_hidden_size, charset_size)
    
    def forward(self, x, hidden, cell):
        out = self.embedding(x).unsqueeze(1)  # [batch, embed] -> [batch, 1, embed] for RNN
        out, (hidden, cell) = self.rnn(out, (hidden, cell)) # each char is depended on prev, so we need to pass the hidden and cell here
        out = self.fc(out).reshape(out.size(0), -1) # we need to reshape because we add 1 more dim in the beginning
        return out, hidden, cell
    
    def init_hidden(self, batch_size):
        # (num_layers * num_directions(1 if uni, 2 if bi), batch_size, hidden_size)
        hidden = torch.zeros(1, batch_size, self.rnn_hidden_size)
        cell = torch.zeros(1, batch_size, self.rnn_hidden_size)
        return hidden, cell

In [10]:
charset_size = len(char_array)
embed_dim = 256
rnn_hidden_size = 512
torch.manual_seed(1)
model = RNN(charset_size, embed_dim, rnn_hidden_size)
model.to(device)
model

RNN(
  (embedding): Embedding(80, 256)
  (rnn): LSTM(256, 512, batch_first=True)
  (fc): Linear(in_features=512, out_features=80, bias=True)
)

In [11]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [12]:
import time

num_epochs = 10000
torch.manual_seed(1)
for epoch in range(num_epochs):
    for seq_batch, target_batch in seq_dl: # seq_batch.size() -> (64, 40)
        # Move data to GPU
        seq_batch = seq_batch.to(device)
        target_batch = target_batch.to(device)
        
        hidden, cell = model.init_hidden(batch_size)
        hidden = hidden.to(device)
        cell = cell.to(device)

        optimizer.zero_grad()
        loss = 0
        
        for c in range(seq_length):
            pred, hidden, cell = model(seq_batch[:, c], hidden, cell)
            loss += loss_fn(pred, target_batch[:, c])
        loss.backward()
        optimizer.step()
        loss = loss.item()/seq_length

        if epoch % 500 == 0:
            current_time = time.strftime("%H:%M:%S")
            print(f'[{current_time}] Epoch {epoch} loss: {loss:.4f}')

        break # we just do 1 iteration here because the sheer number of epoch

[20:09:41] Epoch 0 loss: 4.3712
[20:10:41] Epoch 500 loss: 1.5249
[20:11:40] Epoch 1000 loss: 1.4125
[20:12:40] Epoch 1500 loss: 1.3283
[20:13:39] Epoch 2000 loss: 1.2046
[20:14:38] Epoch 2500 loss: 1.1970
[20:15:38] Epoch 3000 loss: 1.1570
[20:16:38] Epoch 3500 loss: 1.1746
[20:17:37] Epoch 4000 loss: 1.1341
[20:18:37] Epoch 4500 loss: 1.0954
[20:19:36] Epoch 5000 loss: 1.1228
[20:20:36] Epoch 5500 loss: 1.0735
[20:21:35] Epoch 6000 loss: 1.0659
[20:22:35] Epoch 6500 loss: 1.0894
[20:23:35] Epoch 7000 loss: 1.0446
[20:24:34] Epoch 7500 loss: 1.1123
[20:25:34] Epoch 8000 loss: 1.1016
[20:26:34] Epoch 8500 loss: 1.0806
[20:27:33] Epoch 9000 loss: 0.9845
[20:28:33] Epoch 9500 loss: 1.0307


In [13]:
from torch.distributions.categorical import Categorical

def sample(model, starting_str, len_generated_text=500, scale_factor=1.0):
    encoded_input = torch.tensor([char2int[s] for s in starting_str])
    encoded_input = torch.reshape(encoded_input, (1, -1)) # add batch dim
    encoded_input =encoded_input.to(device)
    generated_str = starting_str

    model.eval()
    hidden, cell = model.init_hidden(1)
    hidden = hidden.to(device)
    cell = cell.to(device)
    
    for c in range(len(starting_str)-1): # we want to start the generation with last char, so we omit it here
        # _, hidden, cell = model(encoded_input[:, c].view(1), hidden, cell)
        _, hidden, cell = model(encoded_input[:, c], hidden, cell)

    last_char = encoded_input[:, -1]

    for i in range(len_generated_text):
        
        logits, hidden, cell = model(last_char.view(1).to(device), hidden, cell)
        logits = torch.squeeze(logits, 0)
        scaled_logits = logits * scale_factor
        m = Categorical(logits=scaled_logits)
        last_char = m.sample()
        generated_str += str(char_array[last_char])

    return generated_str

In [14]:
torch.manual_seed(1)
scale_factors = [.5, 1., 2.]
starting_str='The island'

for scale_factor in scale_factors:
    print(f'== Scale factor {scale_factor} ==')
    print(sample(model, starting_str = 'The island', scale_factor=scale_factor))
    print("\n")
    

== Scale factor 0.5 ==
The island
bore, Abold thwerm.
Or countened “FaY ismblu’s’ down, and,
already, if tremendou? Might clmativy,--W1
In a wry Nave!” “J*rnh Gzebbvellience came which was, I will, not
even’t xagumous animyie.”

ToLPocht wih numberlscat, but ojeoward Lown a
great.
It was, esseithard, a
Bible, surre!-Linally Turlacpkings aid.

Yef.
Hopels, TomwIred, planimation fell no
coverak.

This last penfect of
Ameyouse, ‘perhausiz was yellow unheadvelt,-, acquaignm, wxiden,
reign. Speparet unexpicially so unfortufe.

Othin


== Scale factor 1.0 ==
The island in the
open aix and the sailor’s idea was rounds, and opened by Top’s blood.

When Cyrus Harding and Spilett, Herbert, and Neb, having so complished without difficulty! Their engineer, formed by the winter grunts of notice.

They believed, the larder was still reply.

“Perhaps the passage of nine oyster; that, Mr. Herbert.”

“Judge so Shark youn form a desirious and two miles off the bank after the island. The Fer sailor snow.