In [1]:
import random
import wandb
import torch
import torch.nn.functional as F
import torch_geometric.transforms as T 

from query_strategies import *
from augmentation import *
from model import *
from model_wrapper import *
from trainers import *
from util import *

from hivegraph.contrastive.grace import GRACE






In [2]:
wandb.login()

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mbanfizsombor1999[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [3]:
random.seed(0)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
# dataset = Planetoid(root='/tmp/Cora', name='Cora')
# generate_balanced_data_splits(dataset,10,"data_splits\\cora_splits")

In [5]:
DROPOUT= 0.3
NUM_PASSES = 10
BUDGET = 150
EPOCHS = 100
SIGNIFICANCE_ITERATIONS = 10

NOISE_PROB = 0.4
NOISE_LEVEL = 0.5

In [6]:
# Augmentations
drop_edge = DropEdge(DROPOUT)
noise_feature_all = NoiseFeature(NOISE_LEVEL, 1)
noise_feature_col = NoiseFeature(NOISE_LEVEL, NOISE_PROB, "col")
noise_feature_row = NoiseFeature(NOISE_LEVEL, NOISE_PROB, "row")
noise_latent = NoiseLatent(NOISE_LEVEL)

drop_edge_noise_all = T.Compose([drop_edge, noise_feature_all])
drop_edge_noise_col = T.Compose([drop_edge, noise_feature_col])
drop_edge_noise_row = T.Compose([drop_edge, noise_feature_row])

In [7]:
# Strategies

random_query = RandomQuery()
entropy_query = EntropyQuery()

augment_sum_entropy = AugmentGraphSumEntropyQuery(drop_edge_noise_all, NUM_PASSES,0.0)
augment_logit_change = AugmentGraphLogitChange(drop_edge_noise_all, NUM_PASSES,1.0)
augment_latent = AugmentGraphSumQueryLatent(noise_latent, NUM_PASSES)
augment_sum_entropy_with_original = AugmentGraphSumEntropyQuery(drop_edge_noise_all, NUM_PASSES, 1/NUM_PASSES)


In [8]:
data_splits = [torch.load(f"data_splits\\cora_splits\\split_{i}.pt") for i in range(10)]


  data_splits = [torch.load(f"data_splits\\cora_splits\\split_{i}.pt") for i in range(10)]


In [9]:
def init_wandb(query_strategy,run):
  config={
    "learning_rate": 0.01,
    "architecture": "GCN",
    "dataset": "CORA",
    "epochs": 200,
    "strategy": str(query_strategy)
    }
  
  if hasattr(query_strategy, "augmentation_fn"):
    augmentation = query_strategy.augmentation_fn
    config["augmentations"] = []
    if isinstance(augmentation, T.Compose):
      transforms = augmentation.transforms
    else:
      transforms = [augmentation]
      
    for t in transforms:
      aug_config = {}
      aug_config["name"] = str(t)
      aug_config["hyperparameters"] = t.__hyperparameters__()
      config["augmentations"].append(aug_config)
  
  wandb.init(
    # Set the project where this run will be logged
    project="graph-active-learning",
    # We pass a run name (otherwise it’ll be randomly assigned, like sunshine-lollypop-10)
    name=f"{query_strategy}_{run}",
    # Track hyperparameters and run metadata
    config=config)
  return config

In [10]:
dataset = data_splits[0].to(device)
num_features = dataset.num_features
num_classes = dataset.y.max().item() + 1
model = GRACE(num_features=dataset.num_features,hidden=128, num_layers=2, drop_edge_rate_1=0.3,drop_edge_rate_2=0.3,drop_feature_rate_1=0.3,drop_feature_rate_2=0.3).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

wrapped_model = GRACEModelWrapper(model,optimizer)
trainer = Trainer()

trainer.train(wrapped_model,dataset,200)

8.56876277923584
8.561138153076172
7.802226543426514
8.038267135620117
7.767679691314697
7.613288402557373
7.5262370109558105
7.4568939208984375
7.4195756912231445
7.442411422729492
7.400008201599121
7.364684581756592
7.37878942489624
7.387082099914551
7.334733486175537
7.378671169281006
7.342468738555908
7.341065883636475
7.346648693084717
7.290910243988037
7.271796703338623
7.265132904052734
7.250162124633789
7.233466148376465
7.2478108406066895
7.260340213775635
7.2337565422058105
7.2294745445251465
7.209474086761475
7.23295783996582
7.226389408111572
7.201923847198486
7.217424392700195
7.208724498748779
7.202087879180908
7.191298007965088
7.197394371032715
7.194523334503174
7.173399448394775
7.149122714996338
7.206968307495117
7.138422012329102
7.134242057800293
7.161535263061523
7.136783599853516
7.146125793457031
7.174375534057617
7.1468729972839355
7.144186496734619
7.106014728546143
7.138978958129883
7.135040760040283
7.104761600494385
7.115084171295166
7.089709281921387
7.1160

In [10]:
# config = {
#     "model": GCN_Contrastive(num_features, 7).to(device),
#     "loss_fn": nt_xent_loss,
#     "optimizer": torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4),
#     "trainer": ContrastiveTrainer(model, optimizer,loss_fn,drop_edge),
#     "strategy": EntropyQueryContrastive()
# }
# configs = [config]

In [11]:
wandb.finish()

In [None]:
STRATEGIES = [ noise_latent]
# STRATEGIES = noises_latent
final_accs = torch.zeros((len(STRATEGIES), BUDGET, SIGNIFICANCE_ITERATIONS))
final_auc = torch.zeros((len(STRATEGIES), BUDGET, SIGNIFICANCE_ITERATIONS))

# data_splits = [generate_random_data_split(dataset,10,500) for _ in range(SIGNIFICANCE_ITERATIONS)]
data_splits = [torch.load(f"data_splits\\cora_splits\\split_{i}.pt") for i in range(10)]

for strategy_ix, strategy in enumerate(STRATEGIES):
    print(f"Strategy: {strategy}")
    init_wandb(str(strategy), 4)
    
    dataset = data_splits[0].to(device)
    num_features = dataset.num_features
    num_classes = dataset.y.max().item() + 1
    model = GRACE(num_features=dataset.num_features,hidden=128, num_layers=2, drop_edge_rate_1=0.3,drop_edge_rate_2=0.3,drop_feature_rate_1=0.3,drop_feature_rate_2=0.3).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

    wrapped_model = GRACEModelWrapper(model,optimizer)
    trainer = Trainer()
    
    trainer.train(wrapped_model,dataset,200)
    
    for b in range(1,BUDGET+1):
        
        budget_accuracies = []
        budget_aucs = []
        
        
        
        for si in range(SIGNIFICANCE_ITERATIONS):
            wrapped_model.reset_predictor()
            dataset = data_splits[si].to(device)
            
            num_features = dataset.num_features
            num_classes = dataset.y.max().item() + 1
            
            print(f"{b} - {si} - {strategy}")
            # model = GCN(num_features,num_classes).to(device)
            
            # loss_fn = F.nll_loss
            # wrapped_model = ModelWrapper(model,optimizer,loss_fn)
            
            acc = trainer.test(wrapped_model,dataset)
            
            budget_accuracies.append(acc)
            # budget_aucs.append(auc)
            
            query_node_idx = strategy(wrapped_model,dataset,dataset.train_pool)
            print(f'\tQuery node: {query_node_idx}')
            
            dataset.train_mask[query_node_idx] = True
            dataset.train_pool[query_node_idx] = False

        budget_accuracies = torch.tensor(budget_accuracies)
        budget_aucs = torch.tensor(budget_aucs)
        final_accs[strategy_ix, b-1, :] = budget_accuracies
        # final_auc[strategy_ix, b-1, :] = budget_aucs
        m = budget_accuracies.mean()
        std = budget_accuracies.std()
        wandb.log({"accuracy_mean": m.item(), "step":b})
        wandb.log({"accuracy_std": std.item(), "step": b})
wandb.finish()


  data_splits = [torch.load(f"data_splits\\cora_splits\\split_{i}.pt") for i in range(10)]


Strategy: AugmentGraphSumEntropyQuery([DropEdge, NoiseFeature], original_weight=0.1)


1 - 0 - AugmentGraphSumEntropyQuery([DropEdge, NoiseFeature], original_weight=0.1)
	Query node: 563
1 - 1 - AugmentGraphSumEntropyQuery([DropEdge, NoiseFeature], original_weight=0.1)
	Query node: 1502
1 - 2 - AugmentGraphSumEntropyQuery([DropEdge, NoiseFeature], original_weight=0.1)
	Query node: 851
1 - 3 - AugmentGraphSumEntropyQuery([DropEdge, NoiseFeature], original_weight=0.1)
	Query node: 682
1 - 4 - AugmentGraphSumEntropyQuery([DropEdge, NoiseFeature], original_weight=0.1)
	Query node: 1503
1 - 5 - AugmentGraphSumEntropyQuery([DropEdge, NoiseFeature], original_weight=0.1)
	Query node: 215
1 - 6 - AugmentGraphSumEntropyQuery([DropEdge, NoiseFeature], original_weight=0.1)
	Query node: 158
1 - 7 - AugmentGraphSumEntropyQuery([DropEdge, NoiseFeature], original_weight=0.1)
	Query node: 1338
1 - 8 - AugmentGraphSumEntropyQuery([DropEdge, NoiseFeature], original_weight=0.1)
	Query node: 483
1 - 9 - AugmentGraphSumEntropyQuery([DropEdge, NoiseFeature], original_weight=0.1)
	Query node: 1

KeyboardInterrupt: 

In [13]:
wandb.finish()

0,1
accuracy_mean,▄▁▂▃▄▆▆▆▆▇▇▇▇▇▇▇▇▇▇█████████████████████
accuracy_std,▆▆█▇▆▃▅▄▃▄▄▄▅▃▃▃▄▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
step,▁▁▁▁▁▁▁▂▂▂▂▂▂▂▃▃▃▄▄▄▄▄▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇██

0,1
accuracy_mean,0.7683
accuracy_std,0.02162
step,142.0
