In [59]:
import torch
import torch.nn 
import numpy as np
import sys
import os
import collections
import random
import math 
import torch.utils.data 
import time
import torch.nn as nn
import torch.nn.functional as F

In [2]:
print(torch.__version__)

1.10.0+cu113


In [3]:
assert 'ptb.train.txt' in os.listdir('/home/word2vec/data')
with open('/home/word2vec/data/ptb.train.txt','r')as f:
    lines = f.readlines()
    raw_dataset = [st.split()for st in lines]
'#sentences:%d'%(len(raw_dataset))

'#sentences:42068'

In [4]:
counter = collections.Counter([tk for st in raw_dataset for tk in st])
counter = dict(filter(lambda x:x[1]>=5,counter.items()))
counter['N']

32481

In [5]:
idx_to_taken = [tk for tk,_ in counter.items()]
token_to_idx = {tk:idx for idx,tk in enumerate(idx_to_taken)}
dataset = [[token_to_idx[tk] for tk in st if tk in token_to_idx]for st in raw_dataset]
num_tokens = sum([len(st)for st in dataset])

In [6]:
counter[idx_to_taken[2]]

32481

In [7]:
idx_to_taken[2]
token_to_idx['N']

2

In [8]:
def discard(idx):
    return random.uniform(0,1)<1-math.sqrt(1e-4/counter[idx_to_taken[idx]]*num_tokens)
sampled_data_set = [[tk for tk in st if not discard(tk)]for st in dataset]
print(sum([len(st)for st in sampled_data_set]))

375861


In [9]:
def compare_count(token):
    return '#%s: before:%d,after:%d'%(token,sum([st.count(token_to_idx[token])for st in dataset]),sum([st.count(token_to_idx[token])for st in sampled_data_set]))
compare_count('the')

'#the: before:50770,after:2113'

In [10]:
def get_centers_and_contexts(dataset,max_window_size):
    centers,contexts= [],[]
    for st in dataset:
        if len(st)<2:
            continue
        centers +=st
        for center_i in range(len(st)):
            window_size = random.randint(1,max_window_size)
            indices = list(range(max(0,center_i-window_size),min(len(st),center_i+window_size+1)))
            indices.remove(center_i)
            contexts.append([st[idx]for idx in indices])
    return centers,contexts
        

In [11]:
tiny_dataset = [list(range(7)),list(range(7,10))]
print('dataset',tiny_dataset)
for centers,contexts in zip(*get_centers_and_contexts(tiny_dataset,2)):
    print('center:',centers,'has contexts:',contexts)

dataset [[0, 1, 2, 3, 4, 5, 6], [7, 8, 9]]
center: 0 has contexts: [1, 2]
center: 1 has contexts: [0, 2]
center: 2 has contexts: [0, 1, 3, 4]
center: 3 has contexts: [2, 4]
center: 4 has contexts: [2, 3, 5, 6]
center: 5 has contexts: [3, 4, 6]
center: 6 has contexts: [4, 5]
center: 7 has contexts: [8, 9]
center: 8 has contexts: [7, 9]
center: 9 has contexts: [7, 8]


In [12]:
all_centers,all_contexts = get_centers_and_contexts(sampled_data_set,5)
print(len(all_centers),len(all_contexts))
print(len(idx_to_taken))
def get_negatives(all_contexts,sampling_weight,K):
    all_negatives,neg_candidates,i= [],[],0
    populations = list(range(len(sampling_weight)))
    print(len(populations))
    for contexts in all_contexts:
        negatives= []
        while len(negatives)<len(contexts)*K:
            if i==len(neg_candidates):
                i=0
                neg_candidates  = random.choices(populations,sampling_weight,k = int(1e5))
            neg =  neg_candidates[i]
            i = i+1
            if neg not in set(contexts):
                negatives.append(neg)
        all_negatives.append(negatives)
    return all_negatives
sampling_weight = [counter[w]**0.75 for w in idx_to_taken]
all_negatives =  get_negatives(all_contexts,sampling_weight,K=5)

374933 374933
9858
9858


In [13]:
class MyDateset(torch.utils.data.Dataset):
    def __init__(self,centers,contexts,negatives):
        assert len(centers)==len(contexts)==len(negatives)
        self.centers = centers
        self.contexts = contexts
        self.negatives = negatives
    def __getitem__(self,index):
        return (self.centers[index],self.contexts[index],self.negatives[index])
    def __len__(self):
        return len(self.centers)

In [14]:
def batchify(data):
    max_len = max(len(c)+len(n) for _,c,n in data)
    centers,contexts_negatives,masks,labels =[],[],[],[]
    for center,context,negative in data:
        cur_len = len(context)+len(negative)
        centers += [center]
        contexts_negatives += [(context+negative+[0]*(max_len-cur_len))]
        masks += [[1]*cur_len+[0]*(max_len-cur_len)]
        labels += [[1]*len(contexts)+[0]*(max_len-len(contexts))]
    return torch.tensor(centers).view(-1,1),torch.tensor(contexts_negatives),torch.tensor(masks),torch.tensor(labels)


In [15]:
batch_size=512
num_workers = 0 
dataset = MyDateset(all_centers,all_contexts,all_negatives)
data_iter = torch.utils.data.DataLoader(dataset,batch_size=batch_size,num_workers=num_workers,shuffle=True,collate_fn=batchify)
for batch in data_iter:
    for name,data in zip(['centers','contexts_negative','mask','labels'],batch):
        print(data.shape)
    break


torch.Size([512, 1])
torch.Size([512, 60])
torch.Size([512, 60])
torch.Size([512, 60])


In [16]:
embed = nn.Embedding(num_embeddings=20,embedding_dim=4)
x = torch.tensor([[1,2,3],[4,5,6]],dtype=torch.long)
embed(x)

tensor([[[ 0.7334,  1.4351, -1.0346, -0.6599],
         [-0.2932, -1.2891, -0.1360, -0.2086],
         [-0.7851, -0.4692, -1.1269, -1.7478]],

        [[ 0.9005, -2.1161,  1.5750, -0.2610],
         [-0.8368,  1.4180,  0.9319,  1.8862],
         [-0.2481, -0.4609, -0.2558,  0.6948]]], grad_fn=<EmbeddingBackward0>)

In [17]:
x = torch.ones((2,1,4))
y = torch.ones((2,4,6))
torch.bmm(x,y)

tensor([[[4., 4., 4., 4., 4., 4.]],

        [[4., 4., 4., 4., 4., 4.]]])

In [18]:
def skip_gram(center,contexts_and_negatives,embed_v,embed_u):
    v = embed_v(center)
    u = embed_u(contexts_and_negatives)
    pred = torch.bmm(v,u.permute(0,2,1))
    return pred

In [62]:
class SigmoidBinaryCrossEntropyLoss(nn.Module):
    def __init__(self):
        super(SigmoidBinaryCrossEntropyLoss,self).__init__()
    def forward(self,inputs,targets,mask=None):
        targets = targets.float()
        mask=mask.float()
        inputs = inputs.float()
        res = torch.nn.functional.binary_cross_entropy_with_logits(inputs,targets,reduction="none",weight=mask)
        return res.mean(dim=1)
        
loss = SigmoidBinaryCrossEntropyLoss()

In [55]:
pred =  torch.tensor([[1.5,0.3,-1,2],[1.1,-0.6,2.2,0.4]])
labels = torch.tensor([[1,0,0,0],[1,1,0,0]])
mask = torch.tensor([[1,1,1,1],[1,1,1,0]])
loss(pred,labels,mask)*mask.shape[1]/mask.float().sum(dim=1)

tensor([0.8740, 1.2100])

In [56]:
def sigmoid(x):
    return 1/(1+np.exp(-x))
loss = -labels*np.log(sigmoid(pred)) - (1-labels)*np.log(1-sigmoid(pred))
loss.mean()

tensor(1.0049)

In [22]:
def sigmd(x):
    return -math.log(1/(1+math.exp(-x)))
print('%.4f'%((sigmd(1.5)+sigmd(-0.3)+sigmd(1)+sigmd(-2))/4))

0.8740


In [23]:
print(len(idx_to_taken))

9858


In [24]:
embed_size = 100
net = nn.Sequential(
    nn.Embedding(num_embeddings=len(idx_to_taken),embedding_dim=embed_size),
    nn.Embedding(num_embeddings=len(idx_to_taken),embedding_dim=embed_size)
)

In [25]:
def train(net,lr,num_epochs):
    device =torch.device('cuda' if torch.cuda.is_available()else 'cpu')
    net= net.to(device)
    print('trian on:',device)
    optimizer = torch.optim.Adam(net.parameters(),lr=lr)
    for epoch in range(num_epochs):
        start = time.time()
        train_loss = 0.0
        n = 0
        for batch in data_iter:
            centers,contexts_negatives,mask,labels = [b.to(device)for b in batch]
            pred = skip_gram(centers,contexts_negatives,net[0],net[1])
            l = (loss(pred.view(labels.shape),labels,mask)*mask.shape[1]/mask.float().sum(dim=1)).mean()
            optimizer.zero_grad()
            l.backward()
            optimizer.step()
            train_loss += l.cpu().item()
            n +=1
        print('epoch:%d,loss:%.4f,time:%.2f'%(epoch+1,train_loss/n,time.time()-start))

In [None]:
train(net,0.01,10)

In [72]:
def get_similar_tokens(query_token,k,embed):
    W = embed.weight.data
    x = W[token_to_idx[query_token]]
    cos = torch.matmul(W,x)/(torch.sum(W*W,dim=1)*torch.sum(x*x)+1e-9).sqrt()
    _,topk = torch.topk(cos,k=k+1)
    topk = topk.cpu().numpy()
    for i in topk[:]:
        print('cosine sim=%.3f:%s'%(cos[i],(idx_to_taken[i])))
get_similar_tokens('a',10,net[0])   

cosine sim=1.000:a
cosine sim=0.375:suspected
cosine sim=0.356:mitsubishi
cosine sim=0.354:bugs
cosine sim=0.351:pa
cosine sim=0.335:disorders
cosine sim=0.330:resource
cosine sim=0.323:layer
cosine sim=0.319:other
cosine sim=0.319:estimates
cosine sim=0.313:picket


In [73]:
!nvidia-smi

No devices were found


In [30]:
logit = torch.tensor([5.0, 1.0, 3.0], dtype=torch.float32)
label = torch.tensor([1.0, 0.0, 1.0], dtype=torch.float32)
output = torch.nn.functional.binary_cross_entropy_with_logits(logit, label)
print(output)

tensor(0.4562)
