In [1]:
import random
from itertools import combinations

import pandas as pd
import networkx as nx
import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor

import torch_geometric.transforms as T
from torch_geometric.data import HeteroData
from torch_geometric.loader import LinkNeighborLoader
from torch_geometric.transforms import RandomLinkSplit, ToUndirected
from torch_geometric.loader import LinkNeighborLoader
from torch_geometric.nn import SAGEConv, to_hetero

from sklearn.metrics import roc_auc_score
from sklearn.preprocessing import LabelEncoder

In [2]:
import os
import sys
from pathlib import Path


project_dir = Path(os.getcwd()).parent
os.chdir(project_dir)
print('Working dir: ', os.getcwd())

Working dir:  c:\Users\aleksandr\PycharmProjects\development\holdings


In [3]:
train_graph = torch.load('data/processed/train_synth')
val_graph = torch.load('data/processed/val_synth')
test_graph = torch.load('data/processed/test_synth')

In [4]:
class GNN(torch.nn.Module):
    def __init__(self, hidden_channels, num_layers=4, dropout=0.01):
        super().__init__()
        conv_model = SAGEConv

        self.num_layers = num_layers
        self.dropout = dropout
        
        self.convs = torch.nn.ModuleList()
        self.convs.append(conv_model(hidden_channels, hidden_channels))
        
        assert (self.num_layers >= 1), 'Number of layers is not >=1'
        for l in range(self.num_layers - 1):
            self.convs.append(conv_model(hidden_channels, hidden_channels))
        
    def forward(self, x: Tensor, edge_index: Tensor) -> Tensor:
        for i in range(self.num_layers - 1):
            x = self.convs[i](x, edge_index)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.convs[-1](x, edge_index)
        x = F.dropout(x, p=self.dropout, training=self.training)
        return x


class Classifier(torch.nn.Module):
    def __init__(self, in_channels=64, hidden_channels=64, num_layers=10, dropout=0.1):
        super(Classifier, self).__init__()
        self.lins = torch.nn.ModuleList()
        self.lins.append(torch.nn.Linear(in_channels, hidden_channels))
        for _ in range(num_layers - 2):
            self.lins.append(torch.nn.Linear(hidden_channels, hidden_channels))
        self.lins.append(torch.nn.Linear(hidden_channels, 1))
        self.dropout = dropout

    def forward(self, x_user: Tensor, x_movie: Tensor, edge_label_index: Tensor) -> Tensor:
        edge_feat_user = x_user[edge_label_index[0]]
        edge_feat_movie = x_movie[edge_label_index[1]]
        x = edge_feat_user * edge_feat_movie
        for lin in self.lins[:-1]:
            x = lin(x)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.lins[-1](x)
        return x


class Model(nn.Module):
    def __init__(
        self, 
        graph_metadata, 
        x_i_feats_num, x_j_feats_num,
        hidden_channels=64, gnn_num_layers=4, gnn_dropout=0.3,
        ):
        super().__init__()

        self.x_i_linear = nn.Linear(in_features=x_i_feats_num, out_features=hidden_channels)
        self.x_j_linear = nn.Linear(in_features=x_j_feats_num, out_features=hidden_channels)

        self.gnn = GNN(hidden_channels, num_layers=gnn_num_layers, dropout=gnn_dropout)
        self.gnn = to_hetero(self.gnn, metadata=graph_metadata)
        
        self.classifier = Classifier()
        
    def forward(self, data: HeteroData) -> Tensor:
        x_dict = {
          "client": self.x_i_linear(data["client"].x),
        }

        x_dict = self.gnn(x_dict, data.edge_index_dict)
        pred = self.classifier(
            x_dict["client"],
            x_dict["client"],
            data['client', 'community', 'client'].edge_label_index
        )
        return pred

In [5]:
train_loader = LinkNeighborLoader(
    data=train_graph,
    num_neighbors=[10, 5],
    neg_sampling_ratio=2.0,
    edge_label_index=(('client', 'community', 'client'), train_graph['client', 'community', 'client'].edge_label_index),
    edge_label=train_graph['client', 'community', 'client'].edge_label,
    batch_size=1024,
    shuffle=True,
)

In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = Model(graph_metadata=train_graph.metadata(), x_i_feats_num=train_graph['client'].x.shape[1], x_j_feats_num=train_graph['client'].x.shape[1],)
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_fn = torch.nn.BCEWithLogitsLoss().to(device)

for epoch in range(1, 2):
    total_loss = total_examples = 0
    for sampled_data in tqdm.tqdm(train_loader):
        optimizer.zero_grad()
        sampled_data.to(device)
        pred = model(sampled_data)
        ground_truth = sampled_data['client', 'community', 'client'].edge_label
        loss = loss_fn(pred, ground_truth.reshape(-1,1))
        # loss = F.binary_cross_entropy_with_logits(pred.flatten(), ground_truth)
        loss.backward()
        optimizer.step()
        total_loss += float(loss) * pred.numel()
        total_examples += pred.numel()
    print(f"Epoch: {epoch:03d}, Loss: {total_loss / total_examples:.4f}")

100%|██████████| 2/2 [00:00<00:00, 18.00it/s]

Epoch: 001, Loss: 0.6808





In [7]:
val_loader = LinkNeighborLoader(
    data=val_graph,
    num_neighbors=[10, 5],
    edge_label_index=(('client', 'community', 'client'), val_graph['client', 'community', 'client'].edge_label_index),
    edge_label=val_graph['client', 'community', 'client'].edge_label,
    batch_size=3 * 512,
    shuffle=False,
)

preds = []
ground_truths = []
for sampled_data in tqdm.tqdm(val_loader):
    with torch.no_grad():
        sampled_data.to(device)
        preds.append(torch.sigmoid(model(sampled_data)))
        ground_truths.append(sampled_data['client', 'community', 'client'].edge_label)
pred = torch.cat(preds, dim=0).cpu().numpy()
ground_truth = torch.cat(ground_truths, dim=0).cpu().numpy()
auc = roc_auc_score(ground_truth, pred)
print(f"Validation AUC: {auc:.4f}")

100%|██████████| 3/3 [00:00<00:00, 64.10it/s]

Validation AUC: 0.5005





In [8]:
train_graph.validate()

del train_graph['client', 'community', 'client']['edge_label']
del train_graph['client', 'community', 'client']['edge_label_index']

transform = T.RandomLinkSplit(
    num_val=0.1,
    num_test=0.1,
    disjoint_train_ratio=0.3,
    neg_sampling_ratio=1.0,
    add_negative_train_samples=True,
    edge_types=('client', 'community', 'client'),
)

_, _, _ = transform(train_graph)