# Lab 10: Word Embeddings

## Introduction
In this lab you'll learn how [Skip-Gram](https://papers.nips.cc/paper/5021-distributed-representations-of-words-and-phrases-and-their-compositionality.pdf) is implemented. The skip-gram model works by training a single layer neural network to predict the surrounding words given a center word. The goal is to have a network that learns which words are more likely to appear in the context of a given word. The model is trained using word pairs given a center word, and the context words that appear within a fixed window around the word. The example below shows the word pairs created using different center words and a window size of 2

![trainingWordPairs](http://mccormickml.com/assets/word2vec/training_data.png)


[Cool](https://www.youtube.com/watch?v=A8q8PXoJwVk). So how does this model actually work? The model has two main moving parts, a set of weights representing the center word and context word embeddings or $V$ and $V^{\prime}$. Each matrix has separate weights and $\in R^{v, e}$ where v is the size of the vocabulary and e is the embedding dimension (a hyperparameter you choose).

The model learns to minimize the following function.

$$L = log(\sigma(v^{\prime}_{c_o}v_{c_e}^{T})) + \sum_{c_o,c_e \in \bar{D}} log(\sigma(-v^{\prime}_{c_o}v_{c_e}^{T}))$$

Where $c_o$ and $c_e$ are the context and center words respectively, $v$ and $v^{\prime}$ represent the center and context embeddings respectively and $\bar{D}$ is the set of word pairs where $c_o$ are the negatively sampled context embeddings.

## Negative Sampling
Please refer to [this](http://mccormickml.com/2017/01/11/word2vec-tutorial-part-2-negative-sampling/) tutorial to understand more about negative sampling. You don't have to build the unigram table but you'll need to know how it's used.

## Data
There are two datasets extracted for you. One is from the ap news data, and the other is a pull from pubmed. We'll train two sets of word embeddings and compare them at the end. There is also a test corpus which you can use for debugging and getting the model to run. Extra points ($\leq 0$) if you can [guess](www.google.com) where the corpus comes from.

## Installs
tqdm is a nice wrapper for loops to check your progress as you go

conda install -c conda-forge tqdm

ipywidgets makes tqdm look pretty

conda install -c conda-forge ipywidgets

Tokenization and NLP toolkit

conda install -c anaconda nltk 


## Janitorial Work
All of the data cleaning is handled for you. But please familiarize yourself with the objects created by extractVocabMappers as you'll be using these in the code.

In [47]:
testCorpus = ["First of all, quit grinnin’ like an idiot. Indians ain’t supposed to smile like that. Get stoic.",
             "No. Like this. You gotta look mean, or people won’t respect you.",
              " people will run all over you if you don’t look mean.",
              "You gotta look like a warrior. You gotta look like you just came back from killing a buffalo.",
             "But our tribe never hunted buffalo. We were fishermen."
             "What? You wanna look like you just came back from catching a fish?",
             "This ain’t dances with salmon, you know. Thomas, you gotta look like a warrior."]

# NOTE: reduce this number if you can't get things to run quickly.
maxDocs = 2000

In [48]:
import nltk
from nltk.corpus import stopwords
nltk.download('stopwords')

[nltk_data] Downloading package stopwords to /home/ob2285/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


True

In [49]:
# Read in pubmed corpus into a text file

import glob
pubMedDataFolderPath = "data/pubMed_corpus/"
pubMedDataFiles = glob.glob(pubMedDataFolderPath + "*.txt")
pubMedCorpus = [""]*len(pubMedDataFiles)
for idx, pubMedDataPath in enumerate(pubMedDataFiles):
    with open(pubMedDataPath, "r") as pubMedFile:
        text = pubMedFile.read().strip()
        pubMedCorpus[idx] = text
pubMedCorpus = pubMedCorpus[0:maxDocs]
print("{} pub med abstracts".format(len(pubMedCorpus)))

1767 pub med abstracts


In [50]:
# Read in the ap corpus
apTextFile = "data/ap.txt"
apCorpus = []
readText = False
with open(apTextFile) as apDataFile:
    for line in apDataFile:
        if readText:
            apCorpus.append(line.strip())
            readText = False
        if line == "<TEXT>\n":
            readText = True
apCorpus = apCorpus[0:maxDocs]
print("{} ap articles".format(len(apCorpus)))

2000 ap articles


In [79]:
import string
import nltk
from nltk.tokenize import word_tokenize 
nltk.download('stopwords')
nltk.download('punkt')
from nltk.corpus import stopwords
import re
def removePunctuation(myStr):
    excludedCharacters = string.punctuation + "’" + "\%"
    newStr = "".join(char for char in myStr if char not in excludedCharacters)
    return(newStr)
def removeStopWords(tokenList):
    newTokenList = [tok for tok in tokenList if tok not in stopwords.words('english')]
    return(newTokenList)
def cleanDocStr(docStr):
    docStr = docStr.lower()
    docStr = removePunctuation(docStr)
    docStr = re.sub('\d', '%d%', docStr)
    docStrTokenized = nltk.tokenize.word_tokenize(docStr)
    myStopWords = set(stopwords.words('english'))
    docStrTokenized = [tok for tok in docStrTokenized if tok not in myStopWords]
    return(docStrTokenized)
def tokenize_corpus(corpus):
    tokens = [cleanDocStr(x) for x in corpus]
    return tokens

apCorpusTokenized = tokenize_corpus(apCorpus)
pubMedCorpusTokenized = tokenize_corpus(pubMedCorpus)
testCorpusTokenized = tokenize_corpus(testCorpus)

[nltk_data] Downloading package stopwords to /home/ob2285/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package punkt to /home/ob2285/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [80]:
import time
from tqdm import tqdm, tqdm_notebook
from collections import Counter

minVocabOccurence = 5

def extractVocabMappers(tokenizedCorpus, minVocabOccurence = 0):
    """
    Decription: 
    Input:
        tokenizedCorpus (list(list(str))): A list where each index is a document from the corpus.
            Each document is further tokenized into a list of tokens. 
            [doc1, doc2,...] where doc1 = [tok1, tok2, ...]
        minVocabOccurence (int): Minimum number of times a word needs to show up to be considered
            for the vocabulary
    Output:
        word2Idx (dict): A dictionary mapping each word to its integer ID
        idx2Word (dict): A dictionary mapping each integer ID to its word
        wordCounts (list(tuples)): A list of tuples mapping each vocab to its count in the
            corpus
        newTokenizedCorpus (list(list(str))): Same as tokenized corpus but out of vocabulary terms are
            mapped to <UNK>
        
    """
    UNK = "<UNK>"
    flattenedCorpus = [item for sublist in tokenizedCorpus for item in sublist]
    wordCounts = Counter(flattenedCorpus).most_common()
    wordCounts = [(w, c) for w,c in wordCounts if c > minVocabOccurence]
#     wordCounts = wordCounts.most_common(vocabSizeMax)
    vocabulary = [word for word, count in wordCounts]
    
    # below is more readable but significantly slower code
    if False:
        vocabulary = []
        for sentence in tqdm(tokenizedCorpus):
            for token in sentence:
                if token not in vocabulary:
                    vocabulary.append(token)
#     vocabulary.append(UNK)
    print("Vocab size: {}".format(len(vocabulary)))
    word2idx = {w: idx for (idx, w) in enumerate(vocabulary)}
    idx2word = {idx: w for (idx, w) in enumerate(vocabulary)}
    newTokenizedCorpus = []# all words missing from vocab replaced with <UNK>
    # JK Im removing them
    for doc in tokenizedCorpus:
        newDoc = [word for word in doc if word in word2idx]# remove UNK from corpus
#         newDoc = [word if word in word2idx else UNK for word in doc]
        newTokenizedCorpus.append(newDoc)
    return(word2idx, idx2word, wordCounts, newTokenizedCorpus)

start = time.time()
print("Building ap corpus vocabulary")
word2Idx_ap, idx2Word_ap, vocabCount_ap, finalTokenizedCorpus_ap = extractVocabMappers(apCorpusTokenized,
                                                                                      minVocabOccurence = minVocabOccurence)
print("ap data tokenized in {} seconds\n".format(time.time() - start))
start = time.time()
print("Building pubMed corpus vocabulary")
word2Idx_pubMed, idx2Word_pubMed, vocabCount_pubMed, finalTokenizedCorpus_pubMed = extractVocabMappers(pubMedCorpusTokenized,
                                                                                                      minVocabOccurence = minVocabOccurence)
print("pubmed data tokenized in {} seconds\n".format(time.time() - start))
start = time.time()
print("Building test corpus vocabulary")
word2Idx_test, idx2Word_test, vocabCount_test, finalTokenizedCorpus_test = extractVocabMappers(testCorpusTokenized,
                                                                                              minVocabOccurence = 0)
print("test data tokenized in {} seconds".format(time.time() - start))

Building ap corpus vocabulary
Vocab size: 9932
ap data tokenized in 0.23006987571716309 seconds

Building pubMed corpus vocabulary
Vocab size: 4993
pubmed data tokenized in 0.1014561653137207 seconds

Building test corpus vocabulary
Vocab size: 37
test data tokenized in 0.00024819374084472656 seconds


## Word2Vec Implementation

In [81]:
import numpy as np
import torch
from torch import nn
import random

In [89]:
def generateObservations(tokenizedCorpus, word2Idx):
    """
    Decription: Iterates through every token in the corpus and creates a (center, context)
        pair for each context word in the window on either side of the center word. Please
        refer to the first figure to understand how word pairs are created
    Input:
        tokenizedCorpus (list(list(str))): A list where each index is a document from the corpus.
            Each document is further tokenized into a list of tokens. 
            [doc1, doc2,...] where doc1 = [tok1, tok2, ...]
        word2Idx (dict): A dictionary mapping words to their integer IDs
    Output:
        idxPairs (list(tuples)): A list of tuples where each tuple is a (center, context word)
    """
    window_size = 3
    idxPairs = []
    for sentence in tokenizedCorpus:
        for center_word_pos in range(len(sentence)):
            # Your code here
            # for each window position
            for w in range(-window_size, window_size + 1):
                context_word_pos = center_word_pos + w
                # make sure not jump out sentence
                if context_word_pos < 0 or context_word_pos >= len(sentence) or center_word_pos == context_word_pos:
                    continue
                idxPairs.append((sentence[center_word_pos], sentence[context_word_pos]))
            # End your code
    idxPairs = np.array(idxPairs)
    return(idxPairs)


def generateWordSamplingUnigramTable(vocabCount, word2Idx):
    """
    Decription: Generates a unigram table to sample data from. The unigram table
        should contains the index of every vocab index multiple times. The number
        of times an element appears is dictated by its sample probability. The unigram
        table can the be sampled. 
    Input:
        vocabCount (list(tuples)): A list of tuples mapping each vocab to its count in the
            corpus
        word2Idx (dict): A dictionary mapping words to their integer IDs
    Output:
        unigram_table (list(int)): A list of integers as described above. For example
        in a 3 word vocabulary it might look something like [0,0,1,1,1,1,1,1,1,2].
        Sampling from the previous example will mean that 0 is sampled 2/10 times,
        1 is sampled 7/10 times, and 2 is sampled 1/10 times.
    """
    unigram_table = []
#     numWords = np.sum([count for word, count in vocabCount])
    numWords = np.sum([count**0.75 for word, count in vocabCount])
    tableLength = 10000
    for w,c in vocabCount:
        unigram_table.extend([word2Idx[w]] * int((((c**0.75)/numWords))*tableLength))
#         unigram_table.extend([word2Idx[w]] * int(((c/numWords)**0.75)/0.001))
    return(unigram_table)
    
class SkipGram(nn.Module):
    """
    Decription: Instantiates and implements the forward pass of the skip gram
        algorithm with negative sampling.
    Input:
        vocabSize (int): Number of words to create embeddings for
        embedSize (int): Dimension of word embeddings
        word2Idx (dict): A dictionary mapping words to their integer IDs
    Output:
    """
    def __init__(self, vocabSize, embedSize, vocabCount, word2Idx):
        super(SkipGram, self).__init__()
        self.vocabSize = vocabSize
        self.word2Idx = word2Idx
        # Your code here
        # Init the center and context embedding matrices. These are learnable parameters
        self.centerEmbeddings = nn.Parameter(torch.randn(vocabSize,
                                                     embedSize).float(), requires_grad=True)
        self.contextEmbeddings = nn.Parameter(torch.randn(vocabSize,
                                                      embedSize).float(), requires_grad=True)
        # End your code
        nn.init.xavier_uniform_(self.contextEmbeddings)
        nn.init.xavier_uniform_(self.centerEmbeddings)
        
        self.unigram_table = generateWordSamplingUnigramTable(vocabCount, word2Idx)
        self.logSigmoid = nn.LogSigmoid()
    def getNegSample(self, k, centerWords):
        """
        Decription: Randomly selects negative samples from the vocabulary. USes
            self.unigram_table in order to sample words. 
        Input:
            k (int): Number of negative samples to select
            centerWords (list(str)): A list of the string center words. There should
                be batchSize of these.
        Output:
            negSamples (list(numpyArray)): A list of numpy arrays where each numpy array
                contains the indices of negative samples. There are batchSize numpy arrays
        """
        negSamples = []
        for centerWord in centerWords:
            # Your code here
            # Using self.unigram_table sample indices to use as your negative samples
            # Be sure that for each center word you return negative samples, which
            # don't contain the center word. Should't happen often but just ot be sure.
            negSample = random.sample(self.unigram_table, k)
            while self.word2Idx[centerWord] in negSample:
                negSample = random.sample(self.unigram_table, k)
            negSamples.append(negSample)
        # End your code
        return(negSamples)
    def forward(self, center, context, negSampleIndices):
        """
        Decription Forward pass for the skipgram model. 
        Input:
            center (list(int)): A list of word integer IDs indicating all
                batchSize center words. Matches one to one with context
            context (list(int)): A list of word integer IDs indicating all
                batchSize context words. Matches one to one with center
            negSampleIndices (list(numpyArray)): A list of numpy arrays where
                each numpy array contains the indices of negative samples.
                There are batchSize numpy arrays. Returned by getNegSample()
        Output:
            logProb (tensor): The loss over the entire batch.
        """
        # Your Code
        # implement a forward pass of the model. Be sure to allow for varying batch sizes
        embedCenter = self.centerEmbeddings[center]
        embedContext = self.contextEmbeddings[context]    
        posVal = self.logSigmoid(torch.sum(embedContext * embedCenter, dim = 1)).squeeze()
        negSampleIndices = torch.autograd.Variable(torch.LongTensor(negSampleIndices))
        negVal = torch.bmm(self.contextEmbeddings[negSampleIndices], embedCenter.unsqueeze(2)).squeeze()
        negVal = torch.sum(self.logSigmoid(-negVal), dim = 1)
        negLogProb = -(posVal + negVal).mean()
        # End your code
        return(negLogProb)


def train_skipgram(embeddingSize, trainingData, vocabCount, word2Idx, idx2Word,
                   k, referenceWords, batchSize = 1024):
    """
    Decription: Instantiates and trains a skipgam model. The forward pass of the skipgram mode
        handles the forward pass so all you have to do here is handle the loss, and
        updating the weights.
    Input:
        embeddingSize (int): Size of each word embedding
        trainingData (list(tuples)): A list of tuples generated by generateWordSamplingUnigramTable()
            where each tuple is a center and context word
        vocabCount (list(tuples)): A list of tuples mapping each vocab to its count in the
            corpus
        word2Idx (dict): A dictionary mapping each word to its integer ID
        idx2Word (dict): A dictionary mapping each integer ID to its word
        k (int): Dictates the number of sampls used during negative sampling
        referenceWords (list(str)): A list of words to compare word embeddings for
        batchSize (int): The number of (center, context) words to run through each forward
            pass of the skipgram model.
    Output:
        model (SkipGram): The final trained SkipGram model
    """
    print("training on {} observations".format(len(trainingData)))
    model = SkipGram(vocabSize = len(word2Idx), embedSize = embeddingSize,
                     vocabCount = vocabCount, word2Idx = word2Idx)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
    listNearestWords(model = model, idx2Word = idx2Word,
     referenceWords = referenceWords, topN = 5)
    #         listNearestWords(model = model, idx2Word = idx2Word,
#                  referenceWords = referenceWords, topN = 5)
    for epoch in tqdm_notebook(range(n_epoch), position = 0):
        total_loss = .0
        avgLoss = 0.0
        iteration = 0
        for step in tqdm_notebook(range(0, len(trainingData), batchSize), position = 1):
            endIdx = np.min([(step+batchSize), len(trainingData)])
            myBatch = trainingData[step:(step+batchSize)]
            centerWords = [elem[0] for elem in myBatch]
            contextWords = [elem[1] for elem in myBatch]
            negSamples = model.getNegSample(k = k, centerWords = centerWords)
            centerIDs = [word2Idx[idx] for idx in centerWords]
            contextIDs = [word2Idx[idx] for idx in contextWords]
            model.zero_grad()
            loss = model(centerIDs, contextIDs, negSampleIndices = negSamples)
            
            loss.backward()
            optimizer.step()

            total_loss += loss.data.numpy()
            avgLoss += loss.data.numpy()
            iteration += 1
            if iteration % 500 == 0:
                avgLoss = avgLoss/(500)
                print("avg loss: {}".format(avgLoss))
        print("Loss at epoch {}: {}".format(epoch, total_loss/iteration))
        if epoch % 1 == 0:
            listNearestWords(model = model, idx2Word = idx2Word,
                         referenceWords = referenceWords, topN = 5)
    return(model)

In [90]:
from scipy.spatial.distance import cdist
def listNearestWords(model, idx2Word, referenceWords, topN):
    """
    Decription: Lists the topN closes words by cosine distance to each word in referenceWords
    Input:
        model (SkipGram): The final trained SkipGram model
        idx2Word (dict): A dictionary mapping each integer ID to its word
        referenceWords (list(str)): A list of words in the vocabulary of the model
        topN (int): The number of closest words to print
    Output:
        None: Just prints
    """
    assert len(idx2Word) == len(model.word2Idx), "Possibly passed in two different vocabularies"
    embeddings = model.centerEmbeddings.data.numpy()
    distMat = cdist(embeddings, embeddings, metric = "cosine")
    # Your code here
    # print the topN closest words to each word in referenceWords
    for word in referenceWords:
        wordIdx = model.word2Idx[word]
        closestIndices = np.argsort(distMat[wordIdx,:])[0:topN]
        closestWords = [(idx2Word[idx], distMat[wordIdx, idx]) for idx in closestIndices]
        for elem in closestWords:
            print(elem)
        print("*"*50 + "\n")
    # End your code

In [91]:
# embd_size = 100
# learning_rate = 0.001
# n_epoch = 60
# idxPairsTest = generateObservations(tokenizedCorpus = finalTokenizedCorpus_test, word2Idx = word2Idx_test)
# sg_model = train_skipgram(embeddingSize = 5, trainingData = idxPairsTest, vocabCount = vocabCount_test,
#                                      word2Idx = word2Idx_test, idx2Word = idx2Word_test, k = 10,
#                                     referenceWords = ["thomas", "salmon"])

In [None]:
embeddingSize = 50
learning_rate = 0.1
n_epoch = 10
idxPairsAP = generateObservations(tokenizedCorpus = finalTokenizedCorpus_ap, word2Idx = word2Idx_ap)
sg_model_ap = train_skipgram(embeddingSize = embeddingSize, trainingData = idxPairsAP,
                                     vocabCount = vocabCount_ap,
                                     word2Idx = word2Idx_ap, idx2Word = idx2Word_ap, k = 20,
                                          referenceWords = ["bush", "soviet", "president", "economy", "american"])

training on 3222584 observations
('bush', 2.220446049250313e-16)
('sec', 0.5033803326733285)
('rioting', 0.5067662979566089)
('intimidation', 0.5135166104535183)
('magazines', 0.5163957768156204)
**************************************************

('soviet', 0.0)
('crocodile', 0.4442133180955036)
('vargas', 0.5114584888316253)
('governing', 0.5308375738131564)
('associate', 0.5429682997449355)
**************************************************

('president', 0.0)
('conventions', 0.4247066995169423)
('clayton', 0.5261356510370664)
('asked', 0.5466800235361514)
('purchases', 0.5558109662720915)
**************************************************

('economy', 1.1102230246251565e-16)
('hubble', 0.4727646238501564)
('onefifth', 0.5193020374628007)
('mountains', 0.5299678951447082)
('elaine', 0.5440051248585133)
**************************************************

('american', 0.0)
('volvo', 0.47128798438542185)
('grigoryants', 0.49699213337556314)
('broadway', 0.5160413740406908)
('accepting'

HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

HBox(children=(IntProgress(value=0, max=3148), HTML(value='')))

avg loss: 14.441185665130615
avg loss: 13.031047929191589
avg loss: 11.637958070017623
avg loss: 11.421434345285848
avg loss: 11.209687608127675
avg loss: 10.79255691031147
Loss at epoch 0: 12.008686547130914
('bush', 0.0)
('police', 0.008125627370117905)
('time', 0.008860428975394852)
('officials', 0.010291807219225846)
('president', 0.01038304812295332)
**************************************************

('soviet', 0.0)
('united', 0.011812184060174702)
('today', 0.012273612522199495)
('time', 0.012456200145919416)
('state', 0.012552599276933418)
**************************************************

('president', 0.0)
('also', 0.0055103221614903886)
('first', 0.005928066864382631)
('two', 0.006124949142827352)
('would', 0.0062991929313442885)
**************************************************

('economy', 0.0)
('company', 0.054801306915847836)
('west', 0.057515728045184344)
('aid', 0.061104225778775945)
('army', 0.06132582333665881)
**************************************************

('

HBox(children=(IntProgress(value=0, max=3148), HTML(value='')))

avg loss: 10.45420519733429
avg loss: 10.004259208047868
avg loss: 9.425600674349477
avg loss: 9.25088370614568
avg loss: 8.978148099199402
avg loss: 8.5882323477197
Loss at epoch 1: 9.391745881439013
('bush', 0.0)
('officials', 0.0021495637996122863)
('could', 0.002362294834993728)
('time', 0.002532638214236327)
('made', 0.002544349356465636)
**************************************************

('soviet', 2.220446049250313e-16)
('today', 0.0022685270593599816)
('united', 0.002293674430478232)
('state', 0.0022997045000529015)
('three', 0.0025253103627129736)
**************************************************

('president', 1.1102230246251565e-16)
('government', 0.001670245720551078)
('two', 0.0018294762866963854)
('time', 0.002098749152494528)
('also', 0.0021345506478709364)
**************************************************

('economy', 0.0)
('aid', 0.009202815134043085)
('good', 0.009603304259339973)
('companies', 0.009944138744628406)
('got', 0.009945848162657978)
*******************

HBox(children=(IntProgress(value=0, max=3148), HTML(value='')))

avg loss: 8.23786749458313
avg loss: 7.830213976661682
avg loss: 7.340277103568635
avg loss: 7.1898030939658195
avg loss: 6.977045694053761
avg loss: 6.705270832616211
Loss at epoch 2: 7.3350870226814
('bush', 0.0)
('officials', 0.0015741275738776483)
('monday', 0.0016392242081517683)
('made', 0.0016502463005277912)
('could', 0.001707056668013962)
**************************************************

('soviet', 0.0)
('today', 0.0015219491648776895)
('state', 0.0015623499624791704)
('american', 0.0015839152230978648)
('united', 0.001610364951571408)
**************************************************

('president', 0.0)
('also', 0.0012022502927377943)
('would', 0.0012220096954603799)
('first', 0.0012225480595149385)
('government', 0.001322685314541494)
**************************************************

('economy', 0.0)
('good', 0.004368336274991225)
('aid', 0.004448516692103666)
('companies', 0.0044603338385375)
('economic', 0.004487299531662159)
******************************************

HBox(children=(IntProgress(value=0, max=3148), HTML(value='')))

In [43]:
embeddingSize = 50
learning_rate = 0.1
n_epoch = 10
idxPairsPubMed = generateObservations(tokenizedCorpus = finalTokenizedCorpus_pubMed, word2Idx = word2Idx_pubMed)
sg_model_pubMed = train_skipgram(embeddingSize = embeddingSize, trainingData = idxPairsPubMed,
                                     vocabCount = vocabCount_pubMed,
                                     word2Idx = word2Idx_pubMed, idx2Word = idx2Word_pubMed, k = 20,
                                                  referenceWords = ["clinical", "obesity", "microbial", "microbiome"])

training on 1732584 observations
('clinical', 0.0)
('emergence', 0.4616474509367934)
('excessive', 0.46306756316983244)
('performs', 0.46741533353563347)
('increased', 0.5455670205429406)
**************************************************

('obesity', 0.0)
('frequency', 0.5168982691254621)
('symbols', 0.5511356324926966)
('exclusively', 0.5717091317033198)
('decline', 0.5765558141966722)
**************************************************

('microbial', 0.0)
('lcfa', 0.49922018838200666)
('proportional', 0.5456135834035312)
('appeared', 0.5615560018197324)
('ldlcholesterol', 0.570285427428456)
**************************************************

('microbiome', 0.0)
('rational', 0.5546636618615921)
('resolved', 0.5739339776798444)
('dose', 0.5846940592534022)
('prove', 0.6016624000160318)
**************************************************



HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1692), HTML(value='')))

avg loss: 1.260457257628441
avg loss: 1.0720387010776997
avg loss: 0.9904536416945863
Loss at epoch 0: 1.0871013186995302
('clinical', 1.1102230246251565e-16)
('using', 0.07755337021509479)
('study', 0.08444943202264732)
('human', 0.09007409896682372)
('system', 0.09154468442436803)
**************************************************

('obesity', 0.0)
('system', 0.14218763142405288)
('health', 0.14616193104902797)
('microbiome', 0.14672249733496334)
('sedentary', 0.14833099824442253)
**************************************************

('microbial', 0.0)
('clinical', 0.1101586736337904)
('human', 0.1175735292982708)
('using', 0.11797479406554334)
('exercise', 0.12135022976846255)
**************************************************

('microbiome', 0.0)
('system', 0.12527159435144575)
('activity', 0.13453813531724257)
('high', 0.1458028603156406)
('health', 0.14586079997581736)
**************************************************



HBox(children=(IntProgress(value=0, max=1692), HTML(value='')))

avg loss: 0.892559583067894
avg loss: 0.85271362008667
avg loss: 0.821068516920845
Loss at epoch 1: 0.8482917993820447
('clinical', 0.0)
('system', 0.06847316284926142)
('human', 0.06976770035775548)
('information', 0.07009410668572436)
('associated', 0.07198495005043215)
**************************************************

('obesity', 0.0)
('microbiome', 0.0982244949528086)
('used', 0.10183675015158833)
('system', 0.10680113902933586)
('text', 0.10735598205269115)
**************************************************

('microbial', 1.1102230246251565e-16)
('human', 0.07980026690815079)
('exercise', 0.0826131681455543)
('clinical', 0.08276219488625292)
('natural', 0.08397793916199148)
**************************************************

('microbiome', 1.1102230246251565e-16)
('system', 0.09507151039885287)
('disease', 0.09712612695593648)
('microbial', 0.09765670397357851)
('high', 0.09791788219955255)
**************************************************



HBox(children=(IntProgress(value=0, max=1692), HTML(value='')))

avg loss: 0.7838031015396119
avg loss: 0.7693898970408439
avg loss: 0.7557271985400762
Loss at epoch 2: 0.7662153598229372
('clinical', 1.1102230246251565e-16)
('system', 0.06844509062176896)
('associated', 0.07236351267351615)
('health', 0.07606088212339357)
('information', 0.07745466775289767)
**************************************************

('obesity', 0.0)
('microbiome', 0.09756338024805988)
('used', 0.09936838508744006)
('text', 0.10099147992038371)
('based', 0.10550319876636749)
**************************************************

('microbial', 0.0)
('human', 0.07363076776851629)
('natural', 0.07862238015672496)
('used', 0.0789489403363125)
('two', 0.08234572868066892)
**************************************************

('microbiome', 1.1102230246251565e-16)
('disease', 0.08450013477062279)
('microbial', 0.08839752543655877)
('development', 0.08859445755644135)
('microbiota', 0.08906992472186959)
**************************************************



HBox(children=(IntProgress(value=0, max=1692), HTML(value='')))

avg loss: 0.7375486475229264
avg loss: 0.7299022284710408
avg loss: 0.7218356183285641
Loss at epoch 3: 0.7275256528375983
('clinical', 0.0)
('system', 0.07018826538244649)
('associated', 0.07378492245924073)
('health', 0.08366193985961445)
('analysis', 0.0865225681720686)
**************************************************

('obesity', 1.1102230246251565e-16)
('used', 0.1021481474955227)
('based', 0.10388180175650408)
('also', 0.10701152278681514)
('microbiome', 0.11221727767282574)
**************************************************

('microbial', 0.0)
('human', 0.07662806166846314)
('natural', 0.08047610033515262)
('two', 0.08551829923211196)
('used', 0.08578694963507338)
**************************************************

('microbiome', 1.1102230246251565e-16)
('disease', 0.0815838929906213)
('development', 0.08195025629710162)
('research', 0.08649439669650205)
('related', 0.08685067256016099)
**************************************************



HBox(children=(IntProgress(value=0, max=1692), HTML(value='')))

avg loss: 0.7108010559082031
avg loss: 0.7055509817419051
avg loss: 0.7005188082420464
Loss at epoch 4: 0.7039014068707365
('clinical', 0.0)
('system', 0.06745904250023704)
('associated', 0.07025648390607375)
('analysis', 0.08226133722463447)
('time', 0.08245953924195815)
**************************************************

('obesity', 0.0)
('used', 0.1013200015084531)
('based', 0.10362699358812044)
('also', 0.11060965006688539)
('found', 0.11193813520338913)
**************************************************

('microbial', 0.0)
('human', 0.08376819721334572)
('natural', 0.08729252879233651)
('model', 0.08781971739397176)
('metabolic', 0.09014354013724635)
**************************************************

('microbiome', 0.0)
('development', 0.08002391893299288)
('disease', 0.08345069300001529)
('related', 0.08645054125928053)
('research', 0.08947236005830606)
**************************************************



HBox(children=(IntProgress(value=0, max=1692), HTML(value='')))

avg loss: 0.6926482183337211
avg loss: 0.688471260609746
avg loss: 0.6849923662005712
Loss at epoch 5: 0.6873036064559828
('clinical', 0.0)
('system', 0.062311423236911634)
('associated', 0.06488713832293247)
('using', 0.07182015301050826)
('time', 0.07208046861377448)
**************************************************

('obesity', 0.0)
('used', 0.09516055998326056)
('based', 0.10006083556298373)
('found', 0.10346164038542616)
('exercise', 0.10548265446950966)
**************************************************

('microbial', 1.1102230246251565e-16)
('model', 0.09152331883947495)
('human', 0.09358792942244654)
('measured', 0.0960778711593373)
('natural', 0.09622363945954893)
**************************************************

('microbiome', 0.0)
('development', 0.08140992080265275)
('related', 0.0874091950024588)
('disease', 0.08915677354281859)
('research', 0.0961601641817117)
**************************************************



HBox(children=(IntProgress(value=0, max=1692), HTML(value='')))

avg loss: 0.6790315974354744
avg loss: 0.6753104136623145
avg loss: 0.672961702725537
Loss at epoch 6: 0.6745492295839826
('clinical', 1.1102230246251565e-16)
('system', 0.0571856132652615)
('associated', 0.05959568745461907)
('using', 0.062366061988063715)
('levels', 0.06527402395343351)
**************************************************

('obesity', 0.0)
('used', 0.08712061193944176)
('found', 0.09284691639509868)
('exercise', 0.09290605784775963)
('based', 0.09489279061182909)
**************************************************

('microbial', 1.1102230246251565e-16)
('model', 0.09396505280806389)
('measured', 0.09520280023611216)
('developed', 0.09938703566738238)
('studies', 0.10137429432297307)
**************************************************

('microbiome', 0.0)
('development', 0.08510365041363255)
('related', 0.0903251434960679)
('disease', 0.09646336769227415)
('bacteria', 0.10068631533988337)
**************************************************



HBox(children=(IntProgress(value=0, max=1692), HTML(value='')))

avg loss: 0.6682857267856598
avg loss: 0.6651215077624321
avg loss: 0.6631840304973097
Loss at epoch 7: 0.6644701669014655
('clinical', 2.220446049250313e-16)
('system', 0.05371995664686191)
('associated', 0.05579978593153734)
('using', 0.056988948932359174)
('levels', 0.057803781265404974)
**************************************************

('obesity', 0.0)
('used', 0.07931993848121144)
('exercise', 0.08201938754858451)
('found', 0.08226940667147742)
('system', 0.0844880355771076)
**************************************************

('microbial', 0.0)
('measured', 0.09461480926148269)
('model', 0.09529304689836471)
('developed', 0.1010002425872043)
('treatment', 0.10312758930402444)
**************************************************

('microbiome', 0.0)
('development', 0.09017060316051506)
('related', 0.09480050724591771)
('disease', 0.10564313538052916)
('bacteria', 0.10624001769886793)
**************************************************



HBox(children=(IntProgress(value=0, max=1692), HTML(value='')))

avg loss: 0.6595424230694771
avg loss: 0.6563488040200471
avg loss: 0.6551584362574083
Loss at epoch 8: 0.6560354224685799
('clinical', 0.0)
('system', 0.05181363945835327)
('levels', 0.05270484171458789)
('associated', 0.053974403128247195)
('using', 0.054550370548890115)
**************************************************

('obesity', 0.0)
('used', 0.07281812440294977)
('exercise', 0.07363538893497801)
('found', 0.07397059120665406)
('system', 0.07611192861563765)
**************************************************

('microbial', 0.0)
('measured', 0.09369128200595767)
('model', 0.09539098151502601)
('developed', 0.1017448349842881)
('treatment', 0.1019804228484309)
**************************************************

('microbiome', 0.0)
('development', 0.09461887603369346)
('related', 0.09825283972695331)
('bacteria', 0.1107673578971895)
('presence', 0.1130225863509925)
**************************************************



HBox(children=(IntProgress(value=0, max=1692), HTML(value='')))

avg loss: 0.6519760429859162
avg loss: 0.6494013966798782
avg loss: 0.6483636779788494
Loss at epoch 9: 0.6490321379269409
('clinical', 0.0)
('levels', 0.04996075189232774)
('system', 0.051451933483990686)
('intervention', 0.05362811530038458)
('associated', 0.05376958487453165)
**************************************************

('obesity', 0.0)
('found', 0.06714204327470585)
('exercise', 0.06774134688817934)
('used', 0.06808356454641806)
('system', 0.07064630543517536)
**************************************************

('microbial', 0.0)
('measured', 0.09355283673321513)
('model', 0.09531950986369808)
('treatment', 0.10067138530362008)
('developed', 0.10220933174474156)
**************************************************

('microbiome', 0.0)
('development', 0.09752403594969239)
('related', 0.10110348095411026)
('presence', 0.1119909948011244)
('bacteria', 0.11463774876900945)
**************************************************



## How Domains Affect Word Embeddings
Choose two words that appear in both the pubmed and ap vocabularies and compare the closest embeddings to both words in the pubmed and ap embeddings[.](https://www.youtube.com/watch?v=Tr-WrGcexlY) **Why might the two words you chose have different representations? How might this affect downstream NLP tasks?**

In [None]:
# Your code here