In [None]:
import os
import os.path as osp

from torch_geometric.datasets import Planetoid
import torch_geometric.transforms as T
from torch_geometric.utils import negative_sampling
from torch_geometric.nn import GCNConv
import numpy as np
import torch
from torch.nn import Sequential, Linear, ReLU
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score, accuracy_score

from utils import (
    get_link_labels,
    prediction_fairness,
)

from torch_geometric.utils import train_test_split_edges

device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
class GCN(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(in_channels, 128)
        self.conv2 = GCNConv(128, out_channels)

    def encode(self, x, pos_edge_index):
        x = F.relu(self.conv1(x, pos_edge_index))
        x = self.conv2(x, pos_edge_index)
        return x

    def decode(self, z, pos_edge_index, neg_edge_index):
        edge_index = torch.cat([pos_edge_index, neg_edge_index], dim=-1)
        logits = (z[edge_index[0]] * z[edge_index[1]]).sum(dim=-1)
        return logits, edge_index

In [None]:
dataset = "citeseer" #"cora" "pubmed"
path = osp.join(osp.dirname(osp.realpath('__file__')), "..", "data", dataset)
dataset = Planetoid(path, dataset, transform=T.NormalizeFeatures())

In [None]:
test_seeds = [0,1,2,3,4,5]
acc_auc = []
fairness = []

In [None]:
delta = 0.16

for random_seed in test_seeds:

    np.random.seed(random_seed)
    data = dataset[0]
    protected_attribute = data.y
    data.train_mask = data.val_mask = data.test_mask = data.y = None
    data = train_test_split_edges(data, val_ratio=0.1, test_ratio=0.2)
    data = data.to(device)

    num_classes = len(np.unique(protected_attribute))
    N = data.num_nodes
    
    
    epochs = 101
    model = GCN(data.num_features, 128).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.005)
    

    Y = torch.LongTensor(protected_attribute).to(device)
    Y_aux = (
        Y[data.train_pos_edge_index[0, :]] != Y[data.train_pos_edge_index[1, :]]
    ).to(device)
    randomization = (
        torch.FloatTensor(epochs, Y_aux.size(0)).uniform_() < 0.5 + delta
    ).to(device)
    
    
    best_val_perf = test_perf = 0
    for epoch in range(1, epochs):
        # TRAINING    
        neg_edges_tr = negative_sampling(
            edge_index=data.train_pos_edge_index,
            num_nodes=N,
            num_neg_samples=data.train_pos_edge_index.size(1) // 2,
        ).to(device)

        if epoch == 1 or epoch % 10 == 0:
            keep = torch.where(randomization[epoch], Y_aux, ~Y_aux)

        model.train()
        optimizer.zero_grad()

        z = model.encode(data.x, data.train_pos_edge_index[:, keep])
        link_logits, _ = model.decode(
            z, data.train_pos_edge_index[:, keep], neg_edges_tr
        )
        tr_labels = get_link_labels(
            data.train_pos_edge_index[:, keep], neg_edges_tr
        ).to(device)
        
        loss = F.binary_cross_entropy_with_logits(link_logits, tr_labels)
        loss.backward()
        optimizer.step()

        # EVALUATION
        model.eval()
        perfs = []
        for prefix in ["val", "test"]:
            pos_edge_index = data[f"{prefix}_pos_edge_index"]
            neg_edge_index = data[f"{prefix}_neg_edge_index"]
            with torch.no_grad():
                z = model.encode(data.x, data.train_pos_edge_index)
                link_logits, edge_idx = model.decode(z, pos_edge_index, neg_edge_index)
            link_probs = link_logits.sigmoid()
            link_labels = get_link_labels(pos_edge_index, neg_edge_index)
            auc = roc_auc_score(link_labels.cpu(), link_probs.cpu())
            perfs.append(auc)

        val_perf, tmp_test_perf = perfs
        if val_perf > best_val_perf:
            best_val_perf = val_perf
            test_perf = tmp_test_perf
        if epoch%10==0:
            log = "Epoch: {:03d}, Loss: {:.4f}, Val: {:.4f}, Test: {:.4f}"
            print(log.format(epoch, loss, best_val_perf, test_perf))

    # FAIRNESS
    auc = test_perf
    cut = [0.45, 0.50, 0.55, 0.60, 0.65, 0.70, 0.75]
    best_acc = 0
    best_cut = 0.5
    for i in cut:
        acc = accuracy_score(link_labels.cpu(), link_probs.cpu() >= i)
        if acc > best_acc:
            best_acc = acc
            best_cut = i
    f = prediction_fairness(
        edge_idx.cpu(), link_labels.cpu(), link_probs.cpu() >= best_cut, Y.cpu()
    )
    acc_auc.append([best_acc * 100, auc * 100])
    fairness.append([x * 100 for x in f])

In [None]:
ma = np.mean(np.asarray(acc_auc), axis=0)
mf = np.mean(np.asarray(fairness), axis=0)

sa = np.std(np.asarray(acc_auc), axis=0)
sf = np.std(np.asarray(fairness), axis=0)

print(f"ACC: {ma[0]:2f} +- {sa[0]:2f}")
print(f"AUC: {ma[1]:2f} +- {sa[1]:2f}")

print(f"DP mix: {mf[0]:2f} +- {sf[0]:2f}")
print(f"EoP mix: {mf[1]:2f} +- {sf[1]:2f}")
print(f"DP group: {mf[2]:2f} +- {sf[2]:2f}")
print(f"EoP group: {mf[3]:2f} +- {sf[3]:2f}")
print(f"DP sub: {mf[4]:2f} +- {sf[4]:2f}")
print(f"EoP sub: {mf[5]:2f} +- {sf[5]:2f}")