In [1]:
import os.path as osp

import torch
from torch_geometric.datasets import AMiner
from torch_geometric.nn import MetaPath2Vec

In [2]:
path = osp.join(osp.dirname(osp.realpath('.')), 'AMiner')
dataset = AMiner(path)
data = dataset[0]

In [7]:
device='cuda:0'

In [8]:
import pickle

In [9]:
graph = torch.load('input/pyg_graph.torch').to(device)
node_idxs = pickle.load(open('input/nodes_by_type.pickle','rb'))
# gene_name_proteins = pickle.load(open('input/gene_name_proteins.pickle','rb'))

In [10]:
graph

HeteroData(
  [1m(tad, overlaps, atac_region)[0m={
    edge_index=[2, 116232],
    num_nodes=118729
  },
  [1m(tad, overlaps, gene)[0m={
    edge_index=[2, 59732],
    num_nodes=62343
  },
  [1m(gene, associated, protein)[0m={
    edge_index=[2, 117680],
    num_nodes=142017
  },
  [1m(atac_region, overlaps, gene)[0m={
    edge_index=[2, 27233],
    num_nodes=51117
  },
  [1m(protein, coexpressed, protein)[0m={
    edge_index=[2, 6525628],
    num_nodes=19026
  },
  [1m(protein, tf_interacts, gene)[0m={
    edge_index=[2, 7899248],
    num_nodes=30524
  },
  [1m(protein, is_named, gene_name)[0m={
    edge_index=[2, 317549],
    num_nodes=158258
  }
)

In [11]:
data

HeteroData(
  [1mauthor[0m={
    y=[246678],
    y_index=[246678],
    num_nodes=1693531
  },
  [1mvenue[0m={
    y=[134],
    y_index=[134],
    num_nodes=3883
  },
  [1mpaper[0m={ num_nodes=3194405 },
  [1m(paper, written_by, author)[0m={ edge_index=[2, 9323605] },
  [1m(author, writes, paper)[0m={ edge_index=[2, 9323605] },
  [1m(paper, published_in, venue)[0m={ edge_index=[2, 3194405] },
  [1m(venue, publishes, paper)[0m={ edge_index=[2, 3194405] }
)

In [12]:
from torch_geometric.nn.models import MetaPath2Vec

In [13]:
metapath = [
    ('author', 'writes', 'paper'),
    ('paper', 'published_in', 'venue'),
    ('venue', 'publishes', 'paper'),
    ('paper', 'written_by', 'author'),
]

In [15]:
data = data.to(device)

In [16]:
model = MetaPath2Vec(data.edge_index_dict, embedding_dim=128,
                     metapath=metapath, walk_length=50, context_size=7,
                     walks_per_node=5, num_negative_samples=5,
                     sparse=True).to(device)

loader = model.loader(batch_size=213, shuffle=True, num_workers=6)
optimizer = torch.optim.SparseAdam(list(model.parameters()), lr=0.01)

In [17]:
def train(epoch, log_steps=100, eval_steps=2000):
    model.train()

    total_loss = 0
    for i, (pos_rw, neg_rw) in enumerate(loader):
        optimizer.zero_grad()
        loss = model.loss(pos_rw.to(device), neg_rw.to(device))
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        if (i + 1) % log_steps == 0:
            print((f'Epoch: {epoch}, Step: {i + 1:05d}/{len(loader)}, '
                   f'Loss: {total_loss / log_steps:.4f}'))
            total_loss = 0

In [18]:
for epoch in range(1, 6):
    train(epoch)

Epoch: 1, Step: 00100/7951, Loss: 8.9345
Epoch: 1, Step: 00200/7951, Loss: 7.1462


KeyboardInterrupt: 