In [1]:
import random
import wandb
import torch
import torch.nn.functional as F
import torch_geometric.transforms as T 

from query_strategies import *
from augmentation import *
from model import *
from model_wrapper import *
from trainers import *
from util import *

from hivegraph.contrastive.grace import GRACE




In [2]:
wandb.login()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mbanfizsombor1999[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [3]:
random.seed(0)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
# dataset = Planetoid(root='/tmp/Cora', name='Cora')
# generate_balanced_data_splits(dataset,10,"data_splits\\cora_splits")

In [5]:
DROPOUT= 0.3
NUM_PASSES = 10
BUDGET = 150
EPOCHS = 100
SIGNIFICANCE_ITERATIONS = 10

NOISE_PROB = 0.4
NOISE_LEVEL = 0.5

In [6]:
# Augmentations
drop_edge = DropEdge(DROPOUT)
noise_feature_all = NoiseFeature(NOISE_LEVEL, 1)
noise_feature_col = NoiseFeature(NOISE_LEVEL, NOISE_PROB, "col")
noise_feature_row = NoiseFeature(NOISE_LEVEL, NOISE_PROB, "row")
noise_latent = NoiseLatent(NOISE_LEVEL)

drop_edge_noise_all = T.Compose([drop_edge, noise_feature_all])
drop_edge_noise_col = T.Compose([drop_edge, noise_feature_col])
drop_edge_noise_row = T.Compose([drop_edge, noise_feature_row])

In [None]:
# Strategies

random_query = RandomQuery()
entropy_query = EntropyQuery()

augment_sum_entropy = AugmentGraphSumEntropyQuery(drop_edge_noise_all, NUM_PASSES,0.0)
augment_logit_change = AugmentGraphLogitChange(drop_edge_noise_all, NUM_PASSES,1.0)
augment_latent = AugmentGraphSumQueryLatent(noise_latent, NUM_PASSES)
augment_sum_entropy_with_original = AugmentGraphSumEntropyQuery(drop_edge_noise_all, NUM_PASSES, 0.5)



In [8]:
data_splits = [torch.load(f"data_splits/split_{i}.pt") for i in range(10)]


  data_splits = [torch.load(f"data_splits/split_{i}.pt") for i in range(10)]


In [24]:
dataset = data_splits[0].to(device)
model = GRACE(num_features=dataset.num_features,hidden=128, num_layers=2, drop_edge_rate_1=0.3,drop_edge_rate_2=0.3,drop_feature_rate_1=0.3,drop_feature_rate_2=0.3).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
modelWrapper = GRACEModelWrapper(model,optimizer)
trainer = Trainer()

In [25]:
trainer.train(modelWrapper, dataset, 200)

In [26]:
modelWrapper.train_step(dataset)

tensor(6.8818, device='cuda:0', grad_fn=<MeanBackward0>)

In [27]:
wrapped_model = modelWrapper

In [30]:
wandb.finish()

0,1
accuracy_mean,█▁▃
accuracy_std,▁█▁
step,▁▁▅▅██

0,1
accuracy_mean,0.2622
accuracy_std,0.0598
step,3.0


In [31]:
STRATEGIES = [ augment_latent]
# STRATEGIES = noises_latent
final_accs = torch.zeros((len(STRATEGIES), BUDGET, SIGNIFICANCE_ITERATIONS))
final_auc = torch.zeros((len(STRATEGIES), BUDGET, SIGNIFICANCE_ITERATIONS))

# data_splits = [generate_random_data_split(dataset,10,500) for _ in range(SIGNIFICANCE_ITERATIONS)]

for strategy_ix, strategy in enumerate(STRATEGIES):
    print(f"Strategy: {strategy}")
    init_wandb(str(strategy), 5)
    data_splits = [torch.load(f"data_splits/split_{i}.pt") for i in range(10)]

    for b in range(1,BUDGET+1):

        budget_accuracies = []
        budget_aucs = []
        for si in range(SIGNIFICANCE_ITERATIONS):
            wrapped_model.reset_predictor()
            
            dataset = data_splits[si].to(device)
            
            num_features = dataset.num_features
            num_classes = dataset.y.max().item() + 1
            
            print(f"{b} - {si} - {strategy}")
            # model = GCN(num_features,num_classes).to(device)
            
            # loss_fn = F.nll_loss
            # optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
            # wrapped_model = ModelWrapper(model,optimizer,loss_fn)
            # trainer = Trainer()
            
            # trainer.train(wrapped_model,dataset,200)
            acc = trainer.test(wrapped_model,dataset)
            
            budget_accuracies.append(acc)
            # budget_aucs.append(auc)
            print(wrapped_model.fitted)
            query_node_idx = strategy(wrapped_model,dataset,dataset.train_pool)
            print(f'\tQuery node: {query_node_idx}')
            
            dataset.train_mask[query_node_idx] = True
            dataset.train_pool[query_node_idx] = False

        budget_accuracies = torch.tensor(budget_accuracies)
        budget_aucs = torch.tensor(budget_aucs)
        final_accs[strategy_ix, b-1, :] = budget_accuracies
        # final_auc[strategy_ix, b-1, :] = budget_aucs
        m = budget_accuracies.mean()
        std = budget_accuracies.std()
        wandb.log({"accuracy_mean": m.item(), "step":b})
        wandb.log({"accuracy_std": std.item(), "step": b})
wandb.finish()


Strategy: AugmentGraphSumQueryLatent([NoiseLatent])


  data_splits = [torch.load(f"data_splits/split_{i}.pt") for i in range(10)]


1 - 0 - AugmentGraphSumQueryLatent([NoiseLatent])
True
	Query node: 341
1 - 1 - AugmentGraphSumQueryLatent([NoiseLatent])
True
	Query node: 696
1 - 2 - AugmentGraphSumQueryLatent([NoiseLatent])
True
	Query node: 1287
1 - 3 - AugmentGraphSumQueryLatent([NoiseLatent])
True
	Query node: 1040
1 - 4 - AugmentGraphSumQueryLatent([NoiseLatent])
True
	Query node: 391
1 - 5 - AugmentGraphSumQueryLatent([NoiseLatent])
True
	Query node: 285
1 - 6 - AugmentGraphSumQueryLatent([NoiseLatent])
True
	Query node: 1540
1 - 7 - AugmentGraphSumQueryLatent([NoiseLatent])
True
	Query node: 1103
1 - 8 - AugmentGraphSumQueryLatent([NoiseLatent])
True
	Query node: 1272
1 - 9 - AugmentGraphSumQueryLatent([NoiseLatent])
True
	Query node: 1004
2 - 0 - AugmentGraphSumQueryLatent([NoiseLatent])
True
	Query node: 605
2 - 1 - AugmentGraphSumQueryLatent([NoiseLatent])
True
	Query node: 512
2 - 2 - AugmentGraphSumQueryLatent([NoiseLatent])
True
	Query node: 1134
2 - 3 - AugmentGraphSumQueryLatent([NoiseLatent])
True
	Q

0,1
accuracy_mean,▁▂▂▂▁▁▂▃▃▄▅▅▆▆▆▆▆▇▇▇▇▇▇▇████████████████
accuracy_std,▆▇▇▃▄▅▅███▅▅▅▅▃▃▃▂▂▂▃▂▃▃▃▂▁▃▃▂▂▂▂▂▂▁▁▁▁▁
step,▁▁▁▁▁▂▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▆▇▇████

0,1
accuracy_mean,0.7341
accuracy_std,0.02622
step,150.0


In [11]:
torch.rand((2708, 7)).argmax(dim=1)

tensor([3, 2, 2,  ..., 0, 5, 4])

In [12]:
wandb.finish()