In [None]:
import random
import wandb
import torch
import torch.nn.functional as F
import torch_geometric.transforms as T 

from query_strategies import *
from augmentation import *
from model import *
from model_wrapper import *
from trainers import *
from util import *

from hivegraph.contrastive.grace import GRACE
from GRACE_new import GRACENew

from torch_geometric.seed import seed_everything




In [None]:
wandb.login()

In [None]:
random.seed(0)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
# dataset = Planetoid(root='/tmp/Cora', name='Cora')
# generate_balanced_data_splits(dataset,10,"data_splits\\cora_splits")

In [None]:
DROPOUT= 0.3
NUM_PASSES = 100
BUDGET = 150
EPOCHS = 100
SIGNIFICANCE_ITERATIONS = 10

NOISE_PROB = 0.4
NOISE_LEVEL = 0.5

In [None]:
# Augmentations
drop_edge = DropEdge(DROPOUT)
noise_feature_all = NoiseFeature(NOISE_LEVEL, 1)
noise_feature_col = NoiseFeature(NOISE_LEVEL, NOISE_PROB, "col")
noise_feature_row = NoiseFeature(NOISE_LEVEL, NOISE_PROB, "row")
noise_latent = NoiseLatent(NOISE_LEVEL)

drop_edge_noise_all = T.Compose([drop_edge, noise_feature_all])
drop_edge_noise_col = T.Compose([drop_edge, noise_feature_col])
drop_edge_noise_row = T.Compose([drop_edge, noise_feature_row])

In [None]:
# Strategies

random_query = RandomQuery()
entropy_query = EntropyQuery()

augment_sum_entropy = AugmentGraphSumEntropyQuery(drop_edge_noise_all, NUM_PASSES,0.0)
augment_logit_change = AugmentGraphLogitChange(drop_edge_noise_all, NUM_PASSES,1.0)
augment_latent = AugmentGraphSumQueryLatent(noise_latent, NUM_PASSES)
augment_sum_entropy_with_original = AugmentGraphSumEntropyQuery(drop_edge_noise_all, NUM_PASSES, 1.0)



In [None]:
data_splits = [torch.load(f"data_splits/cora_ml/split_{i}.pt") for i in range(10)]


In [None]:
from GRACE_new2 import GRACENew2

In [None]:
wandb.finish()

In [None]:
STRATEGIES = [ augment_sum_entropy_with_original]
# STRATEGIES = noises_latent
final_accs = torch.zeros((len(STRATEGIES), BUDGET, SIGNIFICANCE_ITERATIONS))
final_auc = torch.zeros((len(STRATEGIES), BUDGET, SIGNIFICANCE_ITERATIONS))

# data_splits = [generate_random_data_split(dataset,10,500) for _ in range(SIGNIFICANCE_ITERATIONS)]

for strategy_ix, strategy in enumerate(STRATEGIES):
    print(f"Strategy: {strategy}")
    init_wandb(str(strategy), "SEMI_SUP_DROP_NOISE_TRAIN", "CORA")
    data_splits = [torch.load(f"data_splits/split_{i}.pt") for i in range(10)]
    for dataset in data_splits:
        dataset.y_train = dataset.y.clone()
    for b in range(1,BUDGET+1):

        budget_accuracies = []
        budget_aucs = []
        for si in range(SIGNIFICANCE_ITERATIONS):
            seed_everything(si)
            dataset = data_splits[si].to(device)
            num_features = dataset.num_features
            num_classes = dataset.y.max().item() + 1
            
            print(f"{b} - {si} - {strategy}")
            
            # model = GCN(num_features,num_classes).to(device)
            # loss_fn = F.nll_loss
            # optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
            # wrapped_model = ModelWrapper(model,optimizer,loss_fn)
            
            # model = GRACENew(num_features=data_splits[0].num_features,hidden=128, num_layers=2,
            #      drop_edge_rate_1=0.3,drop_edge_rate_2=0.3,
            #      drop_feature_rate_1=0.3,drop_feature_rate_2=0.3,
            #      ratio=0.5, lambda_=1.0).to(device)
            
            train_augmentor = drop_edge_noise_col

            model = GRACENew2(num_features=dataset.num_features,hidden=128, num_layers=2,
                     augmentor1=train_augmentor, augmentor2=train_augmentor).to(device)
            supervised_model = torch.nn.Sequential(
                        torch.nn.Linear(128, 32),
                        torch.nn.ReLU(),
                        torch.nn.Linear(32, 7)
                    ).to(device)
            optimizer = torch.optim.Adam(
                list(model.parameters()) + list(supervised_model.parameters()), 
                lr=0.001, 
                weight_decay=5e-4
            )

            balancing_weight = 0.5 if b < 15 else 0.95
            wrapped_model = SemiSupervisedModelWrapper(supervised_model,model,optimizer, balancing_weight)
            
            trainer = Trainer()
            
            trainer.train(wrapped_model,dataset,200)
            acc = trainer.test(wrapped_model,dataset)
            
            budget_accuracies.append(acc)
            # budget_aucs.append(auc)
            
            # dataset = pool_tuning(wrapped_model, dataset)
            
            query_node_idx = strategy(wrapped_model,dataset,dataset.train_pool)
            print(f'\tQuery node: {query_node_idx}')
            
            dataset.train_mask[query_node_idx] = True
            dataset.train_pool[query_node_idx] = False
            print(f"\tTrain mask: {dataset.train_mask.sum()}")
            print(f"\tTrain pool: {dataset.train_pool.sum()}")

        budget_accuracies = torch.tensor(budget_accuracies)
        budget_aucs = torch.tensor(budget_aucs)
        final_accs[strategy_ix, b-1, :] = budget_accuracies
        # final_auc[strategy_ix, b-1, :] = budget_aucs
        m = budget_accuracies.mean()
        std = budget_accuracies.std()
        wandb.log({"accuracy_mean": m.item(), "step":b})
        wandb.log({"accuracy_std": std.item(), "step": b})
wandb.finish()


In [None]:
dataset =data_splits[0]

In [None]:
out = wrapped_model.self_supervised_model(dataset.x,dataset.edge_index)

In [None]:
from sklearn.manifold import TSNE
out = out.detach().cpu().numpy()

tsne = TSNE(n_components=2)
tsne_out = tsne.fit_transform(out)


In [None]:
import matplotlib.pyplot as plt

plt.scatter(tsne_out[:,0],tsne_out[:,1],c=dataset.y.cpu().numpy())

In [None]:
wandb.finish()