In [1]:
import sys
import os
sys.path.append(os.path.join('../..'))

################################################
# Arguments
################################################

import argparse

parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=123, help='Random seed for model')
parser.add_argument('--dataset', type=str, default='cora', help='dataset')

args = parser.parse_args("")

################################################
# Environment
################################################

import torch
import numpy as np

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

np.random.seed(args.seed)
torch.manual_seed(args.seed)

if device != 'cpu':
    torch.cuda.manual_seed(args.seed)

print('==== Environment ====')
print(f'  torch version: {torch.__version__}')
print(f'  device: {device}')
print(f'  torch seed: {args.seed}')

################################################
# Dataset
################################################

from Utils import GraphData

print(f'==== Dataset: {args.dataset} ====')

graph = GraphData.getGraph("../../Datasets", args.dataset, "gcn", args.seed, device)
graph.summarize()

==== Environment ====
  torch version: 1.10.2
  device: cpu
  torch seed: 123
==== Dataset: cora ====
Loading cora dataset...

[i] Dataset Summary: 
	adj shape: [2708, 2708]
	feature shape: [2708, 1433]
	num labels: 7
	split seed: 123
	train|val|test: 140|500|1000


In [2]:
def generate_adj_view(adj, p_drop):
    # Randomly select XX% of edges
    p = torch.full_like(adj, p_drop).float()
    modifications = torch.bernoulli(p)

    modifications = modifications * adj # Mask it to only modify when edges are present (removal only)

    return adj + modifications - (2 * adj * modifications) # XOR

generate_adj_view(graph.adj, 0.8)

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])

In [3]:
def generate_feat_view(feat: torch.tensor, p_drop):
    # Randomly select XX% of features
    p = torch.full([feat.shape[1]], p_drop).float()
    target_feats = (torch.bernoulli(p) == 1).nonzero().squeeze()

    return feat.index_fill(0, target_feats, 0)

generate_feat_view(graph.features, 0.5).sum()

tensor(36456.)

In [40]:
def generate_view(g, p_edge, p_feat):
    return {
        "adj": generate_adj_view(g.adj, p_edge),
        "features": generate_feat_view(g.features, p_feat)
    }

In [41]:
import torch.nn.functional as F

def sim(z1: torch.Tensor, z2: torch.Tensor):
    z1 = F.normalize(z1)
    z2 = F.normalize(z2)
    return torch.mm(z1, z2.t())

def s_loss(z1: torch.Tensor, z2: torch.Tensor):
    f = lambda x: torch.exp(x / 0.4)
    refl_sim = f(sim(z1, z1))
    between_sim = f(sim(z1, z2))

    return -torch.log(
            between_sim.diag()
            / (refl_sim.sum(1) + between_sim.sum(1) - refl_sim.diag()))

In [67]:
from Models.GCN import GCN

model = GCN(
    input_features=graph.features.shape[1],
    output_classes=graph.labels.max().item()+1,
    hidden_layers=128,
    device=device,
    lr=0.05,
    dropout=0.4,
    weight_decay=0.00001,
    name=f"baseline"
).to(device)

########################
#region GRACE
optimizer = torch.optim.Adam(model.parameters(), lr=model.lr, weight_decay=model.weight_decay)

for epoch in range(200):
    optimizer.zero_grad()

    # Generate views
    view_A = generate_view(graph, 0.2, 0.2)
    view_B = generate_view(graph, 0.2, 0.2)

    # Obtain embeddings
    emb_A = model(view_A["features"], view_A["adj"])
    emb_B = model(view_B["features"], view_B["adj"])

    # Contrastive objective
    obj = (s_loss(emb_A, emb_B) + s_loss(emb_B, emb_A)).mean() * 0.5
    
    # Update params
    obj.backward()
    optimizer.step()

    print(obj.item())

#endregion
########################


8.59684944152832
8.578678131103516
8.525972366333008
8.434473991394043
8.363480567932129
8.30350112915039
8.25337028503418
8.222042083740234
8.182612419128418
8.148009300231934
8.123971939086914
8.098261833190918
8.079007148742676
8.061497688293457
8.036906242370605
8.029257774353027
8.006060600280762
7.982293605804443
7.970405101776123
7.964942455291748
7.930498123168945
7.9111456871032715
7.899277210235596
7.879704475402832
7.86409854888916
7.840716361999512
7.820860862731934
7.807607173919678
7.804925441741943
7.781295299530029
7.763625144958496
7.748591899871826
7.734733581542969
7.726871967315674
7.707786560058594
7.695126533508301
7.677942752838135
7.667926788330078
7.658021450042725
7.646245002746582
7.640690803527832
7.630488872528076
7.614439487457275
7.6083784103393555
7.594449520111084
7.582249164581299
7.569818496704102
7.5689826011657715
7.5586347579956055
7.548610210418701
7.540104866027832
7.527491569519043
7.523734092712402
7.512577533721924
7.501918315887451
7.49370479

In [76]:
#eval
import Utils.Metrics as Metrics

acc = Metrics.acc(model(graph.features, graph.adj).log_softmax(), graph.labels)
print(acc)

0.10709010064601898
