# HW05: Word Embeddings

Remember that these homework work as a completion grade. **You can <span style="color:red">not</span> skip one section this homework.**

**Essay Feedback**

Please provide feedback to two classmates' essays on Eduflow.

**Training word2vec**

In this section, we train a word2vec model using gensim. We train the model on text8 (which consists of the first 90M characters of a Wikipedia dump from 2006 and is considered one of the benchmarks for evaluating language models).

In [90]:
import gensim.downloader as api

api.info("text8")

{'num_records': 1701,
 'record_format': 'list of str (tokens)',
 'file_size': 33182058,
 'reader_code': 'https://github.com/RaRe-Technologies/gensim-data/releases/download/text8/__init__.py',
 'license': 'not found',
 'description': 'First 100,000,000 bytes of plain text from Wikipedia. Used for testing purposes; see wiki-english-* for proper full Wikipedia datasets.',
 'checksum': '68799af40b6bda07dfa47a32612e5364',
 'file_name': 'text8.gz',
 'read_more': ['http://mattmahoney.net/dc/textdata.html'],
 'parts': 1}

In [91]:
dataset = api.load("text8")

In [92]:
#data = [d for d in dataset]
#data

In [93]:
from gensim.models import Word2Vec

##TODO train a word2vec model on this dataset which appear at least 10 times in the corpus
model = Word2Vec(dataset, min_count=10)

**Word Similarities**

gensim models provide almost all the utility you might want to wish for to perform standard word similarity tasks. They are available in the .wv (wordvectors) attribute of the model, more details could be found [here](https://radimrehurek.com/gensim/models/keyedvectors.html).

In [94]:
word_vectors = model.wv

##TODO find the closest words to king
word_vectors.most_similar(positive='king')

[('prince', 0.756159245967865),
 ('queen', 0.7254481315612793),
 ('kings', 0.6992501020431519),
 ('emperor', 0.6969693303108215),
 ('regent', 0.6830177307128906),
 ('vii', 0.6797141432762146),
 ('constantine', 0.6772857904434204),
 ('throne', 0.6630407571792603),
 ('pope', 0.6626625061035156),
 ('viii', 0.6589317917823792)]

In [95]:
##TODO find the closest word for the vector "woman" + "king" - "man"

word_vectors.most_similar(positive=['king', 'women'], negative='man')

[('bishops', 0.5292506217956543),
 ('nobles', 0.5209300518035889),
 ('monarchs', 0.5037956833839417),
 ('kings', 0.4996193051338196),
 ('scots', 0.49918463826179504),
 ('catholics', 0.4922649562358856),
 ('popes', 0.48591870069503784),
 ('judges', 0.48078009486198425),
 ('priests', 0.47934818267822266),
 ('clergy', 0.47883933782577515)]

King is to man as woman is to X

**Evaluate Word Similarities** 

One common way to evaluate word2vec models are word analogy tasks. Let's check how good our model is on one of those. We consider the [WordSim353](http://alfonseca.org/eng/research/wordsim353.html) benchmark, the task is to determine how similar two words are.

In [96]:
!wget http://alfonseca.org/pubs/ws353simrel.tar.gz
!tar xf ws353simrel.tar.gz

path = "wordsim353_sim_rel/wordsim_similarity_goldstandard.txt"

def load_data(path):
    X, y = [], []
    with open(path) as f:
        for line in f:
            line = line.strip().split("\t")
            X.append((line[0], line[1])) # each entry in x contains two words, e.g. X[0] = (tiger, cat)
            y.append(float(line[-1])) # each entry in y is the annotation how similar two words are, e.g. Y[0] = 7.35
    return X, y

X, y = load_data(path)
print (X[:3], y[:3])

--2023-03-30 17:02:13--  http://alfonseca.org/pubs/ws353simrel.tar.gz
Resolving alfonseca.org (alfonseca.org)... 162.215.249.67
Connecting to alfonseca.org (alfonseca.org)|162.215.249.67|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 5460 (5.3K) [application/x-gzip]
Saving to: 'ws353simrel.tar.gz.4'


2023-03-30 17:02:15 (13.5 KB/s) - 'ws353simrel.tar.gz.4' saved [5460/5460]

[('tiger', 'cat'), ('tiger', 'tiger'), ('plane', 'car')] [7.35, 10.0, 5.77]


In [97]:
##TODO compute how similar the pairs in the WordSim353 are according to our model
# if a word is not present in our model, we assign similarity 0 for the respective text pair

similarity_scores = []

for x in X:
    try:
        similarity_scores.append(word_vectors.similarity(x[0],x[1]))
    except:
        similarity_scores.append(0)
        print(x)
print(len(similarity_scores))

('Arafat', 'Jackson')
('asylum', 'madhouse')
('cup', 'tableware')
('Japanese', 'American')
('Harvard', 'Yale')
('Mexico', 'Brazil')
('Mars', 'water')
('Wednesday', 'news')
('stock', 'CD')
203


In [98]:
from scipy.stats import spearmanr

##TODO compute spearman's rank correlation between our prediction and the human annotations
spearmanr(y, similarity_scores)

SpearmanrResult(correlation=0.6535176250837023, pvalue=4.166416672262106e-26)

In [99]:
import warnings
warnings.filterwarnings("ignore")

import spacy
nlp = spacy.load('en_core_web_sm')
##TODO compute word similarities in the WordSim353 dataset using spaCy word embeddings

similarity_list_2 = []

for x in X:
    d = nlp(x[0])
    similarity_list_2.append(d.similarity(nlp(x[1])))

##TODO compute spearman's rank correlation between these similarities and the human annotations
# Don't worry if results are not too convincing for this experiment
spearmanr(similarity_list_2, similarity_scores)

SpearmanrResult(correlation=0.05420386995172911, pvalue=0.44243837493459515)

**PyTorch Embeddings**

In [100]:
#Import the AG news dataset (same as hw01)
#Download them from here 
!wget https://raw.githubusercontent.com/mhjabreel/CharCnn_Keras/master/data/ag_news_csv/train.csv

import pandas as pd
import nltk
df = pd.read_csv('train.csv')

df.columns = ["label", "title", "lead"]
label_map = {1:"world", 2:"sport", 3:"business", 4:"sci/tech"}
def replace_label(x):
	return label_map[x]
df["label"] = df["label"].apply(replace_label) 
df["text"] = df["title"] + " " + df["lead"]
df = df.sample(n=10000) # # only use 10K datapoints
df.head()

--2023-03-30 17:02:19--  https://raw.githubusercontent.com/mhjabreel/CharCnn_Keras/master/data/ag_news_csv/train.csv
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.109.133, 185.199.111.133, 185.199.108.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.109.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 29470338 (28M) [text/plain]
Saving to: 'train.csv'


2023-03-30 17:02:32 (2.22 MB/s) - 'train.csv' saved [29470338/29470338]



Unnamed: 0,label,title,lead,text
117622,business,"Tension, fear after mall shooting",CAMBRIDGE -- In a quiet entryway of the Cambri...,"Tension, fear after mall shooting CAMBRIDGE --..."
49114,sci/tech,CNN Blog,Check this web log throughout the day as CNN A...,CNN Blog Check this web log throughout the day...
43708,world,"Typhoon Meari Hits South Japan, Triggers Floods",TOKYO (Reuters) - A record eighth typhoon swe...,"Typhoon Meari Hits South Japan, Triggers Flood..."
112859,sport,Three teams in the running to be winter champi...,BERLIN: Three teams have the chance to lift th...,Three teams in the running to be winter champi...
116500,sci/tech,"Samsung, Sony in cross-license deal",Two of the world #39;s consumer electronics gi...,"Samsung, Sony in cross-license deal Two of the..."


In [101]:
vocab = 200
##TODO tokenize the text, only keep 200 most frequent words 
def tokenize(text, return_list = False):
    tokens = nlp(text)
    if return_list == True:
        l_1 = []
        for w in tokens:
            if not w.is_stop and not w.is_punct and not w.is_digit:
                l_1.append(w.lemma_.lower())
        return l_1
    else:
        l = ''
        for w in tokens:
            if not w.is_stop and not w.is_punct and not w.is_digit:
                l += w.lemma_.lower() + ' '
        return l 

#tokenize the data in token col

df['temp_tok'] = df['text'].apply(lambda x: tokenize(x))
df['tokens'] = df['text'].apply(lambda x: tokenize(x, return_list=True))

df.head(5)

Unnamed: 0,label,title,lead,text,temp_tok,tokens
117622,business,"Tension, fear after mall shooting",CAMBRIDGE -- In a quiet entryway of the Cambri...,"Tension, fear after mall shooting CAMBRIDGE --...",tension fear mall shoot cambridge quiet entryw...,"[tension, fear, mall, shoot, cambridge, quiet,..."
49114,sci/tech,CNN Blog,Check this web log throughout the day as CNN A...,CNN Blog Check this web log throughout the day...,cnn blog check web log day cnn anchor space co...,"[cnn, blog, check, web, log, day, cnn, anchor,..."
43708,world,"Typhoon Meari Hits South Japan, Triggers Floods",TOKYO (Reuters) - A record eighth typhoon swe...,"Typhoon Meari Hits South Japan, Triggers Flood...",typhoon meari hits south japan triggers floods...,"[typhoon, meari, hits, south, japan, triggers,..."
112859,sport,Three teams in the running to be winter champi...,BERLIN: Three teams have the chance to lift th...,Three teams in the running to be winter champi...,team running winter champion germany berlin te...,"[team, running, winter, champion, germany, ber..."
116500,sci/tech,"Samsung, Sony in cross-license deal",Two of the world #39;s consumer electronics gi...,"Samsung, Sony in cross-license deal Two of the...",samsung sony cross license deal world 39;s con...,"[samsung, sony, cross, license, deal, world, 3..."


In [102]:
from collections import Counter

n=10000

txt = ''
for i in range(n):
    #change to tokens
    txt += df['temp_tok'].values[i]

split_it = txt.split()
Counter = Counter(split_it)  
most_occur = Counter.most_common(vocab)
  
print(most_occur)

[('39;s', 2533), ('say', 2113), ('new', 1789), ('reuters', 1463), ('ap', 1309), ('year', 1253), ('company', 848), ('win', 745), ('world', 739), ('wednesday', 640), ('report', 622), ('monday', 618), ('u.s.', 616), ('tuesday', 611), ('thursday', 610), ('oil', 603), ('game', 602), ('week', 595), ('inc.', 566), ('friday', 550), ('high', 535), ('million', 533), ('york', 529), ('day', 518), ('price', 514), ('iraq', 506), ('kill', 505), ('yesterday', 501), ('plan', 496), ('lead', 493), ('time', 481), ('president', 475), ('end', 470), ('united', 450), ('microsoft', 442), ('group', 438), ('team', 436), ('sunday', 408), ('security', 408), ('government', 406), ('second', 405), ('official', 405), ('afp', 401), ('market', 394), ('percent', 393), ('announce', 390), ('open', 389), ('month', 389), ('stock', 389), ('rise', 388), ('today', 388), ('sale', 386), ('state', 380), ('quot', 377), ('service', 377), ('people', 376), ('profit', 376), ('corp.', 375), ('quarter', 367), ('run', 366), ('set', 363), 

In [103]:
import numpy as np
#create dict to acces keys only
dict_mo ={}

for k, v in most_occur:
    dict_mo.setdefault(k, [v])
    

def keep_top_200(l1, l2):
    l1 = np.array(l1)
    l2 = np.array(l2)
    mask = np.in1d(l1, l2)
    return l1[mask].tolist()

df['most_freq_tokens'] = df['tokens'].apply(lambda x: keep_top_200(x, list(dict_mo.keys())))

In [104]:
len(most_occur)
word2id = {w[0] : i + 1 for i, w in enumerate(most_occur)}
word2id

word2id['say']

def freq_index(l):
    l_temp = [0]*200
    for t in l:
        l_temp[word2id[t]-1] = l_temp[word2id[t]-1] +1
    return l_temp

df['freq_index'] = df['most_freq_tokens'].apply(lambda x: freq_index(x))


In [105]:
df.head()

Unnamed: 0,label,title,lead,text,temp_tok,tokens,most_freq_tokens,freq_index
117622,business,"Tension, fear after mall shooting",CAMBRIDGE -- In a quiet entryway of the Cambri...,"Tension, fear after mall shooting CAMBRIDGE --...",tension fear mall shoot cambridge quiet entryw...,"[tension, fear, mall, shoot, cambridge, quiet,...","[year, old, say]","[0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
49114,sci/tech,CNN Blog,Check this web log throughout the day as CNN A...,CNN Blog Check this web log throughout the day...,cnn blog check web log day cnn anchor space co...,"[cnn, blog, check, web, log, day, cnn, anchor,...","[web, day, space, 39;s, second]","[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
43708,world,"Typhoon Meari Hits South Japan, Triggers Floods",TOKYO (Reuters) - A record eighth typhoon swe...,"Typhoon Meari Hits South Japan, Triggers Flood...",typhoon meari hits south japan triggers floods...,"[typhoon, meari, hits, south, japan, triggers,...","[south, japan, reuters, record, japan, wednesd...","[0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, ..."
112859,sport,Three teams in the running to be winter champi...,BERLIN: Three teams have the chance to lift th...,Three teams in the running to be winter champi...,team running winter champion germany berlin te...,"[team, running, winter, champion, germany, ber...","[team, team, head, final, take, week]","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
116500,sci/tech,"Samsung, Sony in cross-license deal",Two of the world #39;s consumer electronics gi...,"Samsung, Sony in cross-license deal Two of the...",samsung sony cross license deal world 39;s con...,"[samsung, sony, cross, license, deal, world, 3...","[deal, world, 39;s, giant, sign, major, company]","[1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, ..."


In [106]:
#feels like an unnecessary step... but didnt get it to work without it
def to_string(l):
    st = ''
    for i in l:
        st += i + ' '
    st = st[:-1]
    return st
df['freq'] = df['most_freq_tokens'].apply(lambda x: to_string(x))
df.head()

Unnamed: 0,label,title,lead,text,temp_tok,tokens,most_freq_tokens,freq_index,freq
117622,business,"Tension, fear after mall shooting",CAMBRIDGE -- In a quiet entryway of the Cambri...,"Tension, fear after mall shooting CAMBRIDGE --...",tension fear mall shoot cambridge quiet entryw...,"[tension, fear, mall, shoot, cambridge, quiet,...","[year, old, say]","[0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",year old say
49114,sci/tech,CNN Blog,Check this web log throughout the day as CNN A...,CNN Blog Check this web log throughout the day...,cnn blog check web log day cnn anchor space co...,"[cnn, blog, check, web, log, day, cnn, anchor,...","[web, day, space, 39;s, second]","[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",web day space 39;s second
43708,world,"Typhoon Meari Hits South Japan, Triggers Floods",TOKYO (Reuters) - A record eighth typhoon swe...,"Typhoon Meari Hits South Japan, Triggers Flood...",typhoon meari hits south japan triggers floods...,"[typhoon, meari, hits, south, japan, triggers,...","[south, japan, reuters, record, japan, wednesd...","[0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, ...",south japan reuters record japan wednesday kil...
112859,sport,Three teams in the running to be winter champi...,BERLIN: Three teams have the chance to lift th...,Three teams in the running to be winter champi...,team running winter champion germany berlin te...,"[team, running, winter, champion, germany, ber...","[team, team, head, final, take, week]","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",team team head final take week
116500,sci/tech,"Samsung, Sony in cross-license deal",Two of the world #39;s consumer electronics gi...,"Samsung, Sony in cross-license deal Two of the...",samsung sony cross license deal world 39;s con...,"[samsung, sony, cross, license, deal, world, 3...","[deal, world, 39;s, giant, sign, major, company]","[1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, ...",deal world 39;s giant sign major company


In [107]:
#create data for encoder

X = np.array([st for st in df.freq]).reshape(-1,1)
X.shape

(10000, 1)

In [108]:
from sklearn.preprocessing import OneHotEncoder
length = 100
#TODO create a one_hot representation for each word and truncate/pad the sequences such that they are all of the 
#same length (here we use 100)

#encoding
onehot_encoder = OneHotEncoder(sparse=False)

onehot_encoded = onehot_encoder.fit_transform(X)
print(onehot_encoded.shape)

#truncate
print(onehot_encoded[:,:100].shape)

(10000, 9730)
(10000, 100)


In [109]:
##TODO create your torch embedding like we did in notebook 5! (hint: predicting labels: world, sport, business, 
#and sci/tech)

In [110]:
#create an index with the most common words
df.head()

Unnamed: 0,label,title,lead,text,temp_tok,tokens,most_freq_tokens,freq_index,freq
117622,business,"Tension, fear after mall shooting",CAMBRIDGE -- In a quiet entryway of the Cambri...,"Tension, fear after mall shooting CAMBRIDGE --...",tension fear mall shoot cambridge quiet entryw...,"[tension, fear, mall, shoot, cambridge, quiet,...","[year, old, say]","[0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",year old say
49114,sci/tech,CNN Blog,Check this web log throughout the day as CNN A...,CNN Blog Check this web log throughout the day...,cnn blog check web log day cnn anchor space co...,"[cnn, blog, check, web, log, day, cnn, anchor,...","[web, day, space, 39;s, second]","[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",web day space 39;s second
43708,world,"Typhoon Meari Hits South Japan, Triggers Floods",TOKYO (Reuters) - A record eighth typhoon swe...,"Typhoon Meari Hits South Japan, Triggers Flood...",typhoon meari hits south japan triggers floods...,"[typhoon, meari, hits, south, japan, triggers,...","[south, japan, reuters, record, japan, wednesd...","[0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, ...",south japan reuters record japan wednesday kil...
112859,sport,Three teams in the running to be winter champi...,BERLIN: Three teams have the chance to lift th...,Three teams in the running to be winter champi...,team running winter champion germany berlin te...,"[team, running, winter, champion, germany, ber...","[team, team, head, final, take, week]","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",team team head final take week
116500,sci/tech,"Samsung, Sony in cross-license deal",Two of the world #39;s consumer electronics gi...,"Samsung, Sony in cross-license deal Two of the...",samsung sony cross license deal world 39;s con...,"[samsung, sony, cross, license, deal, world, 3...","[deal, world, 39;s, giant, sign, major, company]","[1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, ...",deal world 39;s giant sign major company


In [111]:
#X data and y data
features = np.array([l for l in df.freq_index])

def label_trans(l):
    if l == 'world':
        return 0
    if l == 'sport':
        return 1
    if l == 'business':
        return 2
    if l == 'sci/tech':
        return 3
    
    
df['label_nr'] = df['label'].apply(lambda x: label_trans(x))

y = np.array([i for i in df.label_nr])

print(features.shape, y.shape)

df.head()

(10000, 200) (10000,)


Unnamed: 0,label,title,lead,text,temp_tok,tokens,most_freq_tokens,freq_index,freq,label_nr
117622,business,"Tension, fear after mall shooting",CAMBRIDGE -- In a quiet entryway of the Cambri...,"Tension, fear after mall shooting CAMBRIDGE --...",tension fear mall shoot cambridge quiet entryw...,"[tension, fear, mall, shoot, cambridge, quiet,...","[year, old, say]","[0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",year old say,2
49114,sci/tech,CNN Blog,Check this web log throughout the day as CNN A...,CNN Blog Check this web log throughout the day...,cnn blog check web log day cnn anchor space co...,"[cnn, blog, check, web, log, day, cnn, anchor,...","[web, day, space, 39;s, second]","[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",web day space 39;s second,3
43708,world,"Typhoon Meari Hits South Japan, Triggers Floods",TOKYO (Reuters) - A record eighth typhoon swe...,"Typhoon Meari Hits South Japan, Triggers Flood...",typhoon meari hits south japan triggers floods...,"[typhoon, meari, hits, south, japan, triggers,...","[south, japan, reuters, record, japan, wednesd...","[0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, ...",south japan reuters record japan wednesday kil...,0
112859,sport,Three teams in the running to be winter champi...,BERLIN: Three teams have the chance to lift th...,Three teams in the running to be winter champi...,team running winter champion germany berlin te...,"[team, running, winter, champion, germany, ber...","[team, team, head, final, take, week]","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",team team head final take week,1
116500,sci/tech,"Samsung, Sony in cross-license deal",Two of the world #39;s consumer electronics gi...,"Samsung, Sony in cross-license deal Two of the...",samsung sony cross license deal world 39;s con...,"[samsung, sony, cross, license, deal, world, 3...","[deal, world, 39;s, giant, sign, major, company]","[1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, ...",deal world 39;s giant sign major company,3


In [112]:
class GenericDataset(Dataset):

  def __init__(self, X, y):
    self.X = X
    self.y = y

  def __len__(self):
    return len(self.y)

  def __getitem__(self, index):
    return self.X[index], self.y[index]

dataset = GenericDataset(features, y)
for d in dataset:
    print(d)
    break

(array([0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0]), 2)


In [113]:
#fattar inte riktigt embedd...

# Model setup

import torch
import torch.nn as nn

class EmbeddingNet(nn.Module):
  def __init__(self, num_words=200):
    super(EmbeddingNet, self).__init__()
    self.embedding = nn.Embedding(num_words + 1, 2)
    self.flatten = nn.Flatten()
    self.fc1 = nn.Linear(2 * 200, 50)
    self.fc2 = nn.Linear(50, 1)
    self.softmax = nn.Softmax()

  def forward(self, x):
    x = self.embedding(x)
    x = self.flatten(x)
    x = self.fc1(x)
    x = self.fc2(x)
    x = self.softmax(x)
    return x

num_words = 200
model = EmbeddingNet(num_words)

print(model)

EmbeddingNet(
  (embedding): Embedding(201, 2)
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (fc1): Linear(in_features=400, out_features=50, bias=True)
  (fc2): Linear(in_features=50, out_features=1, bias=True)
  (softmax): Softmax(dim=None)
)


In [114]:
# show the vectors
from torch.utils.data import Dataset, DataLoader

num_words = 200


model = EmbeddingNet(num_words)
#loss func
criterion = nn.BCELoss()
#standard optimizer

#initialize data 
loader = torch.utils.data.DataLoader(dataset, batch_size=100, shuffle=True)

for i in range(10):
#weird step but letx keep it
  if i > 0:
    for data, label in loader:
      optimizer.zero_grad()
      outputs = model(data)
      loss = criterion(outputs, label.float().unsqueeze(1))
      loss.backward()
      optimizer.step()

#could predict on new data now

In [115]:
print(loss)

tensor(-44., grad_fn=<BinaryCrossEntropyBackward0>)
