In [1]:
!curl -O https://gutenberg.org/files/1268/1268-0.txt

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 1143k  100 1143k    0     0  1222k      0 --:--:-- --:--:-- --:--:-- 1229k


In [2]:
import numpy as np
with open('1268-0.txt', 'r', encoding="utf8") as f:
    text = f.read()
start_idx = text.find('THE MYSTERIOUS ISLAND')
end_idx = text.find('End of the Project Gutenberg')
text = text[start_idx:end_idx]
char_set = set(text)

print(f"Total length: {len(text)}")

print(f"Unique Characters: {len(char_set)}")

    

Total length: 1130711
Unique Characters: 85


In [3]:
chars_sorted = sorted(char_set)
char2int = {ch:i for i, ch in enumerate(chars_sorted)}
int2char = np.array(chars_sorted)

text_encoded = np.array(
    [char2int[ch] for ch in text],
    dtype=np.int32
)

print(text[:15])
print(text_encoded[:15])

THE MYSTERIOUS 
[48 36 33  1 41 53 47 48 33 46 37 43 49 47  1]


In [4]:
import torch
from torch.utils.data import Dataset

seq_length = 40
chunk_size = seq_length + 1
text_chunks = [text_encoded[i:i+chunk_size] for i in range(len(text_encoded) - chunk_size+1)]

class TextDataset(Dataset):
    def __init__(self, text_chunks):
        self.text_chunks = text_chunks

    def __len__(self):
        return len(self.text_chunks)

    def __getitem__(self, idx):
        text_chunk = self.text_chunks[idx]
        return text_chunk[:-1].long(), text_chunk[1:].long()

seq_dataset = TextDataset(torch.tensor(text_chunks))

  seq_dataset = TextDataset(torch.tensor(text_chunks))


In [6]:
for i, (seq, target) in enumerate(seq_dataset):
    print(repr(''.join(int2char[seq])))
    print(repr(''.join(int2char[target])))
    if i==1:
        break

'THE MYSTERIOUS ISLAND ***\n\n\n\n\nTHE MYSTER'
'HE MYSTERIOUS ISLAND ***\n\n\n\n\nTHE MYSTERI'
'HE MYSTERIOUS ISLAND ***\n\n\n\n\nTHE MYSTERI'
'E MYSTERIOUS ISLAND ***\n\n\n\n\nTHE MYSTERIO'


In [7]:
## transform Dataset into mini-batches
from torch.utils.data import DataLoader
batch_size = 64
torch.manual_seed(1)
seq_dl = DataLoader(seq_dataset, batch_size=batch_size,
                    shuffle=True, drop_last=True)

## Build a character-level RNN model

In [8]:
import torch.nn as nn
class RNN(nn.Module):
    def __init__(self, vocab_size, embed_dim, rnn_hidden_size):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.rnn_hidden_size = rnn_hidden_size
        self.rnn = nn.LSTM(embed_dim, rnn_hidden_size, batch_first=True)
        self.fc = nn.Linear(rnn_hidden_size, vocab_size)

    def forward(self, x, hidden, cell):
        out = self.embedding(x).unsqueeze(1)
        out, (hidden, cell) = self.rnn(out, (hidden, cell))
        out = self.fc(out).reshape(out.size(0), -1)
        return out, hidden, cell

    def init_hidden(self, batch_size):
        hidden = torch.zeros(1, batch_size, self.rnn_hidden_size)
        cell = torch.zeros(1, batch_size, self.rnn_hidden_size)
        return hidden, cell

In [9]:
vocab_size = len(int2char)
embed_dim = 256
rnn_hidden_size = 512
torch.manual_seed(1)
model = RNN(vocab_size, embed_dim, rnn_hidden_size)
model

RNN(
  (embedding): Embedding(85, 256)
  (rnn): LSTM(256, 512, batch_first=True)
  (fc): Linear(in_features=512, out_features=85, bias=True)
)

In [10]:
## create loss function and optimizer
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)

In [None]:
num_epochs = 10_000
torch.manual_seed(1)
for epoch in range(num_epochs):
    hidden, cell = model.init_hidden(batch_size)
    seq_batch, target_batch = next(iter(seq_dl))
    optimizer.zero_grad()
    loss = 0
    for c in range(seq_length):
        pred, hidden, cel = model(seq_batch[:, c], hidden, cell)
        loss += loss_fn(pred, target_batch[:, c])
    loss.backward()
    optimizer.step()
    loss = loss.item()/seq_length
    if epoch % 500 == 0:
        print(f"Epoch {epoch} loss: {loss:.4f}")

In [12]:
from torch.distributions.categorical import Categorical
def sample(model, starting_str, len_generated_text=500, scale_factor=1.0):
    encoded_input = torch.tensor(
        [char2int[s] for s in starting_str]
    )

    encoded_input = torch.reshape(
        encoded_input, (1, -1)
    )

    generated_str = starting_str

    model.eval()
    hidden, cell = model.init_hidden(1)
    for c in range(len(starting_str)-1):
        _, hidden, cell = model(
            encoded_input[:, c].view(1), hidden, cell
        )

    last_char = encoded_input[:, -1]
    for i in range(len_generated_text):
        logits, hidden, cell = model(
            last_char.view(1), hidden, cell
        )
        logits = torch.squeeze(logits, 0)
        scaled_logits = logits * scale_factor
        m = Categorical(logits=scaled_logits)
        last_char = m.sample()
        generated_str += str(int2char[last_char])

    return generated_str

In [13]:
torch.manual_seed(1)
print(sample(model, starting_str="The island"))

The islandmiauts,
our isents ofiaust o atteerly stuumes tially astanty forate, on unable. 

“It oilln, to roaten timear ssital our iste enen our fatilianes ofails. , howerars ofatine, sear antatin, sea, turalition.  spearetes fires ofual iten atmpous matinees, usilae stion, ous. seasu offialty onles secatte, of att hourpaty tien amiutes, woodme.
 soortea.  anstaitanly furoousa state ourn ofatts. --armety.



“Ih anlesteen satity orate orooos, miatusa, antitus.  itselas.  halle, woolle, letaw, ancter elar 


In [14]:
torch.manual_seed(1)
print(sample(model, starting_str="The island", scale_factor=2.0))

The islandmiatt ooous, antity oura tiouse speate ourate tin aloate, seare one, atitusa time, searitus, tear least oust ouste enty atimes ofures ofient our atiatus.  lootsiate our antits feorest isteer inles ofear oasts ourate time laste or ous tiluus tilia, sea.  also almoonsa antity furene ests satiousiate sea, furole, antity oures ofiants ofuries ofal atioust itates foroogen tione ourste story, antiounate stants ofures ofation ents ofures ofear antiousus firent ouranes ofar one offer antity anity, altit
