In [28]:
from use_dataset import ProofDataset, HiddenConcProofDataset
import torch
import numpy as np
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
import random
from sklearn.metrics import f1_score as f1
#from sklearn.utils.class_weight import compute_class_weight
import torch.nn.functional as F
from torch.nn import Linear, ReLU, BatchNorm1d, Sequential
from torch_geometric.nn import GINConv

torch.manual_seed(0)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [None]:
# use file_limit=10000 to only load and verify the first 10000 graphs (creates file of size ~420 MB)

file_limit = 10000    # desired number of graphs to work with
vocab_size = 1598   # number of characters in our vocabulary

pf_data = ProofDataset(root="data/",read_name="10000_relabeled_data_at_least_10.json" , write_name="10000_relabeled_data_at_least_10.pt" ,file_limit=file_limit)  

In [5]:
# make train/val/test for GIN

# set seed for random # generation
random.seed(10)
length = file_limit
total_indices = [i+file_limit for i in range(file_limit)]

# create index vectors to filter dataset
train_indices = [i for i in range(file_limit)]
train_indices = train_indices + random.sample(total_indices, int(length*.8))
train_indices.sort()

val_index_options = [x for x in total_indices if x not in train_indices]
val_indices = random.sample(val_index_options, int(length*.1))
val_indices.sort()

test_index_options = [x for x in total_indices if x not in train_indices if x not in val_indices]
test_indices = random.sample(test_index_options, int(length*.1))
test_indices.sort()

In [6]:
# our model is judged on how well it can predict the label for the final/conclusion node as seen in the val set created above
# to not give our model too much info, we zero out the features on the final/conclusion node for each graph in the val set

data_list = []

for idx, graph in enumerate(pf_data):
    # inherit features, labels, and edges
    x = graph.x.clone()
    y = graph.y.clone()
    edge_index = graph.edge_index.clone()
    x[-1] = torch.zeros(512) # zero out ALL conclusion nodes

    #if idx in val_indices:
        #x[-1] = torch.zeros(512)    # zero out features of final/conclusion node for each graph in val set

# replace features of conlusion/final nodes from val set with average of neighboring node embeddings
    #if idx in val_indices:

        #connected_nodes = []
        #sum = torch.zeros(512)

        #for i, edge in enumerate(edge_index[0]): # case of outgoing edge coming from conc node
            #if edge == (pf_data[idx].num_nodes-1):  
                #connected_nodes.append(int(edge_index[1][i].item()))
        #for i, edge in enumerate(edge_index[1]): # case of outgoing edge going to conc node
            #if edge == (pf_data[idx].num_nodes-1):
                #connected_nodes.append(int(edge_index[0][i].item()))

        #for i in connected_nodes:
            #if not torch.equal(pf_data[idx].x[i], pf_data[idx].x[-1]):
            #sum += pf_data[idx].x[i]
        #if len(connected_nodes) > 0:
            #x[-1] = sum/(len(connected_nodes))
        #else:
            #x[-1] = sum

    data_list.append(Data(x=x,y=y,edge_index=edge_index))

# read_name not used, but needed to instantiate class
hidden_conc_pf_data = HiddenConcProofDataset(root="data/",read_name="10000_relabeled_data_at_least_10_w_stmts.json",write_name="overwritten_labels.pt", data_list=data_list)

Processing...
Done!


In [7]:
# Create training, validation, and test sets
train_dataset = hidden_conc_pf_data[train_indices]
val_dataset = hidden_conc_pf_data[val_indices]
test_dataset = hidden_conc_pf_data[test_indices]

print(f'Training set   = {len(train_dataset)} graphs')
print(f'Validation set = {len(val_dataset)} graphs')
print(f'Test set       = {len(test_dataset)} graphs')

Training set   = 18000 graphs
Validation set = 1000 graphs
Test set       = 1000 graphs


In [8]:
# Create mini-batches
# batch_size is number of graphs
train_loader = DataLoader(train_dataset, batch_size=750, shuffle=True)
val_loader   = DataLoader(val_dataset, batch_size=500, shuffle=False)
test_loader  = DataLoader(test_dataset, batch_size=500, shuffle=False)

In [9]:
print('\nTrain loader:')
for i, batch in enumerate(train_loader):
    print(f' - Batch {i}: {batch}')

print('\nValidation loader:')
for i, batch in enumerate(val_loader):
    print(f' - Batch {i}: {batch}')

print('\nTest loader:')
for i, batch in enumerate(test_loader):
    print(f' - Batch {i}: {batch}')


Train loader:
 - Batch 0: DataBatch(x=[7562, 512], edge_index=[2, 6812], y=[7562], batch=[7562], ptr=[751])
 - Batch 1: DataBatch(x=[7012, 512], edge_index=[2, 6262], y=[7012], batch=[7012], ptr=[751])
 - Batch 2: DataBatch(x=[6329, 512], edge_index=[2, 5579], y=[6329], batch=[6329], ptr=[751])
 - Batch 3: DataBatch(x=[6330, 512], edge_index=[2, 5580], y=[6330], batch=[6330], ptr=[751])
 - Batch 4: DataBatch(x=[9286, 512], edge_index=[2, 8536], y=[9286], batch=[9286], ptr=[751])
 - Batch 5: DataBatch(x=[7190, 512], edge_index=[2, 6440], y=[7190], batch=[7190], ptr=[751])
 - Batch 6: DataBatch(x=[7955, 512], edge_index=[2, 7205], y=[7955], batch=[7955], ptr=[751])
 - Batch 7: DataBatch(x=[11743, 512], edge_index=[2, 10993], y=[11743], batch=[11743], ptr=[751])
 - Batch 8: DataBatch(x=[8107, 512], edge_index=[2, 7357], y=[8107], batch=[8107], ptr=[751])
 - Batch 9: DataBatch(x=[8510, 512], edge_index=[2, 7760], y=[8510], batch=[8510], ptr=[751])
 - Batch 10: DataBatch(x=[13047, 512], ed

In [10]:
# make a dictionary to record label frequency

# get max label used in hidden_conc_pf_data
max_label = 0

for i in range(file_limit*2):
    for j in hidden_conc_pf_data.get(i).y:
        if j > max_label:
            max_label = j.to(int).item()

# initialize histogram for labels used in hidden_conc_pf_data            
label_count = {}

for i in range(file_limit*2):
    for j in range(max_label+1):        
        label_count[j] = 0

for i in range(file_limit*2):
    for j in hidden_conc_pf_data.get(i).y:
        label_count[j.to(int).item()] += 1

step_count = 0
max = 0
max_freq_index = None   #find the most frequently used index
labels_never_used = 0
labels_used_once = 0
labels_used_twice = 0

for k,v in label_count.items():
    step_count += v
    if v > max:     
        max = v
        max_freq_index = k

    if v == 0:
        labels_never_used += 1
    if v == 1:
        labels_used_once += 1
    if v ==2:
        labels_used_twice += 1

In [11]:
print(f"total number of steps is:", step_count)
print(f"highest frequency label is {max_freq_index} and occurs {max} times")
print(f"final label used is {len(label_count)-1}")
print(label_count)
print(len(label_count),"unique labels are used")
print(labels_never_used,"unique labels never used")
print(labels_used_once, "unique labels used once")
print(labels_used_twice, "unique labels used twice")

total number of steps is: 211694
highest frequency label is 0 and occurs 44150 times
final label used is 2276
{0: 44150, 1: 1863, 2: 66, 3: 77, 4: 111, 5: 1285, 6: 18, 7: 127, 8: 28, 9: 357, 10: 55, 11: 4640, 12: 697, 13: 72, 14: 145, 15: 81, 16: 762, 17: 47, 18: 184, 19: 15, 20: 69, 21: 27, 22: 93, 23: 323, 24: 26, 25: 243, 26: 318, 27: 30, 28: 40, 29: 100, 30: 23, 31: 37, 32: 11, 33: 51, 34: 29, 35: 13, 36: 14, 37: 144, 38: 26, 39: 42, 40: 27, 41: 14, 42: 15, 43: 22, 44: 15, 45: 26, 46: 12, 47: 37, 48: 227, 49: 92, 50: 11, 51: 11, 52: 11, 53: 24, 54: 23, 55: 28, 56: 50, 57: 152, 58: 52, 59: 47, 60: 45, 61: 39, 62: 15, 63: 25, 64: 26, 65: 11, 66: 30, 67: 39, 68: 45, 69: 60, 70: 51, 71: 14, 72: 12, 73: 50, 74: 15, 75: 21, 76: 46, 77: 66, 78: 97, 79: 13, 80: 34, 81: 26, 82: 22, 83: 62, 84: 70, 85: 14, 86: 32, 87: 33, 88: 21, 89: 25, 90: 236, 91: 12, 92: 27, 93: 13, 94: 13, 95: 52, 96: 77, 97: 46, 98: 300, 99: 192, 100: 29, 101: 221, 102: 672, 103: 888, 104: 15, 105: 28, 106: 26, 107: 11

In [12]:
# make array of label frequencies for sklearn compute_class_weight using entire dataset
# really should be doing this for train set (otherwise, data leakage...)
# however, train set may not include certain labels, which leads to error in compute_class_weight

# make array of unique classes
#class_num_arr = [i for i in range(len(label_count))]
#class_num_arr = np.array(class_num_arr)

# make array of all data points with labels
#lbl_arr = np.array([])
#for i in range(file_limit*2):
    #for y in hidden_conc_pf_data.get(i).y:
        #lbl_arr = np.append(lbl_arr,[y.numpy()],axis=0).astype(int)

#class_weights = compute_class_weight(class_weight="balanced",classes = class_num_arr, y=lbl_arr)
#class_weights = torch.from_numpy(class_weights).float().to(device)
#class_weights[0:10]

In [53]:
class GIN(torch.nn.Module):
    """GIN"""
    def __init__(self, dim_h, norm_mode='None', norm_scale=10):
        super(GIN, self).__init__()
        self.conv1 = GINConv(
            Sequential(Linear(hidden_conc_pf_data.num_node_features, dim_h),
                       BatchNorm1d(dim_h), ReLU(),
                       Linear(dim_h, dim_h), ReLU()))
        
        self.conv2 = GINConv(
            Sequential(Linear(dim_h, dim_h), BatchNorm1d(dim_h), ReLU(),
                       Linear(dim_h, dim_h), ReLU()))
        
        self.conv3 = GINConv(
            Sequential(Linear(dim_h, dim_h), BatchNorm1d(dim_h), ReLU(),
                       Linear(dim_h, dim_h), ReLU()))
        
        self.lin1 = Linear(dim_h*3, dim_h*3)
        self.lin2 = Linear(dim_h*3, hidden_conc_pf_data.num_classes)

    def forward(self, x, edge_index):
        
        h1 = self.conv1(x, edge_index)
        h2 = self.conv2(h1, edge_index)
        h3 = self.conv3(h2, edge_index)

        # Concatenate graph embeddings
        h = torch.cat((h1, h2, h3), dim=1)
        h = F.dropout(h, p=0.5, training=self.training)

        h = self.lin1(h)
        h = h.relu()
        h = F.dropout(h, p=0.5, training=self.training)

        h = self.lin2(h)        
        return F.log_softmax(h, dim=1)

In [54]:
# GIN model training

def train(model, loader, lr):
    criterion = torch.nn.CrossEntropyLoss()
    # commented out code is to use class weights to account for imbalanced dataset
    #criterion = torch.nn.CrossEntropyLoss(weight=class_weights)
    #optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    optimizer = torch.optim.RMSprop(model.parameters(), lr=lr)
    epochs = 200

    model.train()
    for epoch in range(epochs+1):
        total_loss = 0
        acc = 0
        val_loss = 0
        val_acc = 0
        #cur_graph = 0   # used to keep track of current statement to enforce preds of only PREVIOUS labels in training

        # train on batches
        for data in loader:       
            #cur_graph += torch.max(data.batch) + 1
            data = data.to(device, non_blocking=True)
            data.y = data.y.to(torch.float).to(device, non_blocking=True)
            optimizer.zero_grad()
            length = len(loader)
            out = model(data.x, data.edge_index.long())
            data.y = data.y.type(torch.LongTensor).to(device, non_blocking=True)
            out = out.type(torch.float32).to(device, non_blocking=True)

            # find conclusion/final nodes in each graph to later scale
            #conc_node_indices = []
            #cur_graph_counter = data.batch[0].clone()
            #for idx, graph in enumerate(data.batch.clone()):
                #if graph > cur_graph_counter:
                    #conc_node_indices.append(idx-1)
                    #cur_graph_counter += 1
            #conc_node_indices.append(len(data.batch)-1)              
            #weighted_out = out.clone()
            #weighted_out[conc_node_indices,:] = weighted_out[conc_node_indices,:]*.25    # scale output from final nodes

            #loss = criterion(weighted_out, data.y)
            loss = criterion(out, data.y)

            total_loss += loss / length

            # commented out code below is meant to enforce predictions to only come from previous theorems
            # for dict to be properly created, you must delete and recreate test.pt by rerunning 
            # pf_data = ProofDataset(root="data/",file_limit=file_limit)
            
            #dict = pf_data.class_corr
            #dict_keys = [k for k in dict.keys()]            

            #def return_next_lowest_idx(num):
                #if num in dict_keys:
                    #return dict[num]
                #while num not in dict_keys:
                    #try:
                        #num -= 1
                        #return dict[num]
                    #except:
                        #pass

            #with torch.no_grad():
                #cur_graph_batch.apply_(return_next_lowest_idx)
                #masked_lbls = (torch.arange(out.size(1)) < (cur_graph_batch[..., None]+1))*(out.cpu())
                #try:
                    #masked_lbls = (torch.arange(out.size(1)) < (cur_graph_batch[..., None]+1))*(out.cpu())
                    #masked_lbls = torch.where(masked_lbls==0,np.nan,masked_lbls)
                    #masked_lbls = masked_lbls.detach().numpy()                
                    #pred = np.nanargmax(masked_lbls,axis=1)
                    #pred = torch.from_numpy(pred)
                    #acc += accuracy(pred, data.y.cpu()) / length
                #except Exception as e:
                    #print("a lil error")
                    #out = out + .00000001
                    #masked_lbls = (torch.arange(out.size(1)) < (cur_graph_batch[..., None]+1))*(out.cpu())
                    #masked_lbls = torch.where(masked_lbls==0,np.nan,masked_lbls)
                    #masked_lbls = masked_lbls.detach().numpy()
                    #pred = np.nanargmax(masked_lbls,axis=1)
                    #pred = torch.from_numpy(pred)
                    #acc += accuracy(pred, data.y.cpu()) / length
                    #out = out - .00000001

            #comment out the following pred and acc lines if enforcing predictions as described above
            pred = out.argmax(dim=1)
            #use the following pred line instead if weighting final/conclusion node differently
            #pred = weighted_out.argmax(dim=1)

            acc += accuracy(pred, data.y) / length
            loss.backward()
            optimizer.step()

            # run model on validation set
            val_loss, val_acc, val_f1, top5_acc = test(model, val_loader)

        # Print metrics every epoch
        if(epoch % 10 == 0):
            print(f'Epoch {epoch:>3} | Train Loss: {total_loss:.2f} | Train Acc: {acc*100:>5.2f}% | Val Loss: {val_loss:.2f} | Val Acc: {val_acc*100:.2f}% | Top5 Val Acc: {top5_acc*100:.2f}%| F Score: {val_f1:.2f}')
            
    return model

@torch.no_grad()
def test(model, loader):
    
    criterion = torch.nn.CrossEntropyLoss()
    model.eval()
    loss = 0
    acc = 0
    top5_acc = 0
    fscore = 0    
    
    for data in loader:
        data = data.to(device, non_blocking=True)
        length = len(loader)
        out = model(data.x, data.edge_index.long())
        prob = torch.exp(out)
        prob_sorted = torch.topk(prob,k=5).indices
        data.y = data.y.type(torch.LongTensor).to(device, non_blocking=True)
        loss += criterion(out, data.y) / length
        pred = out.argmax(dim=1)
        acc += accuracy(pred, data.y) / length
        top5_acc += torch.sum(torch.sum(prob_sorted==data.y.unsqueeze(1),dim=1),dim=0) / (length*(data.y.shape[0]))
        fscore += f1(pred.cpu(), data.y.cpu(), average='macro')    # micro looks better, but macro prob more accurate

    return loss, acc, fscore, top5_acc

def accuracy(pred_y, y):
    """Calculate accuracy."""
    return ((pred_y == y).sum() / len(y)).item()

In [55]:
# initialize (and reset weights of) model
# dim_h is hyperparameter of number of hidden layers

gin_trained = None
gin = None
gin = GIN(dim_h=400).to(device)
gin

GIN(
  (conv1): GINConv(nn=Sequential(
    (0): Linear(in_features=512, out_features=400, bias=True)
    (1): BatchNorm1d(400, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Linear(in_features=400, out_features=400, bias=True)
    (4): ReLU()
  ))
  (conv2): GINConv(nn=Sequential(
    (0): Linear(in_features=400, out_features=400, bias=True)
    (1): BatchNorm1d(400, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Linear(in_features=400, out_features=400, bias=True)
    (4): ReLU()
  ))
  (conv3): GINConv(nn=Sequential(
    (0): Linear(in_features=400, out_features=400, bias=True)
    (1): BatchNorm1d(400, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Linear(in_features=400, out_features=400, bias=True)
    (4): ReLU()
  ))
  (lin1): Linear(in_features=1200, out_features=1200, bias=True)
  (lin2): Linear(in_features=1200, out_features=2277, bias=True)
)

In [56]:
# reset weights and train model
gin_trained = None
gin_trained = train(gin, train_loader,lr=.001)

Epoch   0 | Train Loss: 7.00 | Train Acc: 16.14% | Val Loss: 5.78 | Val Acc: 17.41% | Top5 Val Acc: 30.02%| F Score: 0.00
Epoch  10 | Train Loss: 3.46 | Train Acc: 36.80% | Val Loss: 4.02 | Val Acc: 26.29% | Top5 Val Acc: 44.22%| F Score: 0.01
Epoch  20 | Train Loss: 2.56 | Train Acc: 44.15% | Val Loss: 3.13 | Val Acc: 31.80% | Top5 Val Acc: 55.83%| F Score: 0.09
Epoch  30 | Train Loss: 2.17 | Train Acc: 49.55% | Val Loss: 2.65 | Val Acc: 38.58% | Top5 Val Acc: 65.62%| F Score: 0.21
Epoch  40 | Train Loss: 1.67 | Train Acc: 57.84% | Val Loss: 2.34 | Val Acc: 44.73% | Top5 Val Acc: 71.74%| F Score: 0.39
Epoch  50 | Train Loss: 1.42 | Train Acc: 63.07% | Val Loss: 2.32 | Val Acc: 45.10% | Top5 Val Acc: 72.93%| F Score: 0.41
Epoch  60 | Train Loss: 1.34 | Train Acc: 65.02% | Val Loss: 2.51 | Val Acc: 43.74% | Top5 Val Acc: 70.73%| F Score: 0.40
Epoch  70 | Train Loss: 1.14 | Train Acc: 69.26% | Val Loss: 2.27 | Val Acc: 48.69% | Top5 Val Acc: 75.89%| F Score: 0.52
Epoch  80 | Train Loss: 

In [57]:
# get accuracy only for conclusion node from val set

total_conc_labels = len(val_dataset)
correct_conc_pred = 0

# dict of the form true_label:(predicted_labels)
incorrect_preds = {}

for graph in val_dataset:
    get_gin_predict = gin_trained(graph.x.to(device),graph.edge_index.long().to(device))
    get_conc_predict = np.argmax(get_gin_predict[-1].detach().cpu().numpy(),axis=0)
    if get_conc_predict == graph.y[-1].item():
        correct_conc_pred += 1
    if get_conc_predict != graph.y[-1].item():
        if graph.y[-1].item() in incorrect_preds:
            incorrect_preds[graph.y[-1].item()] = incorrect_preds[graph.y[-1].item()]+(get_conc_predict,)
        else:
            incorrect_preds[graph.y[-1].item()] = (get_conc_predict,)


def get_conc_gin_label_acc(i):
    correct_pred_i = 0
    total_num_i = 0

    for graph in val_dataset:
        get_gin_predict = gin_trained(graph.x.to(device),graph.edge_index.long().to(device))
        get_conc_predict = np.argmax(get_gin_predict[-1].detach().cpu().numpy(),axis=0)
        if graph.y[-1].item() == i:
            total_num_i += 1
            if get_conc_predict == i:
                correct_pred_i += 1
    if total_num_i == 0:
        return 0,0,0
    return (correct_pred_i / total_num_i), total_num_i, correct_pred_i

def get_topk_acc(k):

    top3_corr = 0

    for graph in val_dataset:
        get_gin_predict = gin_trained(graph.x.to(device),graph.edge_index.long().to(device))
        prob = torch.exp(get_gin_predict[-1])
        prob_sorted = torch.topk(prob,k=k).indices
        graph.y = graph.y.to(device)
        top3_corr += torch.sum(prob_sorted==graph.y[-1])

    return top3_corr / len(val_dataset)

k=5
print(f"restricting to conclusion nodes, our val_graph has accuracy being {correct_conc_pred/total_conc_labels*100:.2f}%")
print(f"'unk' accuracy should be: {get_conc_gin_label_acc(2276)[0]*100:.2f}%")
print(f"top{k} accuracy is: {get_topk_acc(k).item()*100:.2f}%")

restricting to conclusion nodes, our val_graph has accuracy being 41.40%
'unk' accuracy should be: 68.55%
top5 accuracy is: 69.60%


In [20]:
# make a dictionary to record label frequency based on val set
# only consider conclusion/final nodes

# get max label used in pf_data
max_label = 0

for i in val_indices:
    if hidden_conc_pf_data.get(i).y[-1] > max_label:
        max_label = hidden_conc_pf_data.get(i).y[-1].to(int).item()

# initialize histogram for labels used in pf_data            
label_count = {}

for i in val_indices:
    for j in range(max_label+1):        
        label_count[j] = 0

for i in val_indices:
    label_count[hidden_conc_pf_data.get(i).y[-1].to(int).item()] += 1

step_count = 0
max = 0
max_freq_index = None   # find the most frequently used index
labels_never_used = 0
labels_used_once = 0
labels_used_twice = 0


for k,v in label_count.items():
    step_count += v
    if v > max:     
        max = v
        max_freq_index = k

    if v == 0:
        labels_never_used += 1
    if v == 1:
        labels_used_once += 1
    if v ==2:
        labels_used_twice += 1

In [21]:
print(f"total number of steps is:", step_count)
print(f"highest frequency label is {max_freq_index} and occurs {max} times")
print(f"final label used is {len(label_count)-1}")
print(label_count)
print(len(label_count),"unique labels are used")
print(labels_never_used,"unique labels never used")
print(labels_used_once, "unique labels used once")
print(labels_used_twice, "unique labels used twice")

total number of steps is: 1000
highest frequency label is 2276 and occurs 124 times
final label used is 2276
{0: 2, 1: 31, 2: 0, 3: 2, 4: 1, 5: 5, 6: 0, 7: 0, 8: 0, 9: 7, 10: 0, 11: 34, 12: 1, 13: 0, 14: 4, 15: 0, 16: 0, 17: 0, 18: 0, 19: 0, 20: 0, 21: 0, 22: 2, 23: 3, 24: 0, 25: 4, 26: 5, 27: 0, 28: 0, 29: 2, 30: 0, 31: 0, 32: 0, 33: 0, 34: 0, 35: 0, 36: 1, 37: 1, 38: 0, 39: 1, 40: 0, 41: 0, 42: 0, 43: 0, 44: 0, 45: 4, 46: 0, 47: 0, 48: 2, 49: 0, 50: 0, 51: 0, 52: 0, 53: 0, 54: 0, 55: 0, 56: 0, 57: 1, 58: 0, 59: 0, 60: 0, 61: 0, 62: 0, 63: 0, 64: 0, 65: 0, 66: 0, 67: 0, 68: 0, 69: 1, 70: 0, 71: 0, 72: 0, 73: 0, 74: 1, 75: 1, 76: 2, 77: 1, 78: 0, 79: 0, 80: 2, 81: 0, 82: 0, 83: 0, 84: 1, 85: 0, 86: 0, 87: 1, 88: 0, 89: 1, 90: 17, 91: 2, 92: 1, 93: 1, 94: 0, 95: 0, 96: 1, 97: 0, 98: 25, 99: 14, 100: 0, 101: 3, 102: 12, 103: 6, 104: 0, 105: 0, 106: 0, 107: 1, 108: 6, 109: 8, 110: 5, 111: 0, 112: 2, 113: 8, 114: 17, 115: 5, 116: 1, 117: 17, 118: 1, 119: 0, 120: 0, 121: 1, 122: 1, 123: 1, 

In [None]:
# make acc dict for val conclusion nodes

conc_label_acc_dict = {}

def get_conc_gin_label_acc(i):
    correct_pred_i = 0
    total_num_i = 0

    for graph in val_dataset:
        get_gin_predict = gin_trained(graph.x.to(device),graph.edge_index.long().to(device))
        get_conc_predict = np.argmax(get_gin_predict[-1].detach().cpu().numpy(),axis=0)
        if graph.y[-1].item() == i:
            total_num_i += 1
            if get_conc_predict == i:
                correct_pred_i += 1
    if total_num_i == 0:
        return 0,0,0
    return (correct_pred_i / total_num_i), total_num_i, correct_pred_i


total_conc_labels = 0
total_conc_correct_preds = 0
for i in range(hidden_conc_pf_data.num_classes):
    if i % 200 ==0:
        print("on label",i)
    try:
        container = get_conc_gin_label_acc(i)
        if container[1] != 0:
            conc_label_acc_dict[i] = (container[0],container[1])  # pair of accuracy and label count for each label
        total_conc_labels += container[1]
        total_conc_correct_preds += container[2]
    except Exception as e:
        print(e,get_conc_gin_label_acc(i))

print(conc_label_acc_dict)
print("if everything here is correct, our conc node val_acc should be:",total_conc_correct_preds/total_conc_labels)

In [None]:
# sort the above dict by accuracy

sorted_conc_label_acc_dict = {k: v for k, v in sorted(conc_label_acc_dict.items(), key=lambda item: item[1])}
sorted_conc_label_acc_dict