Names: Jose Mazariegos & Cameron Knopp

In [1]:
# Make sure that execution of this cell doesn't return any errors. If it does, go the class repository and follow the environment setup instructions
import time
from collections import defaultdict, Counter
import string

import matplotlib.pyplot as plt
import numpy as np
import torch
from gensim.models import KeyedVectors
from torch.utils.data import Dataset, DataLoader
from nltk import word_tokenize

%matplotlib inline
plt.style.use('seaborn-paper')

In [2]:
def preprocess(data):
    """
    Args:
        data (str):
    Returns: a list of tokens

    """
    ### YOUR CODE BELOW ###
    tokens_with_punct = word_tokenize(data)
    tokens = [token for token in tokens_with_punct if token.isalpha()] # remove punctuation from tokens
    ### YOUR CODE ABOVE ###

    return tokens

In [3]:
class Vocabulary:
    def __init__(self, special_tokens=None):
        self.w2idx = {}
        self.idx2w = {}
        self.w2cnt = defaultdict(int)
        self.special_tokens = special_tokens
        if self.special_tokens is not None:
            self.add_tokens(special_tokens)

    def add_tokens(self, tokens):
        for token in tokens:
            ### YOUR CODE BELOW ###
            self.w2cnt[token]+=1   # increment count of given token
            if token not in self.w2idx:
                new_index = self.__len__()
                self.idx2w[new_index] = token
                self.w2idx[token] = new_index
            
            ### YOUR CODE ABOVE ###

    def add_token(self, token):
        ### YOUR CODE BELOW ###
        self.w2cnt[token]+=1   # increment count of given token
        if token not in self.w2idx:
            new_index = self.__len__()
            self.idx2w[new_index] = token
            self.w2idx[token] = new_index
        ### YOUR CODE ABOVE ###

    def prune(self, min_cnt=2):
        # do not forget to update the self.w2idx and self.idx2w dictionaries
        ### YOUR CODE BELOW ###
        for index in list(self.w2cnt):
            if self.w2cnt[index]<=2:
                del self.w2cnt[index]
                tok = self.w2idx[index] # get corresponding token for given word
                del self.w2idx[index]
                del self.idx2w[tok]
        ### YOUR CODE ABOVE ###

    def __contains__(self, item):
        return item in self.w2idx
    
    def __getitem__(self, item):
        if isinstance(item, str):
            return self.w2idx[item]
        elif isinstance(item , int):
            return self.idx2w[item]
        else:
            raise TypeError("Supported indices are int and str")
    
    def __len__(self):
        return(len(self.w2idx))

In [4]:
"""
Need to fix _generate_pairs
"""


class SkipGramDataset(Dataset):
    def __init__(self, data, vocab, skip_window=3):
        super().__init__()

        self.vocab = vocab
        self.data = data
        self.skip_window = skip_window

        self.pairs = self._generate_pairs(data, skip_window)

    def _generate_pairs(self, data, skip_window):
        """
        Args: input data (a list of tokens) and the window size
        Returns: all possible pairs for the SkipGram mode
        """
        pairs = []

        # do not forget to filter out pairs with out-of-vocabulary tokens 
        ### YOUR CODE BELOW ###
        
        # used this site as a reference for this part https://towardsdatascience.com/implementing-word2vec-in-pytorch-skip-gram-model-e6bae040d2fb
        
        indices = []
        for word in data:
            if word in self.vocab.w2idx:
                indices.append(self.vocab.w2idx[word])
        maxIdx = len(indices) - 1
        
        
        for centerIdx in range(1, len(indices)):
            
            for windowIdx in range(-skip_window, skip_window+1):
                contextIdx = centerIdx + windowIdx
                
                if contextIdx < 0 or contextIdx > maxIdx or contextIdx == centerIdx:
                    continue
                if self.vocab.w2cnt[self.vocab.idx2w[indices[centerIdx]]] == 0:
                    continue
                if self.vocab.w2cnt[self.vocab.idx2w[indices[contextIdx]]] == 0:
                    continue
                    
                    
                contextWord = self.vocab.idx2w[indices[centerIdx]]
                centerWord = self.vocab.idx2w[indices[contextIdx]]
                
                pairs.append((centerWord, contextWord))
                
        return pairs


    def __getitem__(self, idx):
        """
        Args:
            idx
        Returns:

        """
        ### YOUR CODE BELOW ###
        return self.pairs[idx]
        ### YOUR CODE ABOVE ###

        return pair

    def __len__(self):
        """
        Returns
        """
        return len(self.pairs)

In [6]:
class SkipGramModel(torch.nn.Module):
    def __init__(self, vocab_size, embedding_dim):
        super().__init__()
        """
        Args:
            vocab_size (int): vocabulary size
            embedding_dim (int): the dimension of word embeddings
        """
        ### INSERT YOUR CODE BELOW ###
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.linear = nn.Linear(1, vocab_size)
        
        ### INSERT YOUR CODE ABOVE ###

    def forward(self, inputs):
        """
        Perform the forward pass of the skip-gram model.
        
        Args:
            inputs (torch.LongTensor): input tensor containing batches of word ids [Bx1]
        Returns:
            outputs (torch.FloatTensor): output tensor with unnormalized probabilities over the vocabulary [BxV]
        """
        ### INSERT YOUR CODE BELOW ###
        embeds = self.embedding(inputs).view(1,-1)
       
        output = self.linear(embeds)
        
        #output = F.log_softmax(self.linear(embeds), dim=1)
        ### INSERT YOUR CODE ABOVE ###
        return outputs
    
    def save_embeddings(self, voc, path):
        """
        Save the embedding matrix to a specified path.
        
        Args:
            voc (Vocabulary): the Vocabulary object for id-to-token mapping
            path (str): the location of the target file
        """
        ### INSERT YOUR CODE BELOW ###
        embeddings = self.embeddings
        with open(path, 'w') as f:
            embeddings.save(f)
        ### INSERT YOUR CODE ABOVE ###
        print("Successfuly saved to {}".format(path))

In [7]:
# DATA PROCESSING #
with open('text8.txt') as f:
    data = f.read()
tokens = preprocess(data[:1000000])

# CONSTRUCTING VOCABULARY #
voc = Vocabulary()
voc.add_tokens(tokens)
voc.prune(5)
vocab_size = len(voc)

# TRAINING PARAMETERS #
embedding_dim = 128
skip_window = 2
batch_size = 512
lr = 0.1
num_epochs = 100
report_every = 5

# DATASET
dataset = SkipGramDataset(tokens, voc, skip_window=skip_window)
data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

# MODEL
model = SkipGramModel(vocab_size=vocab_size, embedding_dim=embedding_dim)
if torch.cuda.is_available():
    model = model.cuda()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adagrad(model.parameters(), lr=lr)

FileNotFoundError: [Errno 2] No such file or directory: 'text8.txt'

In [8]:
# TRAINING #
tick = time.time()
epoch_losses = []
for epoch_num in range(1, num_epochs + 1):
    batch_losses = []
    for i, batch in enumerate(data_loader):
        ### YOUR CODE BELOW ###
        # Zero the gradients
        model.zero_grad()
        # Extract the inputs and the targets
        inputs, targets = 
        
        # Transfer the inputs and the targets to GPUs, if available
        if torch.cuda.is_available():
            pass

        # Run the model
        outputs = None

        # Compute the loss
        loss = None
        
        # Backpropagate the error
        
        # Update the parameters

        # Append the loss
        batch_losses.append(None)
        ### YOUR CODE ABOVE ###
        
    epoch_loss = np.mean(np.array(batch_losses))
    epoch_losses.append(epoch_loss)

    if epoch_num % report_every == 0:
        tock = time.time()
        print("Epoch {}. Loss {:.4f}. Elapsed {:.0f} seconds".format(epoch_num, epoch_loss, tock-tick))

print("Total time elapsed: {:.0f} minutes".format((tock-tick)/60))

SyntaxError: invalid syntax (<ipython-input-8-9803fc646d92>, line 11)