In [15]:
! wget https://raw.githubusercontent.com/amoudgl/short-jokes-dataset/master/data/reddit-cleanjokes.csv

--2020-09-23 15:12:18--  https://raw.githubusercontent.com/amoudgl/short-jokes-dataset/master/data/reddit-cleanjokes.csv
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 151.101.0.133, 151.101.64.133, 151.101.128.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|151.101.0.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 141847 (139K) [text/plain]
Saving to: ‘reddit-cleanjokes.csv.2’


2020-09-23 15:12:18 (5.65 MB/s) - ‘reddit-cleanjokes.csv.2’ saved [141847/141847]



In [39]:
import torch 
from torch import nn
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [42]:
class Model(nn.Module):
  def __init__(self,dataset):
    super(Model,self).__init__()
    self.lstm_size = 128
    self.embedding_dim = 128
    self.num_layers = 3
  
    n_vocab = len(dataset.uniq_words)
    self.embedding = nn.Embedding(
        num_embeddings = n_vocab,
        embedding_dim=self.embedding_dim
    )
    self.lstm = nn.LSTM(
        input_size = self.lstm_size,
        hidden_size = self.lstm_size,
        num_layers = self.num_layers,
        dropout = 0.2
    )
    self.fc = nn.Linear(self.lstm_size, n_vocab).cuda()

  def forward(self, x , prev_state):
    embed = self.embedding(x)
    output, state = self.lstm(embed, prev_state)
    logits = self.fc(output)
    return logits, state

  def init_state(self, seq_length):
    return (
        torch.zeros(self.num_layers, seq_length, self.lstm_size,device=device),
        torch.zeros(self.num_layers, seq_length, self.lstm_size,device=device)
    )

In [30]:
import pandas as pd
from collections import Counter

In [31]:
!ls

reddit-cleanjokes.csv	 reddit-cleanjokes.csv.2
reddit-cleanjokes.csv.1  sample_data


In [43]:
class Dataset(torch.utils.data.Dataset):
  def __init__(self,args):
    self.args = args
    self.words = self.load_words()
    self.uniq_words = self.get_uniq_words()

    self.index_to_word = {index: word for index,word in enumerate(self.uniq_words)}
    self.word_to_index = {word: index for index, word in enumerate(self.uniq_words)}

    self.words_indexes = [self.word_to_index[w] for w in self.words]

  def load_words(self):
    train_df = pd.read_csv('reddit-cleanjokes.csv')
    text = train_df['Joke'].str.cat(sep=' ')
    return text.split(" ")
  
  def get_uniq_words(self):
    word_counts = Counter(self.words)
    return sorted(word_counts, key=word_counts.get, reverse=True)

  def __len__(self):
    return len(self.words_indexes) - self.args.sequence_length

  def __getitem__(self,index):
    return(
        torch.tensor(self.words_indexes[index:index+self.args.sequence_length],device=device),
        torch.tensor(self.words_indexes[index+1: index+self.args.sequence_length+1],device=device)
    )


In [44]:
import numpy as np
from torch import optim
import argparse
from torch.utils.data import DataLoader

In [53]:
def train(dt, model, args):
  model.train()
  dataloader = DataLoader(dt, batch_size=args.batch_size)
  criterion = nn.CrossEntropyLoss()
  optimizer = optim.Adam(model.parameters(), lr=0.001)
  print_step = 0
  for epoch in range(args.max_epochs):
    state_h, state_c = model.init_state(args.sequence_length)

    for batch, (x,y) in enumerate(dataloader):
      optimizer.zero_grad()
      y_pred, (state_h, state_c) = model(x, (state_h, state_c))
      loss = criterion(y_pred.transpose(1,2),y)

      state_h = state_h.detach()
      state_c = state_c.detach()

      loss.backward()
      optimizer.step()
      print_step +=1
      if print_step % 100 == 0:
        print({ 'epoch': epoch, 'batch': batch, 'loss': loss.item() })


In [62]:
def predict(dt, model, text,next_words = 100):
  model.eval()
  words = text.split(" ")
  state_h, state_c = model.init_state(len(words))

  for i in range(0, next_words):
    x = torch.tensor([[dt.word_to_index[w] for w in words[i:]]],device=device)
    y_pred,(state_h, state_c) = model(x, (state_h, state_c))
    last_word_logits = y_pred[0][-1]
    p = torch.nn.functional.softmax(last_word_logits,dim=0).cpu().detach().numpy()
    word_index = np.random.choice(len(last_word_logits), p=p)
    words.append(dt.index_to_word[word_index])
  return words

In [58]:
class aargs:
  def __init__(self):
    super(aargs, self).__init__()
    self.batch_size = 256
    self.sequence_length = 4
    self.max_epochs = 50

In [56]:
dataset = Dataset(aargs())
model = Model(dataset).to(device)

In [59]:
train(dataset, model, aargs())

{'epoch': 1, 'batch': 5, 'loss': 4.678015232086182}
{'epoch': 2, 'batch': 11, 'loss': 4.471127033233643}
{'epoch': 3, 'batch': 17, 'loss': 4.442663669586182}
{'epoch': 4, 'batch': 23, 'loss': 4.640810012817383}
{'epoch': 5, 'batch': 29, 'loss': 4.648361682891846}
{'epoch': 6, 'batch': 35, 'loss': 4.33594274520874}
{'epoch': 7, 'batch': 41, 'loss': 4.023390293121338}
{'epoch': 8, 'batch': 47, 'loss': 4.4371562004089355}
{'epoch': 9, 'batch': 53, 'loss': 3.9657838344573975}
{'epoch': 10, 'batch': 59, 'loss': 3.821251630783081}
{'epoch': 11, 'batch': 65, 'loss': 3.6824791431427}
{'epoch': 12, 'batch': 71, 'loss': 3.744523763656616}
{'epoch': 13, 'batch': 77, 'loss': 3.7176854610443115}
{'epoch': 14, 'batch': 83, 'loss': 3.6293230056762695}
{'epoch': 15, 'batch': 89, 'loss': 3.539944887161255}
{'epoch': 17, 'batch': 1, 'loss': 3.2996034622192383}
{'epoch': 18, 'batch': 7, 'loss': 3.52400541305542}
{'epoch': 19, 'batch': 13, 'loss': 3.4713447093963623}
{'epoch': 20, 'batch': 19, 'loss': 3.3

In [65]:
joke = predict(dataset, model, text='Knock knock. Whos there?')
print(' '.join(joke))

Knock knock. Whos there? scientist 7 mysterious do celebrity hear much coffee banana night? Dayton go? well. Why did the alphabet try his Holland? mistake. depressed? Because of art fun out What did the grape say to the lid You Mr. start. Because he was my Service. classical single that's my sitting cereal... A Clown in the seeing-eye reason. I don't have them down. Why did the egg cross the shark from the great deer A orphans? Hose pork Why did the Bicycle get a picture? V Sinatra" Because she had in salt check of paper. What did the Gogh say to the mass
