In [1]:
import pickle
import numpy as np
from collections import Counter
from model import SkipGramModel
import torch.optim as optim
from torch import nn
import torch

In [2]:
dataset = "NCI1"

In [3]:
with open('./'+dataset+'/graph_voc_3.json', 'rb') as f:
    graph_enc = pickle.load(f)

sub_graph_voc = []
for g in list(graph_enc.keys()):
    sub_graph_voc.extend(graph_enc[g])

min_cnt = 5

sub_graph_vocab = dict(Counter(sub_graph_voc))
sub_graph_vocab = {i:sub_graph_vocab[i] for i in list(sub_graph_vocab.keys()) if sub_graph_vocab[i]>=min_cnt}

In [4]:
for g in list(graph_enc.keys()):
    graph_enc[g] = [x for x in graph_enc[g] if x in list(sub_graph_vocab.keys())]

id_to_sub_graph = {i:list(sub_graph_vocab.keys())[i] for i in range(len(sub_graph_vocab))}
sub_graph_to_id = {id_to_sub_graph[i]:i for i in list(id_to_sub_graph.keys())}

model_1 = SkipGramModel(len(graph_enc), len(id_to_sub_graph), 1024)

In [5]:
def init_sample_table():
    sample_table = []
    sample_table_size = 1e8
    pow_frequency = np.array(list(sub_graph_vocab.values())) ** 0.75
    words_pow = sum(pow_frequency)
    ratio = pow_frequency / words_pow
    count = np.round(ratio * sample_table_size)
    for wid, c in enumerate(count):
        sample_table += [sub_graph_to_id[list(sub_graph_vocab.keys())[wid]]] * int(c)
    sample_table = np.array(sample_table)
    return sample_table

In [6]:
sample_table = init_sample_table()
neg_count = 2
epoch = 20

opt = optim.SparseAdam(model_1.parameters(), lr=0.0001)
model_1.train()

cuda = False
if torch.cuda.is_available():
    cuda = True
    model_1.cuda()

In [7]:
loss_g = {}

In [8]:
for i in range(epoch):
    for j in range(len(graph_enc)):
        opt.zero_grad()

        # doc_id = np.random.randint(1, len(graph_enc))
        doc_id = j
        if len(graph_enc[doc_id + 1]) == 0:
            continue
        doc_u = torch.tensor([doc_id], dtype=torch.long, requires_grad=False)

        pos_v = [sub_graph_to_id[x] for x in graph_enc[doc_id + 1]]
        loss = []
        for p in pos_v:

            while (True):
                neg_v = np.random.choice(sample_table, size=(neg_count)).tolist()
                if p not in neg_v:
                    break

            pos = torch.tensor([p], dtype=torch.long, requires_grad=False)
            neg_v = torch.tensor(neg_v, dtype=torch.long, requires_grad=False)

            if cuda:
                doc_u = doc_u.cuda()
                pos = pos.cuda()
                neg_v = neg_v.cuda()

            loss_val = model_1(doc_u, pos, neg_v)

            # print(str(i)+'   '+str(loss_val))
            loss.append(loss_val.data.cpu().numpy())
            loss_val.backward()
            opt.step()

        if doc_id not in list(loss_g.keys()):
            loss_g[doc_id] = [np.mean(loss)]
        else:
            loss_g[doc_id].append(np.mean(loss))

    l = np.mean([loss_g[k][i] for k in list(loss_g.keys())])

    print('epoch - ' + str(i) + '\tloss - ' + str(l))

print('Completed')

epoch - 0	loss - 2.0138826
epoch - 1	loss - 1.6570077
epoch - 2	loss - 1.2493625
epoch - 3	loss - 0.97931015
epoch - 4	loss - 0.81560194
epoch - 5	loss - 0.70842934
epoch - 6	loss - 0.6298165
epoch - 7	loss - 0.57592624
epoch - 8	loss - 0.5366204
epoch - 9	loss - 0.50597227
epoch - 10	loss - 0.4805988
epoch - 11	loss - 0.4631515
epoch - 12	loss - 0.4480258
epoch - 13	loss - 0.43521816
epoch - 14	loss - 0.4254279
epoch - 15	loss - 0.41492504
epoch - 16	loss - 0.4112482
epoch - 17	loss - 0.40363556
epoch - 18	loss - 0.39864045
epoch - 19	loss - 0.3947498
Completed


In [9]:
iter_loss = [np.mean([loss_g[x][i] for x in list(loss_g.keys())]) for i in range(epoch)]
print(iter_loss)

with open('./' + dataset + '/loss.json', 'wb') as f:
    pickle.dump(loss_g, f)

with open('./' + dataset + '/iter_loss.json', 'wb') as f:
    pickle.dump(iter_loss, f)

model_1.save_embedding(cuda, dataset)

[2.0138826, 1.6570077, 1.2493625, 0.97931015, 0.81560194, 0.70842934, 0.6298165, 0.57592624, 0.5366204, 0.50597227, 0.4805988, 0.4631515, 0.4480258, 0.43521816, 0.4254279, 0.41492504, 0.4112482, 0.40363556, 0.39864045, 0.3947498]
