<a href="https://colab.research.google.com/github/pearpare/sherlock-lstm/blob/main/lstm_project.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Elisabeth Kam (etk45) 

In [14]:
import torch 
import torch.nn as nn
import torch.optim as optim

In [15]:
import numpy as np

In [6]:
filename = "/sherlock.txt"
raw_txt = open(filename, 'r', encoding = 'utf-8').read()
#raw_txt = [raw_txt.replace(' ', '') for raw in raw_txt]

In [7]:
raw_txt = raw_txt.lower()
raw_txt = raw_txt[:50000]
chars = sorted(list(set(raw_txt)))
char_to_int = dict((c, i) for i, c in enumerate(chars))

In [9]:
n_chars = len(raw_txt)
n_vocab = len(chars)
print("Total characters: ", n_chars)
print("Total vocab: ", n_vocab)

Total characters:  50000
Total vocab:  44


In [12]:
#prepare the dataset of input to output pairs encoded as integers
char_seq_len = 50
X_data = []
y_data = []

for i in range(0, n_chars - char_seq_len, 1):
    seq_in = raw_txt[i:i + char_seq_len]
    seq_out = raw_txt[i + char_seq_len]
    X_data.append([char_to_int[char] for char in seq_in])
    y_data.append(char_to_int[seq_out])
    
n_patterns = len(X_data)
print("Total patterns: ", n_patterns)

Total patterns:  49950


In [16]:
X = torch.tensor(X_data, dtype=torch.float32).reshape(n_patterns, char_seq_len, 1)
X = X / float(n_vocab)
y = torch.tensor(y_data)
print(X.shape, y.shape)

torch.Size([49950, 50, 1]) torch.Size([49950])


In [17]:
import torch.utils.data as data 

In [18]:
torch.cuda.is_available()

True

In [19]:
class bookModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.lstm = nn.LSTM(input_size=1, hidden_size=256, num_layers=2, batch_first=True, dropout = 0.2)
        self.dropout = nn.Dropout(0.2) #could try changing droput values for fun 
        self.linear = nn.Linear(256, n_vocab)
    def forward(self, x): 
        x, _ = self.lstm(x)
        # takes only the last output 
        x = x[:, -1, :]
        # produce output 
        x = self.linear(self.dropout(x))
        return x 

In [20]:
n_epochs = 50
batch_size = 128 
model = bookModel()
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# print(device)
model.to(device)

optimizer = optim.Adam(model.parameters())
loss_fn = nn.CrossEntropyLoss(reduction="sum")
loader = data.DataLoader(data.TensorDataset(X, y), shuffle = True, batch_size=batch_size)

best_model = None
best_loss = np.inf

for epoch in range(n_epochs):
    model.train()
    for X_batch, y_batch in loader: 
        y_pred = model(X_batch.to(device))
        loss = loss_fn(y_pred, y_batch.to(device))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    #Validation Time
    model.eval()
    loss = 0
    with torch.no_grad():
        for X_batch, y_batch in loader:
            y_pred = model(X_batch.to(device))
            loss += loss_fn(y_pred, y_batch.to(device))
        if loss < best_loss:
            best_loss = loss
            best_model = model.state_dict()
        print("Epoch %d: Cross-entropy: %.3f" % (epoch, loss))
torch.save([best_model, char_to_int], "single-char.pth")

Epoch 0: Cross-entropy: 146138.828
Epoch 1: Cross-entropy: 134399.094
Epoch 2: Cross-entropy: 129455.500
Epoch 3: Cross-entropy: 125558.992
Epoch 4: Cross-entropy: 121896.844
Epoch 5: Cross-entropy: 118509.797
Epoch 6: Cross-entropy: 114714.719
Epoch 7: Cross-entropy: 111921.219
Epoch 8: Cross-entropy: 109119.609
Epoch 9: Cross-entropy: 106658.750
Epoch 10: Cross-entropy: 105133.664
Epoch 11: Cross-entropy: 103828.242
Epoch 12: Cross-entropy: 99264.562
Epoch 13: Cross-entropy: 98614.281
Epoch 14: Cross-entropy: 94654.641
Epoch 15: Cross-entropy: 93194.883
Epoch 16: Cross-entropy: 90537.344
Epoch 17: Cross-entropy: 88246.297
Epoch 18: Cross-entropy: 86732.422
Epoch 19: Cross-entropy: 84632.992
Epoch 20: Cross-entropy: 83034.867
Epoch 21: Cross-entropy: 80697.750
Epoch 22: Cross-entropy: 80874.211
Epoch 23: Cross-entropy: 79120.297
Epoch 24: Cross-entropy: 75374.023
Epoch 25: Cross-entropy: 74269.664
Epoch 26: Cross-entropy: 73040.000
Epoch 27: Cross-entropy: 71859.820
Epoch 28: Cross-en

In [43]:
best_model, char_to_int, torch.load("single-char.pth")
n_vocab = len(char_to_int)
int_to_char = dict((i, c) for c, i in char_to_int.items())
model.load_state_dict(best_model)

<All keys matched successfully>

In [44]:
#generate a prompt here 
file = "/sherlock.txt"
raw_txt2 = open(file, 'r', encoding = 'utf-8').read()
raw_txt2 = raw_txt2.lower()
raw_txt = raw_txt[:50000]
seq_len = 50
start = np.random.randint(0, len(raw_txt2)-seq_len)
prompt = raw_txt2[start:start+seq_len]
pattern = [char_to_int[c] for c in prompt]

In [45]:
model.eval()
print("Prompt:")
print(prompt)
print("Prompt ends here.")
print("\n")
print("Result:")
with torch.no_grad():
  for i in range(1000):
    #format input array of int into pytorch tensor 
    x = np.reshape(pattern, (1, len(pattern), 1)) / float(n_vocab)
    x = torch.tensor(x, dtype=torch.float32)
    #genreate logits as output from the model 
    pred = model(x.to(device))
    #convert logits into one character
    index = int(pred.argmax())
    result = int_to_char[index]
    print(result, end="")
    #append the new character into the prompt for the next iteration
    pattern.append(index)
    pattern = pattern[1:]

print()
print("Done.")

Prompt:

written in a hurry and dipped her pen too deep. i
Prompt ends here.


Result:
olmes which he had a pat which witl the rteption wiich i should rur at the eall. he was sapa in the coune to said it an insersence that the woole oas enueh that i was adrertarion to hore teat it on dirertalce."

"i whilk in the lady the house of the mat who was tuitted into the room which he had a pat woon the steption wiich i saw see hoom the room whth he was soon at ar elende of the simtleraph which he was satsing saper which had been aben of the stier, she was hou the soom with the shotograph in the matker was stickciny tign the paid in the steption wiich was a poall pooked to suos the soom which ie his welled upon the rtoeet. 
"i co not with you will be good in the lat whin some cound have alenaele. it was a lanyer with the siotograph which he had a comatitned an immeriant wou know the room wiich i suselfeny of the soalet of the simtleraph which he has eeen and oattered to miss the stoeet. 
it wa