In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch

In [None]:
from torch_geometric.data import Data, DataLoader
from torch_geometric.datasets import TUDataset, Planetoid
from torch_geometric.nn import GCNConv, Set2Set, GNNExplainer
import torch_geometric.transforms as T
import torch
import torch.nn.functional as F
import os

import matplotlib.pyplot as plt
from dgl.data import BACommunityDataset
import networkx as nx
import torch_geometric
import dgl


In [None]:
from src.model import *
from src.data import *
from src.explainer import *
from src.plot import *
from src.modify import *
from src.protgnn import *

In [None]:
ba_dataset = get_dataset("BAShapes")

In [None]:
num_classes = len(set([int(i) for i in ba_dataset.y]))
num_features = ba_dataset.x.shape[-1]

In [None]:
epochs = 3000
dim = 20

## GCExplainer

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net(num_features=num_features, dim=dim, num_classes=num_classes).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-3)


In [None]:
model = train_model(epochs,model,device,ba_dataset,optimizer,test_model)

In [None]:
gce_explainer = GCExplainer()

In [None]:
gce_explainer.learn_prototypes(model,ba_dataset)

In [None]:
gce_explainer.get_prediction(model,ba_dataset)

In [None]:
plot_kmeans_clusters(gce_explainer.kmeans,gce_explainer.initial_activations)

In [None]:
gce_explainer.get_prediction(model,identity(ba_dataset))

In [None]:
gce_explainer.get_completeness(model,identity(ba_dataset))

## ProtGNN

In [None]:
model = GCNNet_NC(num_features, num_classes, model_args).to(device)

In [None]:
model.load_state_dict(torch.load("../models/protgnn_bashapes.pt"))

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-3)

In [None]:
model = train_model(300,model,device,ba_dataset,optimizer,test_model,get_outputs=lambda a: a[1])

In [None]:
prot_explainer = ProtGNNExplainer()

In [None]:
prot_explainer.learn_prototypes(model,ba_dataset)

In [None]:
prot_explainer.get_prediction(model,ba_dataset)

## CDM

In [None]:
model = GCN(num_features=num_features, dim=dim, num_classes=num_classes)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-3)

In [None]:
model = train_model(epochs,model,device,ba_dataset,optimizer,test_model,get_outputs=lambda a: a[1])

In [None]:
cdm_explainer = CDMExplainer()

In [None]:
cdm_explainer.learn_prototypes(model, ba_dataset)

In [None]:
cdm_explainer.get_prediction(model, ba_dataset)

In [None]:
cdm_explainer.get_completeness(model, ba_dataset)

In [None]:
cdm_explainer.get_concepts(model, ba_dataset)

## Adversary Methods

In [None]:
ba_aggressive = aggressive_adversary(ba_dataset,0.1)

In [None]:
ba_aggressive.edge_index.int()

In [None]:
ba_conservative = conservative_adversary(ba_dataset, 'BAShapes', 0.1)

In [None]:
ba_conservative.edge_index.int()

## Plot Prelim Results

In [None]:
plot_metric('results', 'bashapes', 'fidelity_plus')
plot_metric('results', 'bashapes', 'completeness')
plot_difference_metric('results/concepts', 'bashapes', 'concepts')