In [1]:
import glob
import os
import unicodedata
import string
import torch
import random
import torch.nn.functional as F
import torch.nn as nn
import matplotlib.pyplot as plt
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter

In [2]:
class LSTM(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_layers = 1):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        
        self.lstm = nn.LSTM(input_size=input_size,
                            hidden_size=hidden_size,
                            num_layers = num_layers,
                            batch_first=False)
        
        self.linear1 = nn.Linear(hidden_size, output_size)
        self.linear2 = nn.Linear(output_size, output_size)
        
        nn.init.kaiming_uniform_(self.linear1.weight, nonlinearity='relu')
        nn.init.kaiming_uniform_(self.linear2.weight, nonlinearity='relu')
            
    def forward(self, x, hidden):
        out, hidden = self.lstm(x, hidden)
        out = F.relu(self.linear1(out))
        out = self.linear2(out)
        return out, hidden
    
    def init_hidden(self, batch_size=1):
        h0 = torch.zeros(self.num_layers, batch_size, self.hidden_size)
        c0 = torch.zeros(self.num_layers, batch_size, self.hidden_size)
        return h0, c0

In [3]:
# First, retrieve all the book names and collect the contexes in a list
files = glob.glob('books/*.txt')
book_names = []          
all_book_lines = []   

for filename in files:
    book_names.append(os.path.splitext(os.path.basename(filename))[0])
    with open(filename, 'r') as f:
        all_book_lines += f.readlines()

# Put all the chars into a single list
chars = [char for line in all_book_lines for char in line]
v = list(set(chars))  # vocab
n_vocab = len(v)
v_idx = {v[i]:i for i in range(n_vocab)}    # vocab-index dictionary
idx_v = {i:v[i] for i in range(n_vocab)}

In [4]:
def str_to_oh(word):
    idxs = torch.tensor([v_idx[l] for l in word])
    oh = F.one_hot(idxs, n_vocab)
    return oh

def oh_to_char(oh):
    idx = torch.argmax(oh, dim=1)
    return idx_v[idx.item()]
    

In [18]:
# n_hidden = 256
# batch_size = 128
# model = LSTM(n_vocab, n_hidden, n_vocab, num_layers=10)
# model.load_state_dict(torch.load('models/model_31.pt'))

# h, c = model.init_hidden(batch_size=batch_size)
# X = str_to_oh('A').type(torch.float32).unsqueeze(0)
# print(oh_to_char(X.squeeze(0)), end='')
# for i in range(400):
#     output, (h, c) = model(X, (h, c))
#     probs = F.softmax(output[0], dim=0)
#     char_idx = torch.multinomial(probs, 1).item() 
#     X = F.one_hot(torch.tensor([char_idx]), num_classes=n_vocab).type(torch.float32)
#     print(idx_v[char_idx], end='')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# Evaluate
n_hidden = 256
batch_size = 1
model = LSTM(n_vocab, n_hidden, n_vocab, num_layers=10)
model.load_state_dict(torch.load('models/model_15.pt', map_location=device))
model.eval()

h, c = model.init_hidden(batch_size=batch_size)
X = str_to_oh('a').type(torch.float32).unsqueeze(0)  # Add a batch dimension
print(oh_to_char(X[0]), end='')

for i in range(1000):
    output, (h, c) = model(X, (h, c))
    probs = F.softmax(output[0, -1], dim=0)  # Take the last output and apply softmax
    char_idx = torch.multinomial(probs, 1).item()  # Sample from the distribution
    X = F.one_hot(torch.tensor([char_idx]), num_classes=n_vocab).type(torch.float32).unsqueeze(0)
    print(oh_to_char(X[0]), end='')



aVOı-lP
tlç!lç.lkGl"çxıÇbıGGıbaVp;g-lçLüyıkyıLrç pMAGl.ç-lklA!
çfP!MaVVÖ;gLı!lkçtMPMçElH
Eç!ü-!ĞçEüyĞPaVplŞlç.
-G!lgPlkçüPĞ!Ğgç.ıPklk!
rVmltĞçLı-ĞkPıkç-lAP
çıgoıkçş"BıVNP;";gçıkPı-ĞkıçĞAıgBıaVÖıkçfoçGıgĞçxıo-Ğk!ıçXıyGı"çLlçGlgrVğı-Ğgıçl-l.çişkGı!ĞçXıkçtıPf"MVÜıgĞGçfbMk!MçGıEĞbıbçtl-
!!laVcıkkıbç.fBlŞ
g!lçyıL!ĞLç.
-
gB
"aVjıGlgıçblHblktlçlEĞbĞç;"MkaV?ı-ıçtıgıçĞAtĞgçĞtıkçtly
rVTtbıçıP.Ğgçolk!
çfglçtĞ!bĞrVhilkçşBıGıgçEü-Algç.ı-Gçü-!ĞGrV:ç-ükç-lAG!lkçXıkĞE
gç!lkbltl-aVƏĞkçHı"ĞGçEü--ı-ĞGtıgçi;gıAçüP!ĞaVzıgĞgçfŞ!MçEĞkç!;XBıçElŞlçs!ı!ĞaVvfPPMç;tMgolçAşkıktıkç-lPGlçĞiıkrVQk!ĞkPıkçıtklg
gçĞPGĞ!lç-MHMaVp;g-lglçGıgırçtlP!
Ş
G!lP
.raaVöşk;AG;AçPfgıkĞgĞgçEĞkçtıgĞgçĞbbı-ırV?Ğ-PıkĞgĞçxıgoıLçfPlkçAĞÇçBlylgçoMk!MaVQAbıkçElkPlk
gçfçl!
!çoükıktĞg!ıgZVjfgklçxlXçşPi;";çüyı!ĞçEĞkçLıAl!ıgaVƏĞkçl-PlgçMkMgMçHıy!ıgçE;XilglrVR;ÇĞç-übĞkıçylPGlPlgç.ıPEĞgıaVRlgĞŞĞç.lb
A!
ç!ıkAlL
gĞçyıGrVTtıL!ıgçÇıPkĞ-lkçLĞtĞ-ĞgiĞgçfPMkaVuıkçLıbçLĞGĞç!üGı"ç !ıç.
P
gBI.
!
PlçlPrVƏlAtlç.lblkçiş-PıkĞrçıgPıkçfPGl"rVhŞ
"G
Arçü-ç;tb;çĞAEı