## Skip_gram_with_negative_sampling

In [88]:
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.optim as optim
import torch.nn.functional as F
import nltk
import random
import numpy as np
from collections import Counter
flatten = lambda l: [item for sublist in l for item in sublist]

In [89]:
print(torch.__version__)
print(nltk.__version__)

1.0.0
3.3


In [90]:
FloatTensor =  torch.FloatTensor
LongTensor = torch.LongTensor
ByteTensor =  torch.ByteTensor

In [91]:
def getBatch(batch_size, train_data):
    random.shuffle(train_data)
    sindex = 0
    eindex = batch_size
    while eindex < len(train_data):
        batch = train_data[sindex: eindex]
        temp = eindex
        eindex = eindex + batch_size
        sindex = temp
        yield batch
    
    if eindex >= len(train_data):
        batch = train_data[sindex:]
        yield batch

In [92]:
def prepare_sequence(seq, word2index):
    idxs = list(map(lambda w: word2index[w] if word2index.get(w) is not None else word2index["<UNK>"], seq))
    return Variable(LongTensor(idxs))

def prepare_word(word, word2index):
    return Variable(LongTensor([word2index[word]]) if word2index.get(word) is not None else LongTensor([word2index["<UNK>"]]))

## Data load and Preprocessing

In [93]:
corpus = list(nltk.corpus.gutenberg.sents('melville-moby_dick.txt'))[:500]
corpus = [[word.lower() for word in sent] for sent in corpus]

In [94]:
word_count = Counter(flatten(corpus))

In [95]:
MIN_COUNT = 3
exclude = []

In [96]:
for w, c in word_count.items(): # 3번 아래로 나타나는 단어들은 제외한다. 
    if c < MIN_COUNT:
        exclude.append(w)

In [97]:
sparse_words = []               # 3번 아래로 등장하는 단어들
for w,c in word_count.items():
    sparse_words.append(w)
sparse_words[:10]

['[',
 'moby',
 'dick',
 'by',
 'herman',
 'melville',
 '1851',
 ']',
 'etymology',
 '.']

## Prepare train data

In [98]:
vocab = list(set(flatten(corpus)) - set(exclude))

In [99]:
word2index = {}
for vo in vocab:
    if word2index.get(vo) is None:
        word2index[vo] = len(word2index)
        
index2word = {v:k for k, v in word2index.items()}

In [100]:
WINDOW_SIZE = 5
windows =  flatten([list(nltk.ngrams(['<DUMMY>'] * WINDOW_SIZE + c + ['<DUMMY>'] * WINDOW_SIZE, WINDOW_SIZE * 2 + 1)) for c in corpus])

train_data = []

for window in windows:
    for i in range(WINDOW_SIZE * 2 + 1):
        if window[i] in exclude or window[WINDOW_SIZE] in exclude: 
            continue # min_count
        if i == WINDOW_SIZE or window[i] == '<DUMMY>': 
            continue
        train_data.append((window[WINDOW_SIZE], window[i]))

X_p = []
y_p = []

for tr in train_data:
    X_p.append(prepare_word(tr[0], word2index).view(1, -1))
    y_p.append(prepare_word(tr[1], word2index).view(1, -1))
    
train_data = list(zip(X_p, y_p))

In [101]:
len(train_data) # 자주등장하는 단어들 포함시켜서 이전에 skip_gram모델때보다 더 많음

50242

## Build Unigram Distribution ** 0.75

![](https://user-images.githubusercontent.com/36406676/54072330-bfc2a480-42bc-11e9-8759-d27c561d28d9.jpg)

In [20]:
Z = 0.001
word_count = Counter(flatten(corpus))
num_total_words = sum([c for w, c in word_count.items() if w not in exclude])

In [36]:
num_total_words

7798

In [40]:
unigram_table= []
for vo in vocab:
    unigram_table.extend([vo]*int(((word_count[vo]/num_total_words)**0.75)/Z))

In [86]:
word_count['city']

4

In [87]:
unigram_table[:10]

## 의미를 해석해보면 전체 단어빈도가 7798개이고 해당단어 빈도수에 1000을 곱했을때 대략적으로 city는 3이나온다는 뜻..?
## (4/7798)**0.75*1000 = 3.40844....

['coffin',
 'coffin',
 'coffin',
 'voyages',
 'voyages',
 'voyages',
 '?',
 '?',
 '?',
 '?']

In [45]:
print(len(vocab), len(unigram_table)) # 어찌됐든 늘어남

478 3500


## Negative_Sampling

In [2]:
def negative_sampling(targets, unigram_table, k):
    batch_size = target.size(0)
    neg_samples = []
    for i in range(batch_size):
        nsample = []
        target_index = targets[i].data.tolist()[0]
        while len(nsample) < k: # num of sampling
            neg = ramdom.choice(unigram_table)
            if word2index[neg] == target_index:
                continue
            nsample.append(neg)
        neg_samples.append(prepare_sequence(nsample, word2index).view(1,-1))
        
    return torch.cat(neg_samples)            

## Modeling