<a href="https://colab.research.google.com/github/ligemlp/mylesson/blob/master/%E7%AC%AC%E4%BA%8C%E8%AF%BEpytorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!git clone https://github.com/ligemlp/mylesson.git

Cloning into 'mylesson'...
remote: Enumerating objects: 9, done.[K
remote: Total 9 (delta 0), reused 0 (delta 0), pack-reused 9[K
Unpacking objects: 100% (9/9), done.


In [2]:
!ls -R

.:
mylesson  sample_data

./mylesson:
text8  text8.zip

./mylesson/text8:
text8.dev.txt  text8.test.txt  text8.train.txt

./sample_data:
anscombe.json		      mnist_test.csv
california_housing_test.csv   mnist_train_small.csv
california_housing_train.csv  README.md


In [0]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils
from collections import Counter
import numpy as np
import random
import math
import pandas as pd
import scipy
import sklearn
from sklearn.metrics.pairwise import cosine_similarity

In [0]:
USE_CUDA = torch.cuda.is_available()
random.seed(1)
np.random.seed(1)
torch.manual_seed(1)
if USE_CUDA:
    torch.cuda.manual_seed(1)
#hyper paramters
C = 3 #context window
K = 100 #number of negative samples
NUM_EPOCHS = 2
MAX_VOCAB_SIZE = 30000
BATCH_SIZE = 128
LEARNING_RATE = 0.2
EMBEDDING_SIZE = 100

In [0]:
def word_tokenize(text):
    return text.split()

In [6]:
with open("./mylesson/text8/text8.train.txt","r") as f:
    data = f.read()
text = data.split()
vocab = dict(Counter(text).most_common(MAX_VOCAB_SIZE - 1))
vocab["<unk>"] = len(text) - np.sum(list(vocab.values()))

idx_to_word = [word for word in vocab.keys()]
word_to_idx = {word:i for i,word in enumerate(idx_to_word)}

word_counts = np.array([count for count in vocab.values()], dtype=np.float32)
word_freq = word_counts / np.sum(word_counts)
word_freq = word_freq ** (3./4.)
word_freq = word_counts / np.sum(word_counts)
VOCAB_SIZE = len(idx_to_word)
VOCAB_SIZE

30000

In [0]:
class WordEmbedingDataset(torch.utils.data.Dataset):
    def __init__(self, text, word_to_idx, idx_to_word, word_freq, word_counts):
        super(WordEmbedingDataset, self).__init__()
        self.text_encoded = [word_to_idx.get(word, word_to_idx["<unk>"]) for word in text]
        self.text_encoded = torch.LongTensor(self.text_encoded)
        self.word_to_idx = word_to_idx
        self.idx_to_word = idx_to_word
        self.word_freq = torch.Tensor(word_freq)
        self.word_counts = torch.Tensor(word_counts)

    def __len__(self):
        return len(self.text_encoded)

    def __getitem__(self, idx):
        center_word = self.text_encoded[idx]
        pos_indices = list(range(idx - C,idx)) + list(range(idx+1, idx+C+1))
        pos_indices = [i % len(self.text_encoded) for i in pos_indices]
        pos_words = self.text_encoded[pos_indices]
        neg_words = torch.multinomial(self.word_freq,  K * pos_words.shape[0], True)
        return center_word, pos_words, neg_words

In [8]:
dataset = WordEmbedingDataset(text, word_to_idx, idx_to_word, word_freq, word_counts)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
center_word, pos_words, neg_words = dataset.__getitem__(100)
pos_words

tensor([  58,   25, 6525,    1,  152,   32])

In [0]:
class EmbeddingModel(nn.Module):
    def __init__(self, vocab_size, embed_size):
        super(EmbeddingModel, self).__init__()
        self.vocab_size = vocab_size
        self.embed_size = embed_size
        self.in_embed = nn.Embedding(self.vocab_size, self.embed_size)
        self.out_embed = nn.Embedding(self.vocab_size, self.embed_size)

    def forward(self, input_labels, pos_labels, neg_labels):
        input_embedding = self.in_embed(input_labels)#[batch_size, embed_size]
        pos_embedding = self.out_embed(pos_labels)#[batch_size,(window_size*2), embed_size]
        neg_embedding = self.out_embed(neg_labels)#[batch_size,(window_size*2*K),embed_size]

        input_embedding = input_embedding.unsqueeze(2)#[batch_size,embed_size,1]
        pos_dot = torch.bmm(pos_embedding, input_embedding).squeeze(2)#[batch_size, (window_size*2)]
        neg_dot = torch.bmm(neg_embedding, -input_embedding).squeeze(2)#[batch_size, (window_size*2*K)]

        log_pos = F.logsigmoid(pos_dot).sum(1)
        log_neg = F.logsigmoid(neg_dot).sum(1)

        loss = log_pos + log_neg

        return -loss
    
    def input_embeddings(self):
        return self.in_embed.weight.data.cpu().numpy()

In [0]:
model = EmbeddingModel(VOCAB_SIZE, EMBEDDING_SIZE)
if USE_CUDA:
    model = model.cuda()

In [15]:
optimizer = torch.optim.SGD(model.parameters(),lr=LEARNING_RATE)
for e in range(NUM_EPOCHS):
    for i , (input_labels, pos_labels, neg_labels) in enumerate(dataloader):
        if USE_CUDA:
            input_labels = input_labels.long()
            pos_labels = pos_labels.long()
            neg_labels = neg_labels.long()

            input_labels = input_labels.cuda()
            pos_labels = pos_labels.cuda()
            neg_labels = neg_labels.cuda()
            optimizer.zero_grad()
            
            loss = model(input_labels, pos_labels, neg_labels).mean()
            loss.backward()
            optimizer.step()

            if i % 100 == 0:
                print("Epoch",e,"iteration",i,loss.item())

Epoch 0 iteration 0 2515.32373046875
Epoch 0 iteration 100 869.727294921875
Epoch 0 iteration 200 574.3082275390625
Epoch 0 iteration 300 616.73486328125
Epoch 0 iteration 400 438.3895568847656
Epoch 0 iteration 500 465.9903869628906
Epoch 0 iteration 600 350.5777893066406
Epoch 0 iteration 700 434.18865966796875
Epoch 0 iteration 800 306.4776611328125
Epoch 0 iteration 900 272.8446960449219
Epoch 0 iteration 1000 257.8622741699219
Epoch 0 iteration 1100 283.87042236328125
Epoch 0 iteration 1200 215.76821899414062
Epoch 0 iteration 1300 172.3601837158203
Epoch 0 iteration 1400 233.6432342529297
Epoch 0 iteration 1500 236.304443359375
Epoch 0 iteration 1600 195.4098358154297
Epoch 0 iteration 1700 162.46824645996094
Epoch 0 iteration 1800 189.8173370361328
Epoch 0 iteration 1900 172.59027099609375
Epoch 0 iteration 2000 166.48712158203125
Epoch 0 iteration 2100 203.97393798828125
Epoch 0 iteration 2200 140.853271484375
Epoch 0 iteration 2300 182.30224609375
Epoch 0 iteration 2400 142.02

KeyboardInterrupt: ignored