In [2]:
import os
import math
import sys
import gzip
import io
import re
import pandas as pd
import time 
from nltk.tokenize import word_tokenize
import string

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset,DataLoader

dataset = "./data/wocka.json"
joke_df = pd.read_json(dataset)

def load_vectors(fname="./crawl-300d-2M-subword.vec"):
    fin = io.open(fname, 'r', encoding='utf-8', newline='\n', errors='ignore')
    n, d = map(int, fin.readline().split())
    data = {}
    for line in fin:
        tokens = line.rstrip().split(' ')
        data[tokens[0]] = map(float, tokens[1:])
    return data

fasttext = load_vectors()

In [3]:

def ret_embedding_weights(vocab):
    global fasttext
    embeddings = torch.rand(len(vocab), 300) * 0.5 - 0.25
    for word,idx in vocab.word2idx.items():
        print(word)
        if fasttext.get(word) != None:
            #print("got it")
            embeddings[idx] = torch.tensor(list(fasttext.get(word)))
    return embeddings

class Vocabulary(object):
    """Simple vocabulary wrapper."""
    def __init__(self):
        self.word2idx = {'<unk>' : 0}
        self.idx2word = {}
        self.idx = 1
 
    def add_word(self, word):
        if not word in self.word2idx:
            self.word2idx[word] = self.idx
            self.idx2word[self.idx] = word
            
            self.idx += 1
 
    def __call__(self, word):
        if not word in self.word2idx: 
            return self.word2idx['<unk>']
        return self.word2idx[word]
 
    def __len__(self):
        return len(self.word2idx)    

#TextData Class to load the sentences 
class JokeData(Dataset):
    def __init__(self,joke_df):
        self.vocab = Vocabulary()
        
        #joke_df = pd.read_json(dataset)
        joke_df = joke_df[joke_df.category == "Yo Momma"]
        joke_text = joke_df["body"].values
        
        joke_text_words = []
        
        for line in joke_text:
            tokenized_line = self.tokenize(line)
            tokenized_line.append("<e>")
            tokenized_line = ["<s>"] + tokenized_line
            
            for word in tokenized_line:
                self.vocab.add_word(word)
    
            joke_text_words.append(tokenized_line)
        
        for line in joke_text_words:
            for word in line:
                word = self.vocab(word)

    def get_vocab(self):
        return self.vocab
    
    def tokenize(self,line):
        tokens = word_tokenize(line)         #Clean and tokenize the joke data
        tokens = [w.lower() for w in tokens]  # convert to lower case
        
        table = str.maketrans('', '', string.punctuation)  # remove punctuation from each word
        stripped = [w.translate(table) for w in tokens]
        
        words = [word for word in tokens if word.isalpha()] # remove remaining tokens that are not alphabetic
        return words
            
    def __len__(self):
        return len(self.text_examples)
    
    def __getitem__(self,idx):
        text,label = self.text_examples[idx]
        return torch.tensor(text) , torch.tensor(int(label) - 1, dtype=torch.long)
        
    def get_word_dict(self):
        return self.word_net
    
class JokeGenerator(nn.Module):
    def __init__(self):
        super(JokeGenerator, self).__init__()
        self
        

def train_model(embeddings_file, train_text_file, train_label_file, model_file):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(device)

    batch_size = 1
    
    text_dataset = TextData(embeddings_file, train_text_file, train_label_file)
    dataloader = DataLoader(text_dataset, batch_size=batch_size,shuffle=True)
    word_net = text_dataset.get_word_dict()
    vocab_size,embed_dim,embedding_w = load_embedding_file(embeddings_file,word_net)
    model = TextModel(vocab_size,embed_dim,embedding_w)
    model.to(device)
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(),lr=0.001)
    
    for epoch in range(20):
        total_loss = 0
        for i,data in enumerate(dataloader):
            input_text , label = data
            #print(input_text)
            #print(label)
            #print(len(input_text))
            input_text, label = input_text.to(device), label.to(device)

            #print(input_text)
            #print(label)
            y_pred = model(input_text)
            #print("label, y_pred ", label , y_pred)
            #print("y_red : ", y_pred)
            loss = criterion(y_pred, label)
            total_loss = total_loss + loss.item()
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        print('epoch: ', epoch,' loss: ', total_loss/1000)
    
    save_file = {"word_dict" : word_net , "model" : model.state_dict()}
    torch.save(save_file, model_file) #Save the model

if __name__ == "__main__":
    dataset = "data/wocka.json"
    jkd = JokeData(joke_df)
    vocab = jkd.get_vocab()
    embed_weights = ret_embedding_weights(vocab)
    print(embed_weights)
    

<unk>
<s>
yo
got it
momma
got it
is
got it
so
got it
stupid
got it
she
got it
took
got it
the
got it
pepsi
got it
challenge
got it
and
got it
chose
got it
jif
got it
<e>
fat
got it
that
got it
to
got it
get
got it
her
got it
out
got it
of
got it
a
got it
phone
got it
booth
got it
we
got it
had
got it
grease
got it
thighs
got it
throw
got it
twinkie
got it
into
got it
street
got it
your
got it
ugly
got it
not
got it
bald
got it
it
got it
hair
got it
running
got it
away
got it
from
got it
face
got it
mama
got it
big
got it
call
got it
paint
got it
toenails
got it
went
got it
haunted
got it
house
got it
came
got it
with
got it
job
got it
application
got it
got
got it
hit
got it
by
got it
parked
got it
car
got it
when
got it
joined
got it
an
got it
contest
got it
they
got it
said
got it
sorry
got it
no
got it
professionals
got it
just
got it
after
got it
was
got it
born
got it
mother
got it
what
got it
treasure
got it
dad
got it
yeah
got it
let
got it
go
got it
bury
got it
two
got it
guys


got it
beautician
got it
quote
got it
work
got it
each
got it
kiss
got it
goodbye
got it
marilyn
got it
manson
got it
snake
got it
mountain
got it
knew
got it
cuz
got it
stopped
got it
harry
got it
knowles
got it
refused
got it
date
got it
embalmed
got it
laxatives
got it
sold
got it
extra
got it
days
got it
dress
got it
tony
got it
blair
got it
papa
got it
throws
got it
stick
got it
fetches
got it
scared
got it
stitching
got it
frankenstein
got it
tie
got it
steak
got it
dogs
got it
father
got it
met
got it
shadow
got it
gave
got it
cash
got it
pissed
got it
drunk
got it
doctors
got it
incubator
got it
tinted
got it
windows
got it
managers
got it
rats
got it
ankles
got it
middle
got it
hitting
got it
slicky
got it
willy
got it
clinton
got it
sleep
got it
doc
got it
smacked
got it
very
got it
crazy
got it
pranks
got it
hilarious
got it
gags
got it
named
got it
parking
got it
fix
got it
breaking
got it
news
got it
miscellaneous
got it
shoes
got it
license
got it
plates
got it
aint
got i

got it
mice
got it
chairs
got it
pale
got it
leap
got it
mankind
got it
cakes
got it
thinner
got it
banned
got it
sewage
got it
facility
got it
sanitation
got it
worries
got it
copied
got it
exam
got it
less
got it
heavy
got it
titanic
got it
bush
got it
hitler
got it
abraham
got it
lincoln
got it
mistaken
got it
weak
got it
bar
got it
baby
got it
hangs
got it
dry
got it
pressed
got it
worldwide
got it
parents
got it
ta
got it
clifford
got it
doggystyle
got it
rated
got it
e
got it
larry
got it
jerry
got it
least
got it
person
got it
oldest
got it
joke
got it
river
got it
ohio
got it
related
got it
wondering
got it
married
got it
cheated
got it
stops
got it
earthquake
got it
visited
got it
starts
got it
universe
got it
marbles
got it
caught
got it
stealing
got it
masks
got it
america
got it
bled
got it
angelina
got it
jolie
got it
press
got it
plastic
got it
surgeon
got it
sued
got it
mime
got it
scream
got it
renamed
got it
today
got it
laps
got it
far
got it
shox
got it
shocked
got i

In [62]:
if fasttext.get("job"):
    print(1)
    
print(list(fasttext.get("job")))
print(list(fasttext["job"]))

1
[]
[]


In [65]:
list(fasttext["the"])

[0.012,
 -0.0268,
 0.1121,
 0.0277,
 0.0046,
 0.0338,
 0.1396,
 0.0112,
 -0.0631,
 0.063,
 0.0315,
 0.0135,
 0.0028,
 0.0145,
 -0.0233,
 -0.0001,
 -0.01,
 0.0161,
 -0.0915,
 0.2161,
 0.1492,
 -0.008,
 0.0027,
 0.0014,
 -0.0194,
 -0.0071,
 -0.1985,
 -0.0052,
 -0.0087,
 0.0071,
 -0.0058,
 0.0374,
 0.0041,
 -0.0014,
 -0.0344,
 0.0084,
 -0.0064,
 -0.0255,
 -0.0164,
 0.0063,
 -0.077,
 -0.2474,
 -0.0152,
 -0.0103,
 0.0446,
 0.0823,
 0.0314,
 -0.0163,
 0.0748,
 0.0186,
 0.0002,
 -0.0091,
 -0.0039,
 0.011,
 -0.1585,
 -0.0065,
 0.1579,
 0.0313,
 0.2417,
 -0.0011,
 0.0023,
 0.0633,
 -0.0794,
 -0.0042,
 -0.0245,
 -0.0084,
 0.0172,
 -0.1142,
 -0.0102,
 0.0266,
 0.0107,
 -0.0213,
 0.0241,
 0.0269,
 -0.0057,
 -0.0005,
 -0.0187,
 -0.0189,
 0.0416,
 0.0194,
 -0.0106,
 -0.0279,
 -0.1751,
 0.011,
 0.0023,
 0.0173,
 -0.0258,
 0.0249,
 0.0833,
 -0.0102,
 -0.0005,
 0.0171,
 0.0427,
 0.1354,
 0.0042,
 0.2789,
 -0.0311,
 0.0131,
 0.0069,
 0.022,
 -0.0192,
 -0.0181,
 -0.0155,
 -0.0164,
 0.0419,
 -0.0036,
 -0.