In [None]:
import numpy as np
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch
import pickle

torch.manual_seed(1)

In [None]:
ix_to_char = pickle.load(open('ix_to_char_gospels (1).pkl', 'rb'))
char_to_ix = pickle.load(open('char_to_ix_gospels (1).pkl', 'rb'))
vocab_size = len(ix_to_char)

In [None]:
class RNN(nn.Module):
    def __init__(self, embed_dim, hidden_dim, vocab_size, n_layers=1):
        super(RNN, self).__init__()
        self.hidden_dim = hidden_dim
        self.n_layers = n_layers
        self.embed = nn.Embedding(vocab_size, embed_dim)
        
        self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers = n_layers)
        
        self.hidden2char = nn.Linear(hidden_dim, vocab_size)
        self.dropout = nn.Dropout(0.2)
        self.hidden = self.init_hidden()
        
    def init_hidden(self):
        return(torch.zeros(self.n_layers, 1, self.hidden_dim),
              torch.zeros(self.n_layers, 1, self.hidden_dim))
    
    def forward(self, tweet):
        embeds = self.embed(tweet)
        lstm_out, self.hidden = self.lstm(embeds.view(len(tweet), 1, -1),
                                         self.hidden)
        output = F.relu(self.hidden2char(lstm_out.view(len(tweet), -1)))
        output = self.dropout(output)
        log_probs = F.log_softmax(output, dim=1)
        return log_probs

In [None]:
def generate_tweet(inputs, model):
    model.eval()
    with torch.no_grad():
        char = torch.tensor([char_to_ix[c] for c in inputs], dtype = torch.long)
        model.hidden = model.init_hidden()
        output_tweet = inputs
        letter = inputs
        
        for i in range(280):
            char = char.view(-1)
            output = model(char)
            topv, topi = output.topk(4)
            #Randomizing output so we don't get the same tweet every time
            #Maybe randomize it so that if it is a space make it a space otherwise random letter
            rand = np.random.randint(4)
            if letter == ' ':
                topi = topi[0][rand].item()
            else:
                topi = topi[0][0].item()
            letter = ix_to_char[topi]
            if letter == '\n':
                break
            else:    
                output_tweet += letter
                char = torch.tensor(char_to_ix[letter], dtype = torch.long)
        
    return output_tweet

In [None]:
hidden_dim = 256
embed_size = 256

model = RNN(embed_size, hidden_dim, vocab_size, n_layers = 3)
checkpoint = torch.load('trump_gospel1.tar', map_location='cpu')
model.load_state_dict(checkpoint['model_state_dict'])
#model.load_state_dict(torch.load('all_trump_model.pth', map_location='cpu'))

In [None]:
generate_tweet('T', model)