In [1]:
import torch
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
from torch_geometric.data.lightning import LightningDataset
import random
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch_geometric.transforms as T 

from query_strategies import *
from augmentation import *
from model import *

import wandb

import lightning as L
from torch_geometric.data.lightning import LightningDataset
import logging
import warnings
from util import *
from torch_geometric.data import Data, DataLoader
from trainers import *
from torch_geometric.utils import to_dense_adj
from sklearn.decomposition import PCA



In [2]:
warnings.filterwarnings("ignore")
logging.getLogger("lightning.pytorch").setLevel(logging.ERROR)


In [3]:
wandb.login()

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mbanfizsombor1999[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [4]:
random.seed(0)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [5]:
# dataset = Planetoid(root='/tmp/Cora', name='Cora')
# generate_balanced_data_splits(dataset,10,"data_splits\\cora_splits")

In [6]:
DROPOUT= 0.3
NUM_PASSES = 10
BUDGET = 150
EPOCHS = 100
SIGNIFICANCE_ITERATIONS = 10

NOISE_PROB = 0.4
NOISE_LEVEL = 0.5

In [7]:
# Augmentations
drop_edge = DropEdge(DROPOUT)
noise_feature_all = NoiseFeature(NOISE_LEVEL, 1)
noise_feature_col = NoiseFeature(NOISE_LEVEL, NOISE_PROB, "col")
noise_feature_row = NoiseFeature(NOISE_LEVEL, NOISE_PROB, "row")
noise_latent = NoiseLatent(NOISE_LEVEL)

drop_edge_noise_all = T.Compose([drop_edge, noise_feature_all])
drop_edge_noise_col = T.Compose([drop_edge, noise_feature_col])
drop_edge_noise_row = T.Compose([drop_edge, noise_feature_row])

In [8]:
# Strategies

random_query = RandomQuery()
entropy_query = EntropyQuery()

augment_sum_entropy = AugmentGraphSumEntropyQuery(drop_edge_noise_all, NUM_PASSES,0.0)
augment_logit_change = AugmentGraphLogitChange(drop_edge_noise_all, NUM_PASSES,1.0)
augment_latent = AugmentGraphSumQueryLatent(noise_latent, NUM_PASSES)
augment_sum_entropy_with_original = AugmentGraphSumEntropyQuery(drop_edge_noise_all, NUM_PASSES, 1/NUM_PASSES)
augment_sum_entropy_with_original_contrastive = AugmentGraphSumEntropyQueryContrastive(drop_edge_noise_all, NUM_PASSES,1/NUM_PASSES)

In [9]:
data_splits = [torch.load(f"data_splits\\cora_splits\\split_{i}.pt") for i in range(10)]


In [10]:
def init_wandb(query_strategy,run):
  config={
    "learning_rate": 0.01,
    "architecture": "GCN",
    "dataset": "CORA",
    "epochs": 200,
    "strategy": str(query_strategy)
    }
  
  if hasattr(query_strategy, "augmentation_fn"):
    augmentation = query_strategy.augmentation_fn
    config["augmentations"] = []
    if isinstance(augmentation, T.Compose):
      transforms = augmentation.transforms
    else:
      transforms = [augmentation]
      
    for t in transforms:
      aug_config = {}
      aug_config["name"] = str(t)
      aug_config["hyperparameters"] = t.__hyperparameters__()
      config["augmentations"].append(aug_config)
  
  wandb.init(
    # Set the project where this run will be logged
    project="graph-active-learning",
    # We pass a run name (otherwise it’ll be randomly assigned, like sunshine-lollypop-10)
    name=f"{query_strategy}_{run}",
    # Track hyperparameters and run metadata
    config=config)
  return config

In [11]:
# config = {
#     "model": GCN_Contrastive(num_features, 7).to(device),
#     "loss_fn": nt_xent_loss,
#     "optimizer": torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4),
#     "trainer": ContrastiveTrainer(model, optimizer,loss_fn,drop_edge),
#     "strategy": EntropyQueryContrastive()
# }
# configs = [config]

In [12]:
dataset = data_splits[0].to(device)
num_features = dataset.num_features
num_classes = dataset.y.max().item() + 1
model = GCN(num_features,num_classes).to(device)

loss_fn = F.nll_loss
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
trainer = Trainer(model, optimizer, loss_fn)
trainer.model.to(device)


GCN(
  (convs): ModuleList(
    (0): GCNConv(1433, 16)
    (1): GCNConv(16, 7)
  )
  (projection_head): Identity()
)

In [15]:
STRATEGIES = [ augment_sum_entropy_with_original]
# STRATEGIES = noises_latent
final_accs = torch.zeros((len(STRATEGIES), BUDGET, SIGNIFICANCE_ITERATIONS))
final_auc = torch.zeros((len(STRATEGIES), BUDGET, SIGNIFICANCE_ITERATIONS))

# data_splits = [generate_random_data_split(dataset,10,500) for _ in range(SIGNIFICANCE_ITERATIONS)]
data_splits = [torch.load(f"data_splits\\cora_splits\\split_{i}.pt") for i in range(10)]

for strategy_ix, strategy in enumerate(STRATEGIES):
    print(f"Strategy: {strategy}")
    init_wandb(str(strategy), 4)
    
    
    
    for b in range(1,BUDGET+1):
        
        budget_accuracies = []
        budget_aucs = []
        
        
        
        for si in range(SIGNIFICANCE_ITERATIONS):
            dataset = data_splits[si].to(device)
            
            num_features = dataset.num_features
            num_classes = dataset.y.max().item() + 1
            
            print(f"{b} - {si} - {strategy}")
            model = GCN(num_features,num_classes).to(device)
            loss_fn = F.nll_loss
            optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
            trainer = Trainer(model, optimizer, loss_fn)
            trainer.model.to(device)
            trainer.train(dataset,EPOCHS)
            
            
            acc, auc = trainer.test(dataset)
            
            budget_accuracies.append(acc)
            # budget_aucs.append(auc)
            
            query_node_idx = strategy(model,dataset,dataset.train_pool)
            print(f'\tQuery node: {query_node_idx}')
            
            dataset.train_mask[query_node_idx] = True
            dataset.train_pool[query_node_idx] = False

        budget_accuracies = torch.tensor(budget_accuracies)
        budget_aucs = torch.tensor(budget_aucs)
        final_accs[strategy_ix, b-1, :] = budget_accuracies
        # final_auc[strategy_ix, b-1, :] = budget_aucs
        m = budget_accuracies.mean()
        std = budget_accuracies.std()
        wandb.log({"accuracy_mean": m.item(), "step":b})
        wandb.log({"accuracy_std": std.item(), "step": b})
wandb.finish()


Strategy: AugmentGraphSumEntropyQuery([DropEdge, NoiseFeature], original_weight=0.1)


1 - 0 - AugmentGraphSumEntropyQuery([DropEdge, NoiseFeature], original_weight=0.1)
	Query node: 1070
1 - 1 - AugmentGraphSumEntropyQuery([DropEdge, NoiseFeature], original_weight=0.1)
	Query node: 823
1 - 2 - AugmentGraphSumEntropyQuery([DropEdge, NoiseFeature], original_weight=0.1)
	Query node: 1453
1 - 3 - AugmentGraphSumEntropyQuery([DropEdge, NoiseFeature], original_weight=0.1)
	Query node: 219
1 - 4 - AugmentGraphSumEntropyQuery([DropEdge, NoiseFeature], original_weight=0.1)
	Query node: 1127
1 - 5 - AugmentGraphSumEntropyQuery([DropEdge, NoiseFeature], original_weight=0.1)
	Query node: 219
1 - 6 - AugmentGraphSumEntropyQuery([DropEdge, NoiseFeature], original_weight=0.1)
	Query node: 1614
1 - 7 - AugmentGraphSumEntropyQuery([DropEdge, NoiseFeature], original_weight=0.1)
	Query node: 1525
1 - 8 - AugmentGraphSumEntropyQuery([DropEdge, NoiseFeature], original_weight=0.1)
	Query node: 815
1 - 9 - AugmentGraphSumEntropyQuery([DropEdge, NoiseFeature], original_weight=0.1)
	Query node:

KeyboardInterrupt: 

In [16]:
wandb.finish()

0,1
accuracy_mean,▁▂▂▃▄▄▄▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇█▇█████████████
accuracy_std,▅▆█▆▅▆▆▄▃▅▅▄▃▃▂▃▃▃▃▃▃▂▁▂▁▁▂▁▂▁▂▁▂▁▁▁▁
step,▁▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▃▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███

0,1
accuracy_mean,0.7396
accuracy_std,0.01262
step,37.0
