# Use Model

In [17]:
#!aws s3 ls


In [18]:
#!aws s3 ls s3://project17-bucket-alex/stories-and-books-nlp/processed-data/


In [19]:
if True is False: # set to true only for the first un
    # Setup - Run only once per Kernel App
    %conda install openjdk -y

    # install PySpark
    %pip install s3fs pyarrow torch==1.13.0

    # restart kernel
    from IPython.core.display import HTML
    HTML("<script>Jupyter.notebook.kernel.restart()</script>")

In [20]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

#import sys

#import os
import torch
import torch.nn as nn
#from torch.utils.data import DataLoader, TensorDataset, random_split
import torch.nn.functional as F
#import string
#from tqdm import tqdm
import pickle
#from io import BytesIO


pd.set_option('display.max_colwidth', 150) 
#pd.set_option('display.width', None)
pd.set_option('display.max_columns', None)


In [21]:
#local path
mapping_path = '../../data/nlp-data/char2idx.pkl'

# open
with open(mapping_path, 'rb') as file:
    char2idx = pickle.load(file)

# define
idx2char = {idx: char for char, idx in char2idx.items()}




In [22]:

class RNN(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, n_layers):
        super(RNN, self).__init__()
        
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        
        self.rnn = nn.LSTM(embedding_dim, hidden_dim, n_layers, batch_first=True)
        
        # output
        self.fc = nn.Linear(hidden_dim, vocab_size)
        
    def forward(self, x):
        x = self.embedding(x)
        x, _ = self.rnn(x)
        x = self.fc(x)
        
        return x


In [23]:
def generate_text(model, start_string, length, temperature=1):

    
    # put the model in evaluation mode
    model.eval()

    # convert start_string characters to indices
    input_eval = [char2idx.get(s, char2idx['UNK']) for s in start_string]
    input_eval = torch.tensor(input_eval, dtype=torch.long).unsqueeze(0)

    generated_text = []

    with torch.no_grad():
        for _ in range(length):
            output = model(input_eval)
            output = output[:, -1, :] / temperature
            probabilities = F.softmax(output, dim=-1)
            predicted_id = torch.multinomial(probabilities, num_samples=1)

            # append to the input for the next character prediction
            input_eval = torch.cat([input_eval, predicted_id], dim=1)

            generated_text.append(idx2char[predicted_id.item()])

    return start_string + ''.join(generated_text)




In [24]:
# load the model
model_choice = 'rnn_model_500_0.001'

model_saved = torch.load(f'../../data/ml-data/{model_choice}.pth', map_location=torch.device('cpu'))


In [None]:
#  generate
prompt = "Once upon a time"
generated = generate_text(model_saved, prompt, 1000) 



In [29]:
# print while breaking every certain spaces
spaces_to_break = 15
space_count = 0
for char in generated:
    if char == ' ':
        space_count += 1
        if space_count % spaces_to_break == 0:
            print()  
            continue
    print(char, end='')


Once upon a time  e earfbda eauna   uafvar iwnhdt adsttheoiesyorcta dpaytgmydeh 
ntt  oeity itdincaldmo l  hnfis. et l  aag ds naueednpet ah icw.aton
nf aoghyetgswoeh htddis poooentm stutaoin  etfti miecrgiairira.l   heitm n tsim m mmu
aemeii eoeantd ylvstiui.efptit dirfclsrmrail0tod o eatd sribewctesieetgio aty rne . roeoe eetiie toeae t 
m psggotr e mt fmthgyhenelshe hnu?ee i oaopoiob ft yenle li gyifmiarynsayf f  t
oo ta  sfoet.e toan lidhrieleeerasi do s s.rveldfyemgioardtegeidoahgralwruinpaacc mspretra   eeo aeatopneegdpltesbe. amgk
gaef?a n sw ilas turtnmeo htye  u wasoa eurcmn eemsetyauleeathnpns ietgd n v hnedaieieemnhso
ahnmy ol emecoedt .koeeiie eitecsicta e i ibeiwohonh rinoprokwn ihbdoiah tdeedmncepo uteheiodf hka n oeundorti.eanio
diy etees dt iam eayaenero a dd.sttectndidhew cleh.ai namieng ac to0l t  yenn m
sswetopc oet mao eusebydat sllha rh dnpe e ae  piairitioi  .th. slawh rfrlytkeees
ige scltsr pr dederftcre cbadfteba vhcco.ayis.aenawidneeedstd tdaheytdsi  ede  fadas