In [3]:
from gat_pred import DataBuilder

builder = DataBuilder()


In [4]:
builder.load_papers()
builder.load_methods()
builder.load_tasks()
builder.data

100%|██████████| 5364/5364 [00:55<00:00, 96.34it/s] 
100%|██████████| 907/907 [00:06<00:00, 130.21it/s]
100%|██████████| 1417/1417 [00:07<00:00, 200.58it/s]


HeteroData(
  paper={ x=[5364, 768] },
  method={ x=[907, 768] },
  task={ x=[1417, 768] }
)

In [5]:
builder.load_relations()
builder.data

HeteroData(
  paper={ x=[5364, 768] },
  method={ x=[907, 768] },
  task={ x=[1417, 768] },
  user={ x=[700, 768] },
  (paper, cites, paper)={ edge_index=[2, 98970] },
  (paper, applies, method)={ edge_index=[2, 14916] },
  (paper, performs, task)={ edge_index=[2, 15072] },
  (user, likes, paper)={ edge_index=[2, 2334] }
)

In [6]:
import torch_geometric.transforms as T
data = T.ToUndirected()(builder.data.to("cuda"))
data

HeteroData(
  paper={ x=[5364, 768] },
  method={ x=[907, 768] },
  task={ x=[1417, 768] },
  user={ x=[700, 768] },
  (paper, cites, paper)={ edge_index=[2, 196576] },
  (paper, applies, method)={ edge_index=[2, 14916] },
  (paper, performs, task)={ edge_index=[2, 15072] },
  (user, likes, paper)={ edge_index=[2, 2334] },
  (method, rev_applies, paper)={ edge_index=[2, 14916] },
  (task, rev_performs, paper)={ edge_index=[2, 15072] },
  (paper, rev_likes, user)={ edge_index=[2, 2334] }
)

In [7]:
import torch
test_data = builder.data
test_data["user", "likes", "paper"].edge_index = torch.tensor([
    [0],
    [2946]
])
test_data["user"].x = test_data["paper"].x[2946].unsqueeze(0)
test_data = T.ToUndirected()(test_data.to("cuda"))
test_data

HeteroData(
  paper={ x=[5364, 768] },
  method={ x=[907, 768] },
  task={ x=[1417, 768] },
  user={ x=[1, 768] },
  (paper, cites, paper)={ edge_index=[2, 196576] },
  (paper, applies, method)={ edge_index=[2, 14916] },
  (paper, performs, task)={ edge_index=[2, 15072] },
  (user, likes, paper)={ edge_index=[2, 1] },
  (method, rev_applies, paper)={ edge_index=[2, 14916] },
  (task, rev_performs, paper)={ edge_index=[2, 15072] },
  (paper, rev_likes, user)={ edge_index=[2, 1] }
)

In [8]:
transform = T.RandomLinkSplit(
    num_val=0.1,
    num_test=0.1,
    disjoint_train_ratio=0.3,
    neg_sampling_ratio=2.0,
    add_negative_train_samples=False,
    edge_types=("user", "likes", "paper"),
    rev_edge_types=("paper", "rev_likes", "user"), 
)
train, val, test = transform(data)

In [17]:
from torch_geometric.loader import LinkNeighborLoader

edge_label_index = train["user", "likes", "paper"].edge_label_index
edge_label = train["user", "likes", "paper"].edge_label
train_loader = LinkNeighborLoader(
    data=train,
    num_neighbors=[10, 5],
    neg_sampling_ratio=2.0,
    edge_label_index=(("user", "likes", "paper"), edge_label_index),
    edge_label=edge_label,
    batch_size=1,
    shuffle=True,
)

In [18]:
import torch
from torch_geometric.nn import GATConv, to_hetero
from torch_geometric.data import HeteroData
class GNN(torch.nn.Module):
    def __init__(self, ):
        super().__init__()
        self.conv1 = GATConv(768, 128, heads=8, add_self_loops=False)
        self.relu = torch.nn.ReLU()
        self.conv2 = GATConv(128 * 8, 768, heads=1, add_self_loops=False)
    
    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = self.relu(x)
        x = self.conv2(x, edge_index)
        return x

class GATPred(torch.nn.Module):
    def __init__(self, ):
        super().__init__()
        self.gnn = to_hetero(GNN(), metadata=data.metadata())
    
    def forward(self, x: HeteroData):
        x_dict = self.gnn({
            "user": x["user"].x,
            "paper": x["paper"].x,
            "method": x["method"].x,
            "task": x["task"].x,
        }, x.edge_index_dict)
        edge_label_index = x["likes"].edge_label_index
        user_edge_features = x_dict["user"][edge_label_index[0]]
        paper_edge_features = x_dict["paper"][edge_label_index[1]]
        return torch.sum(user_edge_features * paper_edge_features, dim=-1)
    
model = GATPred()
model = model.to("cuda")

In [19]:
from tqdm import tqdm

optim = torch.optim.SGD(model.parameters(), lr=0.001)
loss_fn = torch.nn.BCEWithLogitsLoss()

for epoch in range(50):
    total_loss = total_examples = 0
    for d in tqdm(train_loader):
        optim.zero_grad()
        pred = model(d.to("cuda"))
        gt = d["likes"].edge_label
        loss = loss_fn(pred, gt)
        loss.backward()
        optim.step()
        total_loss += float(loss) * pred.numel()
        total_examples += pred.numel()
    print(f"Epoch {epoch}: {total_loss / total_examples}")

100%|██████████| 560/560 [00:03<00:00, 170.71it/s]


Epoch 0: 2.0400003738701344


100%|██████████| 560/560 [00:03<00:00, 153.98it/s]


Epoch 1: 0.4495722593673106


100%|██████████| 560/560 [00:03<00:00, 162.61it/s]


Epoch 2: 0.43228127766134483


100%|██████████| 560/560 [00:03<00:00, 178.75it/s]


Epoch 3: 0.44291495019569993


100%|██████████| 560/560 [00:03<00:00, 179.98it/s]


Epoch 4: 0.42239161703203404


100%|██████████| 560/560 [00:03<00:00, 173.26it/s]


Epoch 5: 0.40811199080199


100%|██████████| 560/560 [00:03<00:00, 180.12it/s]


Epoch 6: 0.41380894535354207


100%|██████████| 560/560 [00:03<00:00, 166.30it/s]


Epoch 7: 0.396772932773456


100%|██████████| 560/560 [00:03<00:00, 160.70it/s]


Epoch 8: 0.39937817073161047


100%|██████████| 560/560 [00:03<00:00, 171.51it/s]


Epoch 9: 0.39632899493112095


100%|██████████| 560/560 [00:03<00:00, 168.08it/s]


Epoch 10: 0.3945737660064229


100%|██████████| 560/560 [00:03<00:00, 173.05it/s]


Epoch 11: 0.4048318943540965


100%|██████████| 560/560 [00:03<00:00, 162.04it/s]


Epoch 12: 0.4042321073955723


100%|██████████| 560/560 [00:03<00:00, 156.16it/s]


Epoch 13: 0.38998840616217684


100%|██████████| 560/560 [00:03<00:00, 161.77it/s]


Epoch 14: 0.38437084103269237


100%|██████████| 560/560 [00:03<00:00, 165.89it/s]


Epoch 15: 0.40355883038469725


100%|██████████| 560/560 [00:03<00:00, 164.12it/s]


Epoch 16: 0.3852798390734409


100%|██████████| 560/560 [00:03<00:00, 173.68it/s]


Epoch 17: 0.3863614734328751


100%|██████████| 560/560 [00:03<00:00, 169.12it/s]


Epoch 18: 0.3879443502718849


100%|██████████| 560/560 [00:03<00:00, 162.27it/s]


Epoch 19: 0.3837127708896462


100%|██████████| 560/560 [00:03<00:00, 168.97it/s]


Epoch 20: 0.3933171738737396


100%|██████████| 560/560 [00:03<00:00, 154.40it/s]


Epoch 21: 0.3751422557447638


100%|██████████| 560/560 [00:03<00:00, 168.25it/s]


Epoch 22: 0.38791513900671687


100%|██████████| 560/560 [00:03<00:00, 182.04it/s]


Epoch 23: 0.3881270735990256


100%|██████████| 560/560 [00:03<00:00, 165.28it/s]


Epoch 24: 0.38644676988146137


100%|██████████| 560/560 [00:03<00:00, 166.12it/s]


Epoch 25: 0.38936316324397924


100%|██████████| 560/560 [00:03<00:00, 177.54it/s]


Epoch 26: 0.39684246239651527


100%|██████████| 560/560 [00:03<00:00, 162.13it/s]


Epoch 27: 0.4004227779339999


100%|██████████| 560/560 [00:03<00:00, 168.16it/s]


Epoch 28: 0.4117867841518351


100%|██████████| 560/560 [00:03<00:00, 168.58it/s]


Epoch 29: 0.40188352216833406


100%|██████████| 560/560 [00:03<00:00, 175.04it/s]


Epoch 30: 0.3947745163592377


100%|██████████| 560/560 [00:03<00:00, 168.20it/s]


Epoch 31: 0.38658689545866637


100%|██████████| 560/560 [00:03<00:00, 166.54it/s]


Epoch 32: 0.39613239894887164


100%|██████████| 560/560 [00:03<00:00, 171.02it/s]


Epoch 33: 0.39824183017813736


100%|██████████| 560/560 [00:03<00:00, 175.26it/s]


Epoch 34: 0.3906564196305616


100%|██████████| 560/560 [00:03<00:00, 170.85it/s]


Epoch 35: 0.38814345193760735


100%|██████████| 560/560 [00:03<00:00, 174.23it/s]


Epoch 36: 0.3990499106102756


100%|██████████| 560/560 [00:03<00:00, 169.18it/s]


Epoch 37: 0.40237337179215893


100%|██████████| 560/560 [00:03<00:00, 178.82it/s]


Epoch 38: 0.39000896472217783


100%|██████████| 560/560 [00:03<00:00, 177.50it/s]


Epoch 39: 0.3886711768939027


100%|██████████| 560/560 [00:03<00:00, 159.94it/s]


Epoch 40: 0.39858546427318026


100%|██████████| 560/560 [00:03<00:00, 180.13it/s]


Epoch 41: 0.38759535597637296


100%|██████████| 560/560 [00:03<00:00, 174.73it/s]


Epoch 42: 0.39157305023899036


100%|██████████| 560/560 [00:03<00:00, 162.66it/s]


Epoch 43: 0.37457448722395514


100%|██████████| 560/560 [00:03<00:00, 171.99it/s]


Epoch 44: 0.39949063805730217


100%|██████████| 560/560 [00:03<00:00, 168.11it/s]


Epoch 45: 0.3973535255728556


100%|██████████| 560/560 [00:03<00:00, 177.63it/s]


Epoch 46: 0.3895810508568372


100%|██████████| 560/560 [00:03<00:00, 167.17it/s]


Epoch 47: 0.37086351539141366


100%|██████████| 560/560 [00:03<00:00, 173.40it/s]


Epoch 48: 0.3771397073553609


100%|██████████| 560/560 [00:03<00:00, 173.28it/s]

Epoch 49: 0.4050315668606865





In [21]:
torch.save(model.state_dict(), "gat_pred.pth")

In [22]:
model.eval()
test_data["likes"].edge_label_index = test_data["likes"].edge_index
model(test_data)

tensor([1.2291], device='cuda:0', grad_fn=<SumBackward1>)