In [1]:
import torch
import pandas as pd
from collections import Counter
import argparse
import numpy as np
from torch import nn, optim
from torch.utils.data import DataLoader
from os import listdir
from os.path import isfile, join
import nltk
from nltk.tokenize import sent_tokenize, word_tokenize
from nltk.stem import WordNetLemmatizer
import string

In [2]:
nltk_stop_words = nltk.corpus.stopwords.words('english')

In [3]:
if torch.cuda.is_available():  
    dev = "cuda:0" 
else:  
    dev = "cpu"  
print(dev)

cpu


In [4]:
def remove_gutenberg_preamble(text):
    try:
        return text[text.index("***", text.index("START OF THIS PROJECT GUTENBERG"))+3:]
    except:
        return text[text.index("***", text.index("START OF THE PROJECT GUTENBERG"))+3:]
    
def remove_gutenberg_postscript(text):
    try:
        return text[:text.index("***", text.index("END OF THE PROJECT GUTENBERG"))-3]
    except:
        return text[:text.index("***", text.index("END OF THIS PROJECT GUTENBERG"))-3]
    
def get_label(file_loc: str) -> int:
    science_fiction = 0
    horror = 1
    adventure = 2
    humor = 3
    western = 4
    mystery = 5
    gothic = 6
    if 'gothic' in file_loc:
        return gothic
    if 'western' in file_loc:
        return western
    if 'mystery' in file_loc:
        return mystery
    if 'humor' in file_loc:
        return humor
    if 'adventure' in file_loc:
        return adventure
    if 'horror' in file_loc:
        return horror
    if 'scifi' in file_loc:
        return science_fiction
    
def create_by_newline(file_loc: str, df):
    new_df = pd.DataFrame()
    try:
        f = open(file_loc, "r")
        text = f.read()
        text = remove_gutenberg_preamble(text)
        text = remove_gutenberg_postscript(text)
        arr = text.split('\n\n')
        arr = [a for a in arr if len(a) > 2]
        label = get_label(file_loc)
        new_df['text'] = arr
        new_df['label'] = label
        df = df.append(new_df[5:-3])
        return df
    except:
        return df

In [5]:
train_on = [join('./set0',f) for f in listdir('./set0') if (isfile(join('./set0', f)) and 'western' in f)]
train_on

['./set0/western18.txt',
 './set0/western19.txt',
 './set0/western09.txt',
 './set0/western08.txt',
 './set0/western05.txt',
 './set0/western11.txt',
 './set0/western10.txt',
 './set0/western04.txt',
 './set0/western12.txt',
 './set0/western06.txt',
 './set0/western07.txt',
 './set0/western13.txt',
 './set0/western17.txt',
 './set0/western03.txt',
 './set0/western02.txt',
 './set0/western16.txt',
 './set0/western00.txt',
 './set0/western14.txt',
 './set0/western15.txt',
 './set0/western01.txt']

In [6]:
new_df = pd.DataFrame()
for file in train_on:
    new_df = create_by_newline(file, new_df)

In [7]:
new_df.reset_index(drop=True, inplace=True)

In [8]:
wordnet_lemmatizer = WordNetLemmatizer()
# vectorizing function to able to call on list of tokens
lemmatize_words = np.vectorize(wordnet_lemmatizer.lemmatize)


for index, i in enumerate(new_df['text']):
    tokens = word_tokenize(i)
    tokens = [t for t in tokens if t not in string.punctuation]
    tokens = [t for t in tokens if t not in nltk_stop_words]
    
    if(len(tokens) < 1):
        continue
    lemmatized_text = ' '.join(lemmatize_words(tokens))
    new_df.at[index,'text'] = lemmatized_text.lower()

In [11]:
# new_df = new_df[:len(new_df) // 4]

In [12]:
new_df.shape

(40953, 2)

In [33]:

class Model(nn.Module):
    def __init__(self, dataset):
        super(Model, self).__init__()
        self.lstm_size = 128
        self.embedding_dim = 128
        self.num_layers = 7

        n_vocab = len(dataset.uniq_words)
        self.embedding = nn.Embedding(
            num_embeddings=n_vocab,
            embedding_dim=self.embedding_dim,
        )
        self.lstm = nn.LSTM(
            input_size=self.lstm_size,
            hidden_size=self.lstm_size,
            num_layers=self.num_layers,
            dropout=0.2,
        )
        self.fc = nn.Linear(self.lstm_size, n_vocab)

    def forward(self, x, prev_state):
        embed = self.embedding(x)
        output, state = self.lstm(embed, prev_state)
        logits = self.fc(output)
        return logits, state

    def init_state(self, sequence_length):
        return (torch.zeros(self.num_layers, sequence_length, self.lstm_size),
                torch.zeros(self.num_layers, sequence_length, self.lstm_size))

In [14]:

class Dataset(torch.utils.data.Dataset):
    def __init__(
        self,
        epochs, batch_size, seq_length
    ):
        self.epochs = epochs
        self.batch_size = batch_size
        self.seq_length = seq_length
        self.words = self.load_words()
        self.uniq_words = self.get_uniq_words()

        self.index_to_word = {index: word for index, word in enumerate(self.uniq_words)}
        self.word_to_index = {word: index for index, word in enumerate(self.uniq_words)}

        self.words_indexes = [self.word_to_index[w] for w in self.words]

    def load_words(self):
        train_df = new_df
        text = train_df['text'].str.cat(sep=' ')
        return text.split(' ')

    def get_uniq_words(self):
        word_counts = Counter(self.words)
        return sorted(word_counts, key=word_counts.get, reverse=True)

    def __len__(self):
        return len(self.words_indexes) - self.seq_length

    def __getitem__(self, index):
        return (
            torch.tensor(self.words_indexes[index:index+self.seq_length]),
            torch.tensor(self.words_indexes[index+1:index+self.seq_length+1]),
        )

In [15]:

def train(dataset, model, epochs, batch_size, seq_length):
    model.train()

    dataloader = DataLoader(dataset, batch_size=batch_size)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    for epoch in range(epochs):
        state_h, state_c = model.init_state(seq_length)

        for batch, (x, y) in enumerate(dataloader):
            optimizer.zero_grad()

            y_pred, (state_h, state_c) = model(x, (state_h, state_c))
            loss = criterion(y_pred.transpose(1, 2), y)

            state_h = state_h.detach()
            state_c = state_c.detach()

            loss.backward()
            optimizer.step()

            print({ 'epoch': epoch, 'batch': batch, 'loss': loss.item() })

In [16]:
def predict(dataset, model, text, next_words=100):
    model.eval()

    words = text.split(' ')
    state_h, state_c = model.init_state(len(words))
    
    for i in range(0, next_words):
        x = torch.tensor([[dataset.word_to_index[w] for w in words[i:]]])
        y_pred, (state_h, state_c) = model(x, (state_h, state_c))

        last_word_logits = y_pred[0][-1]
        p = torch.nn.functional.softmax(last_word_logits, dim=0).detach().numpy()
        word_index = np.random.choice(len(last_word_logits), p=p)
        words.append(dataset.index_to_word[word_index])

    return words

In [59]:
# parser = argparse.ArgumentParser()
# parser.add_argument('--max-epochs', type=int, default=10)
# parser.add_argument('--batch-size', type=int, default=256)
# parser.add_argument('--sequence-length', type=int, default=4)
# args = parser.parse_args()

epochs = 20
batch_size = 256
seq_length = 4

dataset = Dataset(epochs, batch_size, seq_length)
model = Model(dataset)

train(dataset, model, epochs, batch_size, seq_length)
print('DONE')

{'epoch': 0, 'batch': 0, 'loss': 10.45074462890625}
{'epoch': 0, 'batch': 1, 'loss': 10.443279266357422}
{'epoch': 0, 'batch': 2, 'loss': 10.438493728637695}
{'epoch': 0, 'batch': 3, 'loss': 10.430967330932617}
{'epoch': 0, 'batch': 4, 'loss': 10.420955657958984}
{'epoch': 0, 'batch': 5, 'loss': 10.408860206604004}
{'epoch': 0, 'batch': 6, 'loss': 10.407341003417969}
{'epoch': 0, 'batch': 7, 'loss': 10.3621826171875}


KeyboardInterrupt: 

In [128]:
print(predict(dataset, model, text='howdy all the name is ken'))

['howdy', 'all', 'the', 'name', 'is', 'ken', 'meet', 'mine', 'range', 'but', 'probably', 'became', 'exceedingly', 'bitter', 'sinister', 'undeveloped', 'his', 'mind', 'code', 'marry', 'routine', 'indignation', 'punished', 'killing', 'absence', 'order', 'the', 'got', 'oregon', 'excite', 'wipe', 'guilty', 'climax', 'to', 'sheep', 'avidly', 'ruined', '``', 'ellen', 'stole', 'gift', 'dad', 'listen', "'round", '--', 'bear', 'old', 'young', 'red', 'he', 'ever', '--', 'died', '``', "'wal", 'jean', 'that', 'gunman', "''", 'said', '``', 'reckon', 'i', "n't", 'love', 'he', 'goin', 'n', '--', 'grievance', "''", '``', 'wal', 'rocky', 'ta', 'nez', 'meadow', "''", '``', 'how', '--', '--', 'stand', 'thet', 'u', "''", 'ejaculated', 'rancher', 'jean', '``', 'my', 'name', "'s", 'got', 'got', 'father', 'maskin', 'callin', "''", '``', 'where', "''", '``', 'bah', 'color', "''"]


In [129]:
torch.save(model.state_dict(), './model-dict-slow')

In [24]:
epochs = 20
batch_size = 256
seq_length = 4

dataset = Dataset(epochs, batch_size, seq_length)
model2 = Model(dataset)
model2.load_state_dict(torch.load('./model-dict'))

RuntimeError: Error(s) in loading state_dict for Model:
	Missing key(s) in state_dict: "lstm.weight_ih_l3", "lstm.weight_hh_l3", "lstm.bias_ih_l3", "lstm.bias_hh_l3". 

In [19]:
checkpoint = torch.load('./model-dict.pt',map_location=torch.device('cpu'))

In [20]:
for key in checkpoint.keys():
    print(checkpoint[key].shape)

torch.Size([34434, 128])
torch.Size([512, 128])
torch.Size([512, 128])
torch.Size([512])
torch.Size([512])
torch.Size([512, 128])
torch.Size([512, 128])
torch.Size([512])
torch.Size([512])
torch.Size([512, 128])
torch.Size([512, 128])
torch.Size([512])
torch.Size([512])
torch.Size([512, 128])
torch.Size([512, 128])
torch.Size([512])
torch.Size([512])
torch.Size([512, 128])
torch.Size([512, 128])
torch.Size([512])
torch.Size([512])
torch.Size([34434, 128])
torch.Size([34434])


In [40]:
model = Model(dataset)

In [41]:
model.state_dict

<bound method Module.state_dict of Model(
  (embedding): Embedding(34434, 128)
  (lstm): LSTM(128, 128, num_layers=3, dropout=0.2)
  (fc): Linear(in_features=128, out_features=34434, bias=True)
)>

In [34]:
z = torch.load('./model-dict12.pt',map_location=torch.device('cpu'))
model2 = Model(dataset)
model2.load_state_dict(z)

<All keys matched successfully>

In [58]:
for i in range(10):
    print(predict(dataset, model2, text='shoot the dang ol rabbit'))

['shoot', 'the', 'dang', 'ol', 'rabbit', 'three', "b'gosh", 'danced', 'agony', 'cent', 'shone', 'motion', 'saddle', 'pulled', 'taken', 'flight', 'practicing', 'nurse', 'portion', 'clear', '...', 'all', 'from', 'ridge', 'earnest', 'established', 'somethin', 'requires', 'passionate', 'passage', 'dusty', 'convinced', 'another', 'weird', 'meet', 'snarled', 'ranch', 'spirit', 'strolled', 'gray', 'dying', 'told', 'left', 'quoted', 'till', 'i', 'felt', 'four', 'open', 'open', 'reflection', 'heart', 'some', 'i', 'astor', 'naab', 'money', 'sash', 'tickled', 'leaping', 'i', 'public', 'tribe', 'year', 'thank', 'around', 'darkness', 'turkey', 'without', '”', 'gang', 'circle', 'mouth', 'creamy', 'woman', 'vivacity', 'newcomer', 'along', 'upon', 'sorry', 'presently', 'big', 'hunt', 'something', 'believe', 'where', 'afternoon', 'indorsement', 'again', 'windfall', 'courage', 'brave', 'countenance', 'slope', 'thar', 'footprint', 'fur', 'refused', 'wind', 'tree', 'running', 'sure', 'hunter', 'anger', 'p

['shoot', 'the', 'dang', 'ol', 'rabbit', 'gauged', 'eye', 'time', 'whip', 'gentleman', 'leaving', 'all', 'lingered', 'fell', 'evidently', "'s", 'corpse', 'way', 'sight', 'appeared', 'see', 'edge', 'jones', 'constantly', 'could', 'several', 'another', '“', 'cut', 'track', 'gaining', 'case', "'s", 'circumstance', 'fast', 'pale', 'unreal', 'giant', 'disarmed', 'like', 'expects', 'whispered', 'pistol', 'continued', 'mounting', 'trotted', 'withers', 'aftermath', 'repair', '``', 'managed', 'see', 'first', 'colder', 'progress', '--', 'big', '....', 'header', 'wreck', 'justice', 'ai', 'frigid', 'race', 'hand', 'space', 'crossed', 'but', 'wave', 'concerning', 'every', 'spectre', 'saw', 'irresistibly', 'dust', 'heat', 'inheritor', 'long', 'upon', 'stroke', 'close', 'shouted', 'escape', 'herbage', 'hunter', 'spot', 'running', 'existed', 'expect', 'labored', 'but', 'quiet', 'states', 'screen', 'slipped', "'ll", 'slightest', 'life', 'discharged', 'go', 'soft', 'mustang', 'rock', 'left', 'sixty']
