In [1]:
import collections
import math
import random
import sys
import time
import os
import numpy as np
import torch
from torch import nn
import torch.utils.data as Data
import json
from tqdm import tqdm

In [2]:
files = os.listdir('./graph_series/')
path = './graph_series/'

In [3]:
def getTokens():
    tokens = {}
    with open('./nodes-SG.json') as f:
        tokens = json.load(f)
    return tokens

def getInputSeries():
    tokens = getTokens()
    nodes_name = list(tokens.keys())
    nodes_int = [tokens[item] for item in tokens.keys()]
    one_hot = np.eye(len(tokens.keys()))
    inputSeries = []
    allSeries = []
    i = 1
    for file in tqdm(files):
        vec = np.loadtxt(path+file,dtype=np.int32)
#         allSeries.extend(vec.tolist())
#         if vec.shape[0] > 300:
#             vec = vec[:300]
#         else:
#             vec = np.append(vec,np.zeros(300-vec.shape[0]))
        inputSeries.append(vec.tolist())
    return inputSeries

In [4]:
inputSeries = getInputSeries()

100%|██████████| 171019/171019 [26:23<00:00, 107.98it/s]


In [29]:
counter = collections.Counter([tk for st in inputSeries for tk in st])
counter = dict(filter(lambda x: x[1] >= 0, counter.items()))

idx_to_token = [tk for tk, _ in counter.items()]
token_to_idx = {tk: idx for idx, tk in enumerate(idx_to_token)}
dataset = [[token_to_idx[tk] for tk in st if tk in token_to_idx] for st in inputSeries]
num_tokens = sum([len(st) for st in dataset])

print(num_tokens)

25683649


In [30]:
len(counter)

517

In [31]:
a = np.array(list(counter.keys()),dtype=np.int32)
a = np.sort(a,axis=0,kind='quicksort',order=None)
li = np.array(range(801))
np.setdiff1d(li,a)

array([517, 518, 519, 520, 521, 522, 523, 524, 525, 526, 527, 528, 529,
       530, 531, 532, 533, 534, 535, 536, 537, 538, 539, 540, 541, 542,
       543, 544, 545, 546, 547, 548, 549, 550, 551, 552, 553, 554, 555,
       556, 557, 558, 559, 560, 561, 562, 563, 564, 565, 566, 567, 568,
       569, 570, 571, 572, 573, 574, 575, 576, 577, 578, 579, 580, 581,
       582, 583, 584, 585, 586, 587, 588, 589, 590, 591, 592, 593, 594,
       595, 596, 597, 598, 599, 600, 601, 602, 603, 604, 605, 606, 607,
       608, 609, 610, 611, 612, 613, 614, 615, 616, 617, 618, 619, 620,
       621, 622, 623, 624, 625, 626, 627, 628, 629, 630, 631, 632, 633,
       634, 635, 636, 637, 638, 639, 640, 641, 642, 643, 644, 645, 646,
       647, 648, 649, 650, 651, 652, 653, 654, 655, 656, 657, 658, 659,
       660, 661, 662, 663, 664, 665, 666, 667, 668, 669, 670, 671, 672,
       673, 674, 675, 676, 677, 678, 679, 680, 681, 682, 683, 684, 685,
       686, 687, 688, 689, 690, 691, 692, 693, 694, 695, 696, 69

In [32]:
def discard(idx):
    return random.uniform(0, 1) < 1 - math.sqrt(1e-4 / counter[idx_to_token[idx]] * num_tokens)
subsampled_dataset = [[tk for tk in st if not discard(tk)] for st in dataset]

print(sum([len(st) for st in subsampled_dataset]))

2149103


In [33]:
def get_centers_and_contexts(dataset, max_window_size):
    centers, contexts = [], []
    for st in dataset:
        if len(st) < 2: # 至少两个词才能组成中心词
            continue
        centers += st
        for center_i in range(len(st)):
            window_size = random.randint(1, max_window_size)
            indices = list(range(max(0, center_i - window_size),
            min(len(st), center_i + 1 + window_size)))
            indices.remove(center_i) # 去除中心词
            contexts.append([st[idx] for idx in indices])
    return centers, contexts

In [34]:
all_centers, all_contexts = get_centers_and_contexts(subsampled_dataset, 5)

In [35]:
def get_negatives(all_contexts, sampling_weights, K):
    all_negatives, neg_candidates, i = [], [], 0
    population = list(range(len(sampling_weights)))
    for contexts in all_contexts:
        negatives = []
        while len(negatives) < len(contexts) * K:
            if i == len(neg_candidates):
                 i, neg_candidates = 0, random.choices(population, sampling_weights, k=int(1e5))
            neg, i = neg_candidates[i], i + 1
            if neg not in set(contexts):
                negatives.append(neg)
        all_negatives.append(negatives)
    return all_negatives

sampling_weights = [counter[w]**0.75 for w in idx_to_token]
all_negatives = get_negatives(all_contexts, sampling_weights, 5)

In [36]:
class MyDataset(torch.utils.data.Dataset):
    def __init__(self, centers, contexts, negatives):
        assert len(centers) == len(contexts) == len(negatives)
        self.centers = centers
        self.contexts = contexts
        self.negatives = negatives
 
    def __getitem__(self, index):
         return (self.centers[index], self.contexts[index],self.negatives[index])
         
    def __len__(self):
         return len(self.centers)

In [37]:
def batchify(data):
    max_len = max(len(c) + len(n) for _, c, n in data)
    centers, contexts_negatives, masks, labels = [], [], [], []
    for center, context, negative in data:
        cur_len = len(context) + len(negative)
        centers += [center]
        contexts_negatives += [context + negative + [0] * (max_len - cur_len)]
        masks += [[1] * cur_len + [0] * (max_len - cur_len)]
        labels += [[1] * len(context) + [0] * (max_len - len(context))]
    return (torch.tensor(centers).view(-1, 1),
             
torch.tensor(contexts_negatives),
torch.tensor(masks), torch.tensor(labels))

In [38]:
len(all_centers), len(all_contexts), len(all_negatives)

(2135328, 2135328, 2135328)

In [39]:
batch_size = 512
num_workers = 0 if sys.platform.startswith('win32') else 4
dataset = MyDataset(all_centers, all_contexts, all_negatives)
data_iter = Data.DataLoader(dataset, batch_size, shuffle=True,
collate_fn=batchify,num_workers=num_workers)

for batch in data_iter:
    for name, data in zip(['centers', 'contexts_negatives', 'masks', 'labels'], batch):
        print(name, 'shape:', data.shape)
    break

centers shape: torch.Size([512, 1])
contexts_negatives shape: torch.Size([512, 60])
masks shape: torch.Size([512, 60])
labels shape: torch.Size([512, 60])


In [40]:
embed = nn.Embedding(num_embeddings=801, embedding_dim=100)
embed.weight

Parameter containing:
tensor([[ 0.4525,  1.9874, -0.9452,  ..., -0.7464,  0.9073,  0.1377],
        [-0.9171, -0.8009, -0.6175,  ..., -1.0223,  1.0705, -0.4404],
        [ 0.9239, -0.3658,  0.2486,  ..., -0.7920,  0.8907,  0.3150],
        ...,
        [ 0.0952, -1.0623,  0.0595,  ...,  0.0861, -1.1065,  0.7058],
        [-0.7767,  1.0824, -0.0778,  ..., -0.7923, -0.2387,  0.3701],
        [-0.5217,  0.6557,  0.1799,  ..., -0.1496,  0.8289,  0.7726]],
       requires_grad=True)

In [41]:
def skip_gram(center, contexts_and_negatives, embed_v, embed_u):
    v = embed_v(center)
    u = embed_u(contexts_and_negatives)
    pred = torch.bmm(v, u.permute(0, 2, 1))
    return pred

In [42]:
class SigmoidBinaryCrossEntropyLoss(nn.Module):
    def __init__(self): # none mean sum
        super(SigmoidBinaryCrossEntropyLoss, self).__init__()
        
    def forward(self, inputs, targets, mask=None):
#          """
#          input – Tensor shape: (batch_size, len)
#          target – Tensor of the same shape as input
#          """
        inputs, targets, mask = inputs.float(), targets.float(),mask.float()
        res = nn.functional.binary_cross_entropy_with_logits(inputs, targets, reduction="none", weight=mask)
        return res.mean(dim=1)

loss = SigmoidBinaryCrossEntropyLoss()

In [43]:
def sigmd(x):
    return - math.log(1 / (1 + math.exp(-x)))

print('%.4f' % ((sigmd(1.5) + sigmd(-0.3) + sigmd(1) + sigmd(-2)) / 4)) 
print('%.4f' % ((sigmd(1.1) + sigmd(-0.6) + sigmd(-2.2)) / 3))

0.8740
1.2100


In [44]:
#Here, you could change the feautre size of per node.
embed_size = 100 
net = nn.Sequential(
    nn.Embedding(num_embeddings=len(idx_to_token), embedding_dim=embed_size),
    nn.Embedding(num_embeddings=len(idx_to_token), embedding_dim=embed_size)
)

In [45]:
def train(net, lr, num_epochs):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print("train on", device)
    net = net.to(device)
    optimizer = torch.optim.Adam(net.parameters(), lr=lr)
    for epoch in range(num_epochs):
        start, l_sum, n = time.time(), 0.0, 0
        for batch in data_iter: 
            center, context_negative, mask, label = [d.to(device) for d in batch]
            pred = skip_gram(center, context_negative, net[0], net[1])

            l = (loss(pred.view(label.shape), label, mask) *
            mask.shape[1] / mask.float().sum(dim=1)).mean() 
            optimizer.zero_grad()
            l.backward()
            optimizer.step()
            l_sum += l.item()
            n += 1
        print('epoch %d, loss %.2f, time %.2fs' % (epoch + 1, l_sum / n, time.time() - start))

In [46]:
train(net, 0.01, 15)

train on cuda
epoch 1, loss 0.42, time 11.00s
epoch 2, loss 0.39, time 10.36s
epoch 3, loss 0.39, time 10.02s
epoch 4, loss 0.39, time 9.81s
epoch 5, loss 0.39, time 9.81s
epoch 6, loss 0.39, time 9.88s
epoch 7, loss 0.39, time 9.86s
epoch 8, loss 0.39, time 9.85s
epoch 9, loss 0.39, time 9.94s
epoch 10, loss 0.39, time 10.16s
epoch 11, loss 0.39, time 10.23s
epoch 12, loss 0.39, time 10.23s
epoch 13, loss 0.39, time 10.26s
epoch 14, loss 0.39, time 10.12s
epoch 15, loss 0.39, time 10.38s


In [47]:
weight = net[0].weight.data.tolist()

In [48]:
weight = np.array(weight,dtype=np.float32)

In [49]:
np.savetxt('./word_vectors_SG.txt',weight)

In [50]:
np.array(idx_to_token)

array([  0,   1,   2,   3,  10,   4,   5,   6,   7,   8,   9,  17,  18,
        13,  14,  20,  21,  16,  82,  23,  24,  45,  46,  86,  11,  12,
        19,  15,  85,  89,  88,  47,  35,  25,  28, 101,  90,  27,  29,
        75,  34, 112, 106,  91,  26,  36,  48,  52,  50,  53,  80,  81,
        84,  92,  93,  94,  95,  98,  96, 100,  30,  31,  79, 102,  22,
       103,  37,  83,  55,  87,  97, 138, 142,  38,  39,  99,  40, 135,
        41,  42,  43,  44, 121, 122, 123,  49,  51,  54,  56,  57,  58,
        59,  60, 115,  61,  32,  33, 114, 147,  62,  63,  64,  65,  66,
        67,  68,  69,  70,  71,  72,  73,  74,  76,  77,  78, 118, 119,
       116, 140, 139, 148, 150, 111, 134, 143, 283, 141, 104, 105, 107,
       108, 109, 110, 117, 124, 224, 144, 120, 125, 126, 128, 129, 151,
       152, 153, 154, 155, 156, 157, 158, 166, 167, 168, 127, 163, 188,
       189, 212, 213, 225, 164, 131, 146, 291, 214, 217, 218, 222, 223,
       227, 228, 229, 230, 231, 232, 233, 145, 136, 137, 113, 19

In [51]:
np.savetxt('./word_index_SG.txt',np.array(idx_to_token,dtype=np.int32),fmt = "%d")

In [52]:
len(token_to_idx)

517