In [None]:
import os.path as osp

import torch
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score

from torch_geometric.utils import (negative_sampling, remove_self_loops,
                                   add_self_loops)
from torch_geometric.datasets import Planetoid
import torch_geometric.transforms as T
from torch_geometric.nn import GCNConv, ChebConv, GINConv, GATConv  # noqa
from torch_geometric.utils import train_test_split_edges
import argparse
import numpy as np
import random
import os
from sklearn.metrics import roc_auc_score, f1_score
import json
from torch.nn import Sequential, ReLU, Linear

class GradReverse(torch.autograd.Function):
    """
    Implement the gradient reversal layer for the convenience of domain adaptation neural network.
    The forward part is the identity function while the backward part is the negative function.
    """
    @staticmethod
    def forward(ctx, x):
        return x.view_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output.neg()

class GradientReversalLayer(torch.nn.Module):
    def __init__(self):
        super(GradientReversalLayer, self).__init__()

    def forward(self, inputs):
        return GradReverse.apply(inputs)

In [None]:
def sim(lambda_reg):
    class Net(torch.nn.Module):
        def __init__(self, name='GCNConv'):
            super(Net, self).__init__()
            self.name = name
            if (name == 'GCNConv'):
                self.conv1 = GCNConv(dataset.num_features, 128)
                self.conv2 = GCNConv(128, 64)
            elif (name == 'ChebConv'):
                self.conv1 = ChebConv(dataset.num_features, 128, K=2)
                self.conv2 = ChebConv(128, 64, K=2)
            elif (name == 'GATConv'):
                self.conv1 = GATConv(dataset.num_features, 128)
                self.conv2 = GATConv(128, 64)
            elif (name == 'GINConv'):
                nn1 = Sequential(Linear(dataset.num_features, 128), ReLU(), Linear(128, 64))
                self.conv1 = GINConv(nn1)
                self.bn1 = torch.nn.BatchNorm1d(64)
                nn2 = Sequential(Linear(64, 64), ReLU(), Linear(64, 64))
                self.conv2 = GINConv(nn2)
                self.bn2 = torch.nn.BatchNorm1d(64)

            self.attr = GCNConv(64, dataset.num_classes, cached=True,
                                    normalize=not gdc)

            self.attack = GCNConv(64, dataset.num_classes, cached=True,
                                normalize=not gdc)
            self.reverse = GradientReversalLayer()

        def forward(self, pos_edge_index, neg_edge_index):

            if (self.name == 'GINConv'):
                x = F.relu(self.conv1(data.x, data.train_pos_edge_index))
                x = self.bn1(x)
                x = F.relu(self.conv2(x, data.train_pos_edge_index))
                x = self.bn2(x)
            else:
                x = F.relu(self.conv1(data.x, data.train_pos_edge_index))
                x = self.conv2(x, data.train_pos_edge_index)

            feat = x
            attr = self.attr(x, edge_index, edge_weight)

            #print(feat.size())
            attack = self.reverse(x)
            att = self.attack(attack, edge_index, edge_weight)

            total_edge_index = torch.cat([pos_edge_index, neg_edge_index], dim=-1)
            x_j = torch.index_select(x, 0, total_edge_index[0])
            x_i = torch.index_select(x, 0, total_edge_index[1])

            """
            print(x_j.size())
            print(x_i.size())
            """

            res = torch.einsum("ef,ef->e", x_i, x_j)

            #print(res.size())
            return res, F.log_softmax(attr, dim=1), F.log_softmax(att, dim=1), feat
    
    m = 'GATConv' 
    seed = 42
    lr = 0.01
    num_epochs = 100 if lambda_reg==0 else 175
    finetune_epochs = 40
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)

    dataset = "Cora"
    path = osp.join('..', 'data', dataset)
    dataset = Planetoid(path, dataset, T.NormalizeFeatures())
    data = dataset[0]
    gdc = False
    if gdc:
        gdc = T.GDC(self_loop_weight=1, normalization_in='sym',
                    normalization_out='col',
                    diffusion_kwargs=dict(method='ppr', alpha=0.05),
                    sparsification_kwargs=dict(method='topk', k=128,
                                            dim=0), exact=True)
        data = gdc(data)

    labels = data.y.cuda()
    edge_index, edge_weight = data.edge_index.cuda(), data.edge_attr

    print(labels.size())
    # Train/validation/test
    data = train_test_split_edges(data)

    print(labels)

    device = torch.device('cuda')
    model, data = Net(m).cuda(), data.to("cuda")

    if (m=='GINConv'):
        optimizer = torch.optim.Adam([
            dict(params=model.conv1.parameters(), weight_decay=0),
            dict(params=model.bn1.parameters(), weight_decay=0),
            dict(params=model.conv2.parameters(), weight_decay=0),
            dict(params=model.bn2.parameters(), weight_decay=0),
        ], lr=lr)
    else:
        optimizer = torch.optim.Adam([
            dict(params=model.conv1.parameters(), weight_decay=0),
            dict(params=model.conv2.parameters(), weight_decay=0)
        ], lr=lr)

    if (m=='GINConv'):
        optimizer_att = torch.optim.Adam([
            dict(params=model.conv2.parameters(), weight_decay=5e-4), 
            dict(params=model.bn2.parameters(), weight_decay=0),  
            dict(params=model.attack.parameters(), weight_decay=5e-4),
        ], lr=lr * lambda_reg)
    else:
        optimizer_att = torch.optim.Adam([
            dict(params=model.conv2.parameters(), weight_decay=5e-4),   
            dict(params=model.attack.parameters(), weight_decay=5e-4),
        ], lr=lr * lambda_reg)

    def get_link_labels(pos_edge_index, neg_edge_index):
        link_labels = torch.zeros(pos_edge_index.size(1) +
                                neg_edge_index.size(1)).float().to(device)
        link_labels[:pos_edge_index.size(1)] = 1.
        return link_labels


    def train():
        model.train()
        optimizer.zero_grad()

        x, pos_edge_index = data.x, data.train_pos_edge_index

        _edge_index, _ = remove_self_loops(pos_edge_index)
        pos_edge_index_with_self_loops, _ = add_self_loops(_edge_index,
                                                        num_nodes=x.size(0))

        neg_edge_index = negative_sampling(
            edge_index=pos_edge_index_with_self_loops, num_nodes=x.size(0),
            num_neg_samples=pos_edge_index.size(1))

        link_logits, attr_prediction, attack_prediction,_ = model(pos_edge_index, neg_edge_index)
        link_labels = get_link_labels(pos_edge_index, neg_edge_index)

        loss = F.binary_cross_entropy_with_logits(link_logits, link_labels)
        loss.backward(retain_graph=True)
        optimizer.step()

        optimizer_att.zero_grad()
        loss2 = F.nll_loss(attack_prediction, labels)
        loss2.backward()
        optimizer_att.step()

        return loss


    def test():
        model.eval()
        perfs = []
        for prefix in ["val", "test"]:
            pos_edge_index, neg_edge_index = [
                index for _, index in data("{}_pos_edge_index".format(prefix),
                                        "{}_neg_edge_index".format(prefix))
            ]
            link_probs = torch.sigmoid(model(pos_edge_index, neg_edge_index)[0])
            link_labels = get_link_labels(pos_edge_index, neg_edge_index)
            link_probs = link_probs.detach().cpu().numpy()
            link_labels = link_labels.detach().cpu().numpy()
            perfs.append(roc_auc_score(link_labels, link_probs))
        return perfs


    best_val_perf = test_perf = 0
    for epoch in range(1, num_epochs+1):
        train_loss = train()
        val_perf, tmp_test_perf = test()
        if val_perf > best_val_perf:
            best_val_perf = val_perf
            test_perf = tmp_test_perf
        log = 'Epoch: {:03d}, Loss: {:.4f}, Val: {:.4f}, Test: {:.4f}'
        print(log.format(epoch, train_loss, val_perf, tmp_test_perf))


    optimizer_attr = torch.optim.Adam([
        dict(params=model.attr.parameters(), weight_decay=5e-4),
    ], lr=lr)

    def train_attr():
        model.train()
        optimizer_attr.zero_grad()

        x, pos_edge_index = data.x, data.train_pos_edge_index

        _edge_index, _ = remove_self_loops(pos_edge_index)
        pos_edge_index_with_self_loops, _ = add_self_loops(_edge_index,
                                                        num_nodes=x.size(0))

        neg_edge_index = negative_sampling(
            edge_index=pos_edge_index_with_self_loops, num_nodes=x.size(0),
            num_neg_samples=pos_edge_index.size(1))

        F.nll_loss(model(pos_edge_index, neg_edge_index)[1][data.train_mask], labels[data.train_mask]).backward()
        optimizer_attr.step()


    @torch.no_grad()
    def test_attr():
        model.eval()
        accs = []
        m = ['train_mask', 'val_mask', 'test_mask']
        i = 0
        for _, mask in data('train_mask', 'val_mask', 'test_mask'):

            if (m[i] == 'train_mask') :
                x, pos_edge_index = data.x, data.train_pos_edge_index

                _edge_index, _ = remove_self_loops(pos_edge_index)
                pos_edge_index_with_self_loops, _ = add_self_loops(_edge_index,
                                                                num_nodes=x.size(0))

                neg_edge_index = negative_sampling(
                    edge_index=pos_edge_index_with_self_loops, num_nodes=x.size(0),
                    num_neg_samples=pos_edge_index.size(1))
            else:
                pos_edge_index, neg_edge_index = [
                index for _, index in data("{}_pos_edge_index".format(m[i].split("_")[0]),
                                        "{}_neg_edge_index".format(m[i].split("_")[0]))
                ]
            _, logits, _, _ = model(pos_edge_index, neg_edge_index)

            pred = logits[mask].max(1)[1]
            #acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()
            #accs.append(acc)

            macro = f1_score((data.y[mask]).cpu().numpy(), pred.cpu().numpy(),average='macro')
            accs.append(macro)

            i+=1
        return accs

    if True:
        best_val_acc = test_acc = 0
        for epoch in range(1, finetune_epochs+1):
            train_attr()
            train_acc, val_acc, tmp_test_acc = test_attr()
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                test_acc = tmp_test_acc
            log = 'Epoch: {:03d}, Train: {:.4f}, Val: {:.4f}, Test: {:.4f}'
            print(log.format(epoch, train_acc, val_acc, tmp_test_acc))
    return model

In [None]:
model_135 = sim(1.3)
model_1 = sim(1.2)
model_0 = sim(0)

In [None]:
def feat(model):
    model.train()
    
    dataset = "Cora"
    path = osp.join('..', 'data', dataset)
    dataset = Planetoid(path, dataset, T.NormalizeFeatures())
    data = dataset[0]
    gdc = False
    if gdc:
        gdc = T.GDC(self_loop_weight=1, normalization_in='sym',
                    normalization_out='col',
                    diffusion_kwargs=dict(method='ppr', alpha=0.05),
                    sparsification_kwargs=dict(method='topk', k=128,
                                            dim=0), exact=True)
        data = gdc(data)

    labels = data.y.cuda()
    edge_index, edge_weight = data.edge_index.cuda(), data.edge_attr

    print(labels.size())
    # Train/validation/test
    data = train_test_split_edges(data)
    
    x, pos_edge_index = data.x, data.train_pos_edge_index

    _edge_index, _ = remove_self_loops(pos_edge_index)
    pos_edge_index_with_self_loops, _ = add_self_loops(_edge_index,
                                                    num_nodes=x.size(0))

    neg_edge_index = negative_sampling(
        edge_index=pos_edge_index_with_self_loops, num_nodes=x.size(0),
        num_neg_samples=pos_edge_index.size(1))

    return model(pos_edge_index.cuda(), neg_edge_index.cuda())[-1]

In [None]:
feat0 = feat(model_0)
feat1 = feat(model_1)
feat135 = feat(model_135)

In [None]:
dataset = "Cora"
path = osp.join('..', 'data', dataset)
dataset = Planetoid(path, dataset, T.NormalizeFeatures())
data = dataset[0]
labels = data.y.detach().cpu().numpy()

In [None]:
from sklearn.manifold import TSNE
tsne = TSNE(n_components=2, verbose=1, perplexity=100, n_iter=1300)
tsne_0 = tsne.fit_transform(feat0.cpu().detach().numpy())
tsne_1 = tsne.fit_transform(feat1.cpu().detach().numpy())
tsne_135 = tsne.fit_transform(feat135.cpu().detach().numpy())

import seaborn as sns
sns.set_style("whitegrid", {'axes.grid' : False})
sns.set_style("dark")
sns.set(rc={'figure.figsize':(8,8)})
palette = sns.color_palette("bright", 7)
ax = sns.scatterplot(tsne_results[:,0], tsne_results[:,1], hue=lab.detach().cpu().numpy(), legend='full', palette=palette)
ax.grid(False)
ax.patch.set_facecolor('white')
ax.set_axis_off()

In [None]:
data

In [None]:
import networkx as nx
import seaborn as sns
import matplotlib.pyplot as plt
dataset = "Cora"
path = osp.join('..', 'data', dataset)
dataset = Planetoid(path, dataset)
data = dataset[0]
el = data.edge_index.cpu().numpy()
G = nx.Graph([(el[0,i],el[1,i]) for i in range(el.shape[1])])

In [None]:
pos0 = {}
for i in range(2708):
    pos0[i]=[tsne_0[i,0],tsne_0[i,1]]
    
pos135 = {}
for i in range(2708):
    pos135[i]=[tsne_135[i,0],tsne_135[i,1]]
    
pos1 = {}
for i in range(2708):
    pos1[i]=[tsne_1[i,0],tsne_1[i,1]]

In [None]:
sns.set(rc={'figure.figsize':(20,10)})

fig, axs = plt.subplots(ncols=2)

palette = sns.color_palette("tab20",7)
nx.draw_networkx_edges(G, pos0, alpha=0.1,edge_color='b',style='solid',ax=axs[0])
sns.scatterplot(tsne_0[:,0], tsne_0[:,1], hue=labels, legend=None, palette=palette,ax=axs[0],s=70, edgecolor="black")

axs[0].grid(False)
axs[0].patch.set_facecolor('white')
axs[0].set(xticks=[],yticks=[])
axs[0].set_xlabel(r'$\lambda = 0$', fontsize=17)

"""
nx.draw_networkx_edges(G, pos1, alpha=0.1,edge_color='b',style='solid',ax=axs[1])
sns.scatterplot(tsne_1[:,0], tsne_1[:,1], hue=labels, legend=None, palette=palette,ax=axs[1],s=70, edgecolor="black")

axs[1].grid(False)
axs[1].patch.set_facecolor('white')
axs[1].set(xticks=[],yticks=[])
axs[1].set_xlabel(r'$\lambda = 1.2$', fontsize=17)
"""

nx.draw_networkx_edges(G, pos135, alpha=0.1,edge_color='b',style='solid',ax=axs[1])
sns.scatterplot(tsne_135[:,0], tsne_135[:,1], hue=labels, legend=None, palette=palette,ax=axs[1],s=70, edgecolor="black")

axs[1].grid(False)
axs[1].patch.set_facecolor('white')
axs[1].set(xticks=[],yticks=[])
axs[1].set_xlabel(r'$\lambda = 1.3$', fontsize=17)

plt.show()

fig.savefig('cora.pdf', bbox_inches='tight') 

In [None]:
dataset = "Cora"
path = osp.join('..', 'data', dataset)
dataset = Planetoid(path, dataset, T.NormalizeFeatures())
data = dataset[0]

In [None]:
data.edge_index