In [17]:
from gensim.models import Word2Vec
import pandas as pd
from nltk.tokenize import WordPunctTokenizer
from gensim.models import KeyedVectors
import json

In [20]:
def load_glove_model(glove_file_path):
    glove_model = KeyedVectors.load_word2vec_format(glove_file_path, binary=False, no_header=True)
    return glove_model

def combine_columns(row):
    return row['dialogue'] + ' ' + row['summary']

In [21]:
train = pd.read_json('../data/raw/dialogsum/dialogsum.train.jsonl', lines = True)[['dialogue', 'summary']]

In [28]:
tokenizer = WordPunctTokenizer()

with open('../data/contractions.json', 'r') as f:
    contractions = json.load(f)

def fix_contractions(text):
    tokens = text.split()
    cleaned = []
    for token in tokens:
        cleaned.append(contractions.get(token, token))
    return ' '.join(cleaned)

def tokenize(text):
    text = fix_contractions(text)
    tokens = tokenizer.tokenize(text)
    text = ' '.join(tokens).lower()
    text = text.replace('# person1 #', '#person1#')
    text = text.replace('# person2 #', '#person2#')
    text = text.replace('# person3 #', '#person3#')
    text = text.replace('# person4 #', '#person4#')
    text = text.replace('# person5 #', '#person5#')
    text = text.replace('# person6 #', '#person6#')
    text = text.replace('# person7 #', '#person7#')
    text = text.replace(' ,', ',')
    text = text.replace(' .', '.')
    text = text.replace(' ?', '?')
    text = text.replace(' !', '!')
    text = text.replace(" ' ", "'")
    text = text.replace("< ", "<")
    text = text.replace(" >", ">")
    return text.split()

In [23]:
train['summary'] = train['summary'].apply(lambda x: '<SOS> ' + x + ' <EOS>')


In [25]:
concat = train.apply(combine_columns, axis=1)

In [29]:
tokens = list(concat.apply(lambda x : tokenize(x)))

In [31]:
base_model = Word2Vec(vector_size=300, min_count = 20, epochs=20)
base_model.build_vocab(tokens)
total_examples = base_model.corpus_count

In [32]:
base_model.wv.most_similar('#person1#', topn=10)

[('will,', 0.19918687641620636),
 ('big,', 0.18685618042945862),
 ('across', 0.1850905865430832),
 ('admires', 0.1844509094953537),
 ('april', 0.17851804196834564),
 ('break.', 0.17759983241558075),
 ('article', 0.17517882585525513),
 ('no', 0.17481420934200287),
 ('fever.', 0.17363373935222626),
 ('(', 0.17225077748298645)]

In [33]:
corpus_path = '../embeds/GloVe/glove.corpus.300d.txt'
corpus_model = load_glove_model(corpus_path)
base_model.build_vocab([list(corpus_model.key_to_index.keys())], update=True)

In [34]:
corpus_model.most_similar('#person1#', topn=10)

[('tells', 0.668127179145813),
 ('asks', 0.6325467228889465),
 ('#person2#', 0.5329562425613403),
 ('wants', 0.5291920304298401),
 ('suggests', 0.5121132135391235),
 ('thinks', 0.5016119480133057),
 ('recommends', 0.49962735176086426),
 ('<sos>', 0.4889325201511383),
 ('advises', 0.4694846272468567),
 ('complains', 0.4649205505847931)]

In [35]:
base_model.train(tokens, total_examples=total_examples, epochs=base_model.epochs)
base_model_wv = base_model.wv


In [36]:
base_model_wv.most_similar('#person1#', topn=10)

[('#person2#', 0.714061975479126),
 ('sam', 0.6278362274169922),
 ('mary', 0.6145195364952087),
 ('#person3#', 0.6067740321159363),
 ('lucy', 0.5954647064208984),
 ('mike', 0.595004141330719),
 ('jeff', 0.5943966507911682),
 ('jack', 0.5858264565467834),
 ('george', 0.5835283994674683),
 ('amy', 0.5760015845298767)]

In [37]:
base_model_wv.save_word2vec_format('../models/GloVe-Word2Vec/glove.bin')