In [None]:
import random
import wandb
import torch
import torch.nn.functional as F
import torch_geometric.transforms as T 
from torch_geometric.seed import seed_everything
from query_strategies import *
from augmentation import *
from model import *
from model_wrapper import *
from trainers import *
from util import *

from hivegraph.contrastive.grace import GRACE

from GRACE_new import GRACENew

from GRACE_new2 import GRACENew2, GRACEEncoder
from copy import deepcopy

In [29]:
wandb.login()

True

In [30]:
random.seed(0)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [31]:
# dataset = Planetoid(root='/tmp/Cora', name='Cora')
# generate_balanced_data_splits(dataset,10,"data_splits\\cora_splits")

In [32]:
DROPOUT= 0.3
NUM_PASSES = 100
BUDGET = 150
EPOCHS = 100
SIGNIFICANCE_ITERATIONS = 10

NOISE_PROB = 0.4
NOISE_LEVEL = 0.5
LATENT_NOISE_LEVEL = 0.25

In [33]:
# Augmentations
drop_edge = DropEdge(DROPOUT)
noise_feature_all = NoiseFeature(NOISE_LEVEL, 1)
noise_feature_col = NoiseFeature(NOISE_LEVEL, NOISE_PROB, "col")
noise_feature_row = NoiseFeature(NOISE_LEVEL, NOISE_PROB, "row")
noise_latent = NoiseLatent(LATENT_NOISE_LEVEL)

mask_feature_col = MaskFeature(DROPOUT, "col")

costa = COSTA()

drop_edge_noise_all = T.Compose([drop_edge, noise_feature_all])
drop_edge_noise_col = T.Compose([drop_edge, noise_feature_col])
drop_edge_noise_row = T.Compose([drop_edge, noise_feature_row])
drop_edge_mask_col = T.Compose([drop_edge, mask_feature_col])

In [34]:
# Weighted augmentations
edge_weights = torch.load("data/drop_weights.pt").to(device)
feature_weights = torch.load("data/feature_weights.pt").to(device)

drop_edge_weighted = DropEdgeWeighted(edge_weights, DROPOUT, threshold=0.7)
drop_feature_weighted = MaskFeatureWeighted(feature_weights, 0.2)
noise_feature_weighted = NoiseFeatureWeighted(feature_weights, NOISE_LEVEL, th=0.7)

weighted_augmentation = T.Compose([drop_edge_weighted, noise_feature_weighted])


  edge_weights = torch.load("data/drop_weights.pt").to(device)
  feature_weights = torch.load("data/feature_weights.pt").to(device)


In [None]:
# Strategies
random_query = RandomQuery()
entropy_query = EntropyQuery()

augment_sum_entropy = AugmentGraphSumEntropyQuery(drop_edge_noise_all, NUM_PASSES,0.0)
augment_logit_change = AugmentGraphLogitChange(drop_edge_noise_all, NUM_PASSES,0.1)
augment_latent = AugmentGraphSumQueryLatent(noise_latent, NUM_PASSES)

augment_sum_entropy_with_original = AugmentGraphSumEntropyQuery(drop_edge_noise_all, NUM_PASSES, 1.0)
# augment_sum_entropy_with_original = AugmentGraphSumEntropyQuery(drop_edge_mask_col, NUM_PASSES, 1.0)

augment_sum_entropy_ratio = AugmentGraphRatioEntropyQuery(drop_edge_noise_all, NUM_PASSES)


augment_sum_entropy_with_original_mask = AugmentGraphSumEntropyQuery(drop_edge_mask_col, NUM_PASSES, 1.0)

contrastive_minmax =  ContrastiveMinMax()

augment_latent_logit_change = AugmentGraphLogitChangeLatent(costa, NUM_PASSES)

augment_expected_graph = AugmentGraphExpectedGraph(drop_edge_noise_all, NUM_PASSES)
augment_sum_entropy_neighbor =AugmentGraphSumEntropyQueryNeighborhood(drop_edge_noise_all, NUM_PASSES, 1,1.0)

acc_change = DropNodeAccChange()

In [36]:
weighted_augment_sum_entropy_with_original = AugmentGraphSumEntropyQuery(weighted_augmentation, NUM_PASSES, 1.0)
weighted_augment_logit_change = AugmentGraphLogitChange(weighted_augmentation, NUM_PASSES,0.1)

In [37]:
data_splits = [torch.load(f"data_splits/split_{i}.pt") for i in range(10)]

  data_splits = [torch.load(f"data_splits/split_{i}.pt") for i in range(10)]


In [38]:
def reset_models(original_models, trained_model):
    for original_model, trained_model in zip(original_models, trained_model):
        trained_model.load_state_dict(original_model.state_dict())
        trained_model.optimizer.load_state_dict(original_model.optimizer.state_dict())

In [39]:
RETRAIN_MODELS = False

In [40]:
wandb.finish()

In [None]:
STRATEGIES = [  augment_sum_entropy_with_original]
# STRATEGIES = noises_latent
final_accs = torch.zeros((len(STRATEGIES), BUDGET, SIGNIFICANCE_ITERATIONS))
final_auc = torch.zeros((len(STRATEGIES), BUDGET, SIGNIFICANCE_ITERATIONS))


for strategy_ix, strategy in enumerate(STRATEGIES):
    print(f"Strategy: {strategy}")
    init_wandb(strategy, f"NOISE_COL_TRAIN", "CORA")
    data_splits = [torch.load(f"data_splits/split_{i}.pt") for i in range(10)]
    

    wrapped_models = []
    original_wrapped_models = []
    # Pretraining
    for si in range(SIGNIFICANCE_ITERATIONS):
        dataset = data_splits[si].to(device)
        seed_everything(si)
        hidden = 128
        projection_dim = 128
        num_layers = 2
        
        
        encoder_module = GRACEEncoder(
            in_channels=dataset.num_features,
            out_channels=hidden
        )
        
        projection_head = torch.nn.Sequential(
            torch.nn.Linear(hidden, projection_dim),
            torch.nn.ELU(),
            torch.nn.Linear(projection_dim, projection_dim),
        )

        model = GRACENew2(encoder_module=encoder_module,
                          projection_head=projection_head,
                          augmentor1=train_augmentor,
                          augmentor2=train_augmentor).to(device)
        
        train_augmentor = drop_edge_noise_col
        optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)
        wrapped_model = GRACE2ModelWrapper(model,optimizer)
        original_wrapped_models.append(deepcopy(wrapped_model))

        trainer = Trainer()
        trainer.train(wrapped_model, dataset, 200)
        wrapped_models.append(wrapped_model)
        print(f"Model {si} trained")
        
    for b in range(1,BUDGET+1):

        budget_accuracies = []
        budget_aucs = []
        for si in range(SIGNIFICANCE_ITERATIONS):
            seed_everything(si)
            dataset = data_splits[si].to(device)            
            wrapped_model = wrapped_models[si]
            print(f"{b} - {si} - {strategy}")
            
            
            # ALWAYS CALL TEST BEFORE STRATEGY
            acc = trainer.test(wrapped_model, dataset)
            chosen_node = strategy(wrapped_model, dataset, dataset.train_pool)
            dataset.train_pool[chosen_node] = False
            dataset.train_mask[chosen_node] = True
            budget_accuracies.append(acc)
            
            if RETRAIN_MODELS:
                print(f"Model {si} retraining...")
                reset_models(original_wrapped_models, wrapped_models)
                trainer.train(wrapped_model, dataset, 200)
            
            wrapped_model.reset_predictor()


        budget_accuracies = torch.tensor(budget_accuracies)
        budget_aucs = torch.tensor(budget_aucs)
        final_accs[strategy_ix, b-1, :] = budget_accuracies
        # final_auc[strategy_ix, b-1, :] = budget_aucs
        m = budget_accuracies.mean()
        std = budget_accuracies.std()
        wandb.log({"accuracy_mean": m.item(), "step":b})
        wandb.log({"accuracy_std": std.item(), "step": b})
wandb.finish()


Strategy: AugmentGraphSumEntropyQuery([DropEdge, MaskFeature], original_weight=1.0)


  data_splits = [torch.load(f"data_splits/split_{i}.pt") for i in range(10)]


Model 0 trained
Model 1 trained
Model 2 trained
Model 3 trained
Model 4 trained
Model 5 trained
Model 6 trained
Model 7 trained
Model 8 trained
Model 9 trained
1 - 0 - AugmentGraphSumEntropyQuery([DropEdge, MaskFeature], original_weight=1.0)
1 - 1 - AugmentGraphSumEntropyQuery([DropEdge, MaskFeature], original_weight=1.0)
1 - 2 - AugmentGraphSumEntropyQuery([DropEdge, MaskFeature], original_weight=1.0)
1 - 3 - AugmentGraphSumEntropyQuery([DropEdge, MaskFeature], original_weight=1.0)
1 - 4 - AugmentGraphSumEntropyQuery([DropEdge, MaskFeature], original_weight=1.0)
1 - 5 - AugmentGraphSumEntropyQuery([DropEdge, MaskFeature], original_weight=1.0)
1 - 6 - AugmentGraphSumEntropyQuery([DropEdge, MaskFeature], original_weight=1.0)
1 - 7 - AugmentGraphSumEntropyQuery([DropEdge, MaskFeature], original_weight=1.0)
1 - 8 - AugmentGraphSumEntropyQuery([DropEdge, MaskFeature], original_weight=1.0)
1 - 9 - AugmentGraphSumEntropyQuery([DropEdge, MaskFeature], original_weight=1.0)
2 - 0 - AugmentGraph

0,1
accuracy_mean,▁▂▄▅▅▅▅▅▅▅▅▆▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇██████
accuracy_std,█▄▄▃▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁
step,▁▁▁▂▂▂▂▂▃▃▃▃▃▃▃▄▄▄▄▄▄▅▅▆▆▆▆▆▆▆▇▇▇▇▇▇████

0,1
accuracy_mean,0.8102
accuracy_std,0.00827
step,150.0


In [None]:
wandb.finish()

In [None]:
train_mask = dataset.train_mask.detach().cpu().numpy()

In [None]:
np.tile(train_mask,11).shape

In [None]:
ol = []
for _ in range(11):
    ol.append(wrapped_model(dataset))
ol = torch.cat(ol)

In [None]:
ol.shape