In [10]:
from use_dataset import ProofDataset, HiddenConcProofDataset
import torch
import numpy as np
from torch_geometric.data import Batch, Data, Dataset
from torch_geometric.loader import DataLoader, NeighborLoader
import random
from sklearn.metrics import f1_score as f1
from sklearn.utils.class_weight import compute_class_weight
import torch.nn.functional as F
from torch.nn import Linear, ReLU, Dropout, BatchNorm1d, Sequential
from torch_geometric.nn import GCNConv, global_mean_pool, GINConv, global_add_pool, PairNorm

torch.manual_seed(0)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [11]:
# use file_limit=5000 to only load and verify the first 5000 graphs (~60 MB)

file_limit = 10000    # desired number of graphs to work with
vocab_size = 1598   # number of characters in our vocabulary

pf_data = ProofDataset(root="data/",read_name="10000_relabeled_data_at_least_5.json" , write_name="10000_relabeled_data_at_least_5_w_stmts.pt" ,file_limit=file_limit)  

In [12]:
# make train/val/test for GIN

# set seed for random # generation
random.seed(10)
length = file_limit
total_indices = [i+file_limit for i in range(file_limit)]

# create index vectors to filter dataset
train_indices = [i for i in range(file_limit)]
train_indices = train_indices + random.sample(total_indices, int(length*.8))
train_indices.sort()

val_index_options = [x for x in total_indices if x not in train_indices]
val_indices = random.sample(val_index_options, int(length*.1))
val_indices.sort()

test_index_options = [x for x in total_indices if x not in train_indices if x not in val_indices]
test_indices = random.sample(test_index_options, int(length*.1))
test_indices.sort()


In [14]:
# our model is judged on how well it can predict the label for the final/conclusion node as seen in the val set created above
# to not give our model too much info, we zero out the features on the final/conclusion node for each graph in the val set

data_list = []

for idx, graph in enumerate(pf_data):
    # inherit features, labels, and edges
    x = graph.x.clone()
    y = graph.y.clone()
    edge_index = graph.edge_index.clone()
    x[-1] = torch.zeros(512) # zero out ALL conclusion nodes

    #if idx in val_indices:
        #x[-1] = torch.zeros(512)    # zero out features of final/conclusion node for each graph in val set

# replace features of conlusion/final nodes from val set with average of neighboring node embeddings
    #if idx in val_indices:

        #connected_nodes = []
        #sum = torch.zeros(512)

        #for i, edge in enumerate(edge_index[0]): # case of outgoing edge coming from conc node
            #if edge == (pf_data[idx].num_nodes-1):  
                #connected_nodes.append(int(edge_index[1][i].item()))
        #for i, edge in enumerate(edge_index[1]): # case of outgoing edge going to conc node
            #if edge == (pf_data[idx].num_nodes-1):
                #connected_nodes.append(int(edge_index[0][i].item()))

        #for i in connected_nodes:
            #if not torch.equal(pf_data[idx].x[i], pf_data[idx].x[-1]):
            #sum += pf_data[idx].x[i]
        #if len(connected_nodes) > 0:
            #x[-1] = sum/(len(connected_nodes))
        #else:
            #x[-1] = sum


    data_list.append(Data(x=x,y=y,edge_index=edge_index))

hidden_conc_pf_data = HiddenConcProofDataset(root="data/",read_name="5000_relabeled_data.json",write_name="overwritten_labels", data_list=data_list)

In [15]:
# Create training, validation, and test sets
train_dataset = hidden_conc_pf_data[train_indices]
val_dataset = hidden_conc_pf_data[val_indices]
test_dataset = hidden_conc_pf_data[test_indices]

print(f'Training set   = {len(train_dataset)} graphs')
print(f'Validation set = {len(val_dataset)} graphs')
print(f'Test set       = {len(test_dataset)} graphs')

Training set   = 9000 graphs
Validation set = 500 graphs
Test set       = 500 graphs


In [16]:
# Create mini-batches
# batch_size is number of graphs
train_loader = DataLoader(train_dataset, batch_size=1000, shuffle=True)
val_loader   = DataLoader(val_dataset, batch_size=500, shuffle=False)
test_loader  = DataLoader(test_dataset, batch_size=500, shuffle=False)

In [17]:
print('\nTrain loader:')
for i, batch in enumerate(train_loader):
    print(f' - Batch {i}: {batch}')

print('\nValidation loader:')
for i, batch in enumerate(val_loader):
    print(f' - Batch {i}: {batch}')

print('\nTest loader:')
for i, batch in enumerate(test_loader):
    print(f' - Batch {i}: {batch}')


Train loader:
 - Batch 0: DataBatch(x=[3984, 512], edge_index=[2, 2984], y=[3984], batch=[3984], ptr=[1001])
 - Batch 1: DataBatch(x=[3876, 512], edge_index=[2, 2876], y=[3876], batch=[3876], ptr=[1001])
 - Batch 2: DataBatch(x=[4065, 512], edge_index=[2, 3065], y=[4065], batch=[4065], ptr=[1001])
 - Batch 3: DataBatch(x=[3737, 512], edge_index=[2, 2737], y=[3737], batch=[3737], ptr=[1001])
 - Batch 4: DataBatch(x=[3754, 512], edge_index=[2, 2754], y=[3754], batch=[3754], ptr=[1001])
 - Batch 5: DataBatch(x=[4068, 512], edge_index=[2, 3068], y=[4068], batch=[4068], ptr=[1001])
 - Batch 6: DataBatch(x=[3743, 512], edge_index=[2, 2743], y=[3743], batch=[3743], ptr=[1001])
 - Batch 7: DataBatch(x=[3860, 512], edge_index=[2, 2860], y=[3860], batch=[3860], ptr=[1001])
 - Batch 8: DataBatch(x=[3603, 512], edge_index=[2, 2603], y=[3603], batch=[3603], ptr=[1001])

Validation loader:
 - Batch 0: DataBatch(x=[3207, 512], edge_index=[2, 2707], y=[3207], batch=[3207], ptr=[501])

Test loader:
 -

In [18]:
# make a dictionary to record label frequency

# get max label used in hidden_conc_pf_data
max_label = 0

for i in range(file_limit*2):
    for j in hidden_conc_pf_data.get(i).y:
        if j > max_label:
            max_label = j.to(int).item()

# initialize histogram for labels used in hidden_conc_pf_data            
label_count = {}

for i in range(file_limit*2):
    for j in range(max_label+1):        
        label_count[j] = 0

for i in range(file_limit*2):
    for j in hidden_conc_pf_data.get(i).y:
        label_count[j.to(int).item()] += 1

step_count = 0
max = 0
max_freq_index = None   #find the most frequently used index
labels_never_used = 0
labels_used_once = 0
labels_used_twice = 0


for k,v in label_count.items():
    step_count += v
    if v > max:     
        max = v
        max_freq_index = k

    if v == 0:
        labels_never_used += 1
    if v == 1:
        labels_used_once += 1
    if v ==2:
        labels_used_twice += 1

In [19]:
print(f"total number of steps is:", step_count)
print(f"highest frequency label is {max_freq_index} and occurs {max} times")
print(f"final label used is {len(label_count)-1}")
print(label_count)
print(len(label_count),"unique labels are used")
print(labels_never_used,"unique labels never used")
print(labels_used_once, "unique labels used once")
print(labels_used_twice, "unique labels used twice")

total number of steps is: 40934
highest frequency label is 554 and occurs 12143 times
final label used is 554
{0: 10280, 1: 545, 2: 52, 3: 24, 4: 270, 5: 14, 6: 40, 7: 42, 8: 390, 9: 26, 10: 28, 11: 245, 12: 23, 13: 57, 14: 12, 15: 21, 16: 108, 17: 47, 18: 75, 19: 14, 20: 18, 21: 26, 22: 17, 23: 13, 24: 18, 25: 17, 26: 12, 27: 15, 28: 50, 29: 18, 30: 13, 31: 14, 32: 12, 33: 42, 34: 12, 35: 20, 36: 13, 37: 14, 38: 18, 39: 22, 40: 29, 41: 83, 42: 25, 43: 12, 44: 127, 45: 60, 46: 27, 47: 83, 48: 130, 49: 100, 50: 24, 51: 24, 52: 54, 53: 115, 54: 29, 55: 25, 56: 38, 57: 59, 58: 78, 59: 59, 60: 23, 61: 29, 62: 125, 63: 41, 64: 44, 65: 11, 66: 17, 67: 12, 68: 32, 69: 15, 70: 19, 71: 21, 72: 25, 73: 19, 74: 20, 75: 33, 76: 52, 77: 25, 78: 13, 79: 13, 80: 26, 81: 16, 82: 17, 83: 474, 84: 39, 85: 74, 86: 172, 87: 73, 88: 20, 89: 20, 90: 67, 91: 21, 92: 61, 93: 24, 94: 21, 95: 28, 96: 45, 97: 107, 98: 15, 99: 16, 100: 39, 101: 242, 102: 37, 103: 12, 104: 113, 105: 24, 106: 11, 107: 60, 108: 14, 

In [20]:
# make array of label frequencies for sklearn compute_class_weight using entire dataset
# really should be doing this for train set (otherwise, data leakage...)
# however, train set may not include certain labels, which leads to error in compute_class_weight

# make array of unique classes
class_num_arr = [i for i in range(len(label_count))]
class_num_arr = np.array(class_num_arr)

# make array of all data points with labels
lbl_arr = np.array([])
for i in range(file_limit*2):
    for y in hidden_conc_pf_data.get(i).y:
        lbl_arr = np.append(lbl_arr,[y.numpy()],axis=0).astype(int)

class_weights = compute_class_weight(class_weight="balanced",classes = class_num_arr, y=lbl_arr)
class_weights = torch.from_numpy(class_weights).float().to(device)
class_weights[0:10]

tensor([0.0072, 0.1353, 1.4184, 3.0731, 0.2732, 5.2682, 1.8439, 1.7561, 0.1891,
        2.8367], device='cuda:0')

In [31]:
class GIN(torch.nn.Module):
    """GIN"""
    def __init__(self, dim_h, norm_mode='None', norm_scale=10):
        super(GIN, self).__init__()
        self.conv1 = GINConv(
            Sequential(Linear(hidden_conc_pf_data.num_node_features, dim_h),
                       BatchNorm1d(dim_h), ReLU(),
                       Linear(dim_h, dim_h), ReLU()))
        
        self.conv2 = GINConv(
            Sequential(Linear(dim_h, dim_h), BatchNorm1d(dim_h), ReLU(),
                       Linear(dim_h, dim_h), ReLU()))
        
        self.conv3 = GINConv(
            Sequential(Linear(dim_h, dim_h), BatchNorm1d(dim_h), ReLU(),
                       Linear(dim_h, dim_h), ReLU()))
        
        self.lin1 = Linear(dim_h*3, dim_h*3)
        self.lin2 = Linear(dim_h*3, hidden_conc_pf_data.num_classes)

    def forward(self, x, edge_index):
        h1 = self.conv1(x, edge_index)
        h2 = self.conv2(h1, edge_index)
        h3 = self.conv3(h2, edge_index)

        # Concatenate graph embeddings
        h = torch.cat((h1, h2, h3), dim=1)
        h = F.dropout(h, p=0.5, training=self.training)

        h = self.lin1(h)
        h = h.relu()
        h = F.dropout(h, p=0.5, training=self.training)

        h = self.lin2(h)        
        return F.log_softmax(h, dim=1)

In [124]:
# GIN model training

def train(model, loader, lr):
    criterion = torch.nn.CrossEntropyLoss()
    # commented out code is to use class weights to account for imbalanced dataset
    #criterion = torch.nn.CrossEntropyLoss(weight=class_weights)
    #optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    optimizer = torch.optim.RMSprop(model.parameters(), lr=lr)
    epochs = 1000

    model.train()
    for epoch in range(epochs+1):
        total_loss = 0
        acc = 0
        val_loss = 0
        val_acc = 0
        #cur_graph = 0   # used to keep track of current statement to enforce preds of only PREVIOUS labels in training

        # Train on batches
        for data in loader:       
            #cur_graph += torch.max(data.batch) + 1
            data = data.to(device, non_blocking=True)
            data.y = data.y.to(torch.float).to(device, non_blocking=True)
            optimizer.zero_grad()
            length = len(loader)
            out = model(data.x, data.edge_index.long())
            data.y = data.y.type(torch.LongTensor).to(device, non_blocking=True)
            out = out.type(torch.float32).to(device, non_blocking=True)
            loss = criterion(out, data.y)
            total_loss += loss / length

            # commented out code below is meant to enforce predictions to only come from previous theorems
            # for dict to be properly created, you must delete and recreate test.pt by rerunning 
            # pf_data = ProofDataset(root="data/",file_limit=file_limit)
            
            #dict = pf_data.class_corr
            #dict_keys = [k for k in dict.keys()]            

            #def return_next_lowest_idx(num):
                #if num in dict_keys:
                    #return dict[num]
                #while num not in dict_keys:
                    #try:
                        #num -= 1
                        #return dict[num]
                    #except:
                        #pass

            #with torch.no_grad():
                #cur_graph_batch.apply_(return_next_lowest_idx)
                #masked_lbls = (torch.arange(out.size(1)) < (cur_graph_batch[..., None]+1))*(out.cpu())
                #try:
                    #masked_lbls = (torch.arange(out.size(1)) < (cur_graph_batch[..., None]+1))*(out.cpu())
                    #masked_lbls = torch.where(masked_lbls==0,np.nan,masked_lbls)
                    #masked_lbls = masked_lbls.detach().numpy()                
                    #pred = np.nanargmax(masked_lbls,axis=1)
                    #pred = torch.from_numpy(pred)
                    #acc += accuracy(pred, data.y.cpu()) / length
                #except Exception as e:
                    #print("a lil error")
                    #out = out + .00000001
                    #masked_lbls = (torch.arange(out.size(1)) < (cur_graph_batch[..., None]+1))*(out.cpu())
                    #masked_lbls = torch.where(masked_lbls==0,np.nan,masked_lbls)
                    #masked_lbls = masked_lbls.detach().numpy()
                    #pred = np.nanargmax(masked_lbls,axis=1)
                    #pred = torch.from_numpy(pred)
                    #acc += accuracy(pred, data.y.cpu()) / length
                    #out = out - .00000001

            #comment out the follow pred and acc lines if enforcing predictions as described above
            pred = out.argmax(dim=1)
            acc += accuracy(pred, data.y) / length
            loss.backward()
            optimizer.step()

            # run model on validation set
            val_loss, val_acc, val_f1, top3_acc = test(model, val_loader)

        # Print metrics every epoch
        if(epoch % 10 == 0):
            print(f'Epoch {epoch:>3} | Train Loss: {total_loss:.2f} | Train Acc: {acc*100:>5.2f}% | Val Loss: {val_loss:.2f} | Val Acc: {val_acc*100:.2f}% | Top3 Val Acc: {top3_acc*100:.2f}%| F Score: {val_f1:.2f}')
            
    return model

@torch.no_grad()
def test(model, loader):
    
    criterion = torch.nn.CrossEntropyLoss()
    model.eval()
    loss = 0
    acc = 0
    top3_acc = 0
    fscore = 0
    
    
    for data in loader:
        data = data.to(device, non_blocking=True)
        length = len(loader)
        out = model(data.x, data.edge_index.long())
        prob = torch.exp(out)
        prob_sorted = torch.topk(prob,k=3).indices
        data.y = data.y.type(torch.LongTensor).to(device, non_blocking=True)
        loss += criterion(out, data.y) / length
        pred = out.argmax(dim=1)
        acc += accuracy(pred, data.y) / length
        top3_acc += torch.sum(torch.sum(prob_sorted==data.y.unsqueeze(1),dim=1),dim=0) / (length*(data.y.shape[0]))
        fscore += f1(pred.cpu(), data.y.cpu(), average='macro')    # micro looks better, but macro prob more accurate

    return loss, acc, fscore, top3_acc

def accuracy(pred_y, y):
    """Calculate accuracy."""
    return ((pred_y == y).sum() / len(y)).item()

In [125]:
# initialize (and reset weights of) model
# dim_h is hyperparameter of number of hidden layers

gin_trained = None
gin = None
gin = GIN(dim_h=800).to(device)
gin

GIN(
  (conv1): GINConv(nn=Sequential(
    (0): Linear(in_features=512, out_features=800, bias=True)
    (1): BatchNorm1d(800, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Linear(in_features=800, out_features=800, bias=True)
    (4): ReLU()
  ))
  (conv2): GINConv(nn=Sequential(
    (0): Linear(in_features=800, out_features=800, bias=True)
    (1): BatchNorm1d(800, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Linear(in_features=800, out_features=800, bias=True)
    (4): ReLU()
  ))
  (conv3): GINConv(nn=Sequential(
    (0): Linear(in_features=800, out_features=800, bias=True)
    (1): BatchNorm1d(800, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Linear(in_features=800, out_features=800, bias=True)
    (4): ReLU()
  ))
  (conv4): GINConv(nn=Sequential(
    (0): Linear(in_features=800, out_features=800, bias=True)
    (1): BatchNorm1d(800, eps=1e-05, momentu

In [126]:
# reset weights and train model
gin_trained = None
gin_trained = train(gin, train_loader,lr=.001)

Epoch   0 | Train Loss: 1093.11 | Train Acc: 24.38% | Val Loss: 4.67 | Val Acc: 17.15% | Top3 Val Acc: 0.4268786907196045%| F Score: 0.00
Epoch  10 | Train Loss: 2.60 | Train Acc: 56.86% | Val Loss: 3.55 | Val Acc: 41.00% | Top3 Val Acc: 0.4855004549026489%| F Score: 0.00
Epoch  20 | Train Loss: 2.27 | Train Acc: 57.74% | Val Loss: 3.09 | Val Acc: 41.69% | Top3 Val Acc: 0.5132522583007812%| F Score: 0.01
Epoch  30 | Train Loss: 1.99 | Train Acc: 58.82% | Val Loss: 2.83 | Val Acc: 42.81% | Top3 Val Acc: 0.547552227973938%| F Score: 0.01
Epoch  40 | Train Loss: 1.91 | Train Acc: 58.80% | Val Loss: 2.72 | Val Acc: 42.56% | Top3 Val Acc: 0.5768631100654602%| F Score: 0.02
Epoch  50 | Train Loss: 1.65 | Train Acc: 60.59% | Val Loss: 2.49 | Val Acc: 44.56% | Top3 Val Acc: 0.6092921495437622%| F Score: 0.05
Epoch  60 | Train Loss: 1.55 | Train Acc: 61.57% | Val Loss: 2.31 | Val Acc: 45.84% | Top3 Val Acc: 0.6345493793487549%| F Score: 0.09
Epoch  70 | Train Loss: 1.43 | Train Acc: 62.92% | Va

In [152]:
# get accuracy only for conclusion node from val set

total_conc_labels = len(val_dataset)
correct_conc_pred = 0

# dict of the form true_label:(predicted_labels)
incorrect_preds = {}

for graph in val_dataset:
    get_gin_predict = gin_trained(graph.x.to(device),graph.edge_index.long().to(device))
    get_conc_predict = np.argmax(get_gin_predict[-1].detach().cpu().numpy(),axis=0)
    if get_conc_predict == graph.y[-1].item():
        correct_conc_pred += 1
    if get_conc_predict != graph.y[-1].item():
        if graph.y[-1].item() in incorrect_preds:
            incorrect_preds[graph.y[-1].item()] = incorrect_preds[graph.y[-1].item()]+(get_conc_predict,)
        else:
            incorrect_preds[graph.y[-1].item()] = (get_conc_predict,)


def get_conc_gin_label_acc(i):
    correct_pred_i = 0
    total_num_i = 0

    for graph in val_dataset:
        get_gin_predict = gin_trained(graph.x.to(device),graph.edge_index.long().to(device))
        get_conc_predict = np.argmax(get_gin_predict[-1].detach().cpu().numpy(),axis=0)
        if graph.y[-1].item() == i:
            total_num_i += 1
            if get_conc_predict == i:
                correct_pred_i += 1
    if total_num_i == 0:
        return 0,0,0
    return (correct_pred_i / total_num_i), total_num_i, correct_pred_i

def get_topk_acc(k=3):

    top3_corr = 0

    for graph in val_dataset:
        get_gin_predict = gin_trained(graph.x.to(device),graph.edge_index.long().to(device))
        prob = torch.exp(get_gin_predict[-1])
        prob_sorted = torch.topk(prob,k=k).indices
        graph.y = graph.y.to(device)
        top3_corr += torch.sum(prob_sorted==graph.y[-1])

    return top3_corr / len(val_dataset)

k=10
print(f"restricting to conclusion nodes, our val_graph has accuracy being {correct_conc_pred/total_conc_labels*100:.2f}")
print(f"unk accuracy should be: {get_conc_gin_label_acc(554)[0]*100:.2f}")
print(f"top{k} accuracy is: {get_topk_acc(10).item()*100:.2f}%")

restricting to conclusion nodes, our val_graph has accuracy being 53.80
unk accuracy should be: 83.48
top10 accuracy is: 85.20%


In [22]:
# make a dictionary to record label frequency based on val set
# only consider conclusion/final nodes

# get max label used in pf_data
max_label = 0

for i in val_indices:
    if hidden_conc_pf_data.get(i).y[-1] > max_label:
        max_label = hidden_conc_pf_data.get(i).y[-1].to(int).item()

# initialize histogram for labels used in pf_data            
label_count = {}

for i in val_indices:
    for j in range(max_label+1):        
        label_count[j] = 0

for i in val_indices:
    label_count[hidden_conc_pf_data.get(i).y[-1].to(int).item()] += 1

step_count = 0
max = 0
max_freq_index = None   # find the most frequently used index
labels_never_used = 0
labels_used_once = 0
labels_used_twice = 0


for k,v in label_count.items():
    step_count += v
    if v > max:     
        max = v
        max_freq_index = k

    if v == 0:
        labels_never_used += 1
    if v == 1:
        labels_used_once += 1
    if v ==2:
        labels_used_twice += 1

In [23]:
print(f"total number of steps is:", step_count)
print(f"highest frequency label is {max_freq_index} and occurs {max} times")
print(f"final label used is {len(label_count)-1}")
print(label_count)
print(len(label_count),"unique labels are used")
print(labels_never_used,"unique labels never used")
print(labels_used_once, "unique labels used once")
print(labels_used_twice, "unique labels used twice")

total number of steps is: 500
highest frequency label is 554 and occurs 115 times
final label used is 554
{0: 2, 1: 22, 2: 0, 3: 1, 4: 3, 5: 1, 6: 2, 7: 0, 8: 19, 9: 0, 10: 2, 11: 1, 12: 0, 13: 0, 14: 0, 15: 0, 16: 4, 17: 1, 18: 2, 19: 0, 20: 0, 21: 0, 22: 1, 23: 1, 24: 0, 25: 3, 26: 1, 27: 0, 28: 0, 29: 0, 30: 1, 31: 1, 32: 0, 33: 0, 34: 0, 35: 2, 36: 0, 37: 0, 38: 0, 39: 0, 40: 1, 41: 10, 42: 1, 43: 0, 44: 15, 45: 7, 46: 0, 47: 1, 48: 4, 49: 3, 50: 0, 51: 0, 52: 2, 53: 6, 54: 2, 55: 3, 56: 0, 57: 0, 58: 7, 59: 2, 60: 0, 61: 0, 62: 8, 63: 0, 64: 2, 65: 0, 66: 1, 67: 0, 68: 0, 69: 0, 70: 0, 71: 1, 72: 1, 73: 0, 74: 3, 75: 3, 76: 0, 77: 0, 78: 0, 79: 0, 80: 2, 81: 2, 82: 0, 83: 28, 84: 2, 85: 1, 86: 7, 87: 7, 88: 1, 89: 1, 90: 1, 91: 1, 92: 1, 93: 0, 94: 1, 95: 0, 96: 0, 97: 6, 98: 2, 99: 0, 100: 1, 101: 18, 102: 0, 103: 0, 104: 5, 105: 0, 106: 1, 107: 0, 108: 0, 109: 0, 110: 2, 111: 2, 112: 0, 113: 0, 114: 0, 115: 0, 116: 0, 117: 0, 118: 0, 119: 1, 120: 0, 121: 0, 122: 0, 123: 1, 124: 

In [24]:
# do the same as above, but now only for val conclusion nodes

conc_label_acc_dict = {}

def get_conc_gin_label_acc(i):
    correct_pred_i = 0
    total_num_i = 0

    for graph in val_dataset:
        get_gin_predict = gin_trained(graph.x.to(device),graph.edge_index.long().to(device))
        get_conc_predict = np.argmax(get_gin_predict[-1].detach().cpu().numpy(),axis=0)
        if graph.y[-1].item() == i:
            total_num_i += 1
            if get_conc_predict == i:
                correct_pred_i += 1
    if total_num_i == 0:
        return 0,0,0
    return (correct_pred_i / total_num_i), total_num_i, correct_pred_i


total_conc_labels = 0
total_conc_correct_preds = 0
for i in range(hidden_conc_pf_data.num_classes):
    if i % 200 ==0:
        print("on label",i)
    try:
        container = get_conc_gin_label_acc(i)
        if container[1] != 0:
            conc_label_acc_dict[i] = (container[0],container[1])  # pair of accuracy and label count for each label
        total_conc_labels += container[1]
        total_conc_correct_preds += container[2]
    except Exception as e:
        print(e,get_conc_gin_label_acc(i))

print(conc_label_acc_dict)
print("if everything here is correct, our conc node val_acc should be:",total_conc_correct_preds/total_conc_labels)

on label 0
on label 200
on label 400
{0: (0.0, 2), 1: (0.7272727272727273, 22), 3: (0.0, 1), 4: (0.0, 3), 5: (0.0, 1), 6: (0.5, 2), 8: (0.5263157894736842, 19), 10: (0.5, 2), 11: (0.0, 1), 16: (0.25, 4), 17: (0.0, 1), 18: (0.5, 2), 22: (0.0, 1), 23: (0.0, 1), 25: (0.3333333333333333, 3), 26: (0.0, 1), 30: (0.0, 1), 31: (0.0, 1), 35: (0.0, 2), 40: (0.0, 1), 41: (0.2, 10), 42: (0.0, 1), 44: (0.2, 15), 45: (0.0, 7), 47: (0.0, 1), 48: (0.25, 4), 49: (0.0, 3), 52: (0.0, 2), 53: (0.0, 6), 54: (0.0, 2), 55: (0.3333333333333333, 3), 58: (0.42857142857142855, 7), 59: (0.0, 2), 62: (0.125, 8), 64: (0.0, 2), 66: (0.0, 1), 71: (0.0, 1), 72: (1.0, 1), 74: (0.0, 3), 75: (0.6666666666666666, 3), 80: (0.0, 2), 81: (1.0, 2), 83: (0.5357142857142857, 28), 84: (0.0, 2), 85: (0.0, 1), 86: (0.42857142857142855, 7), 87: (0.42857142857142855, 7), 88: (0.0, 1), 89: (0.0, 1), 90: (1.0, 1), 91: (0.0, 1), 92: (0.0, 1), 94: (0.0, 1), 97: (0.0, 6), 98: (0.0, 2), 100: (0.0, 1), 101: (0.6111111111111112, 18), 104: (

In [25]:
for k,v in conc_label_acc_dict.items():
    print(k,v[0])

0 0.0
1 0.7272727272727273
3 0.0
4 0.0
5 0.0
6 0.5
8 0.5263157894736842
10 0.5
11 0.0
16 0.25
17 0.0
18 0.5
22 0.0
23 0.0
25 0.3333333333333333
26 0.0
30 0.0
31 0.0
35 0.0
40 0.0
41 0.2
42 0.0
44 0.2
45 0.0
47 0.0
48 0.25
49 0.0
52 0.0
53 0.0
54 0.0
55 0.3333333333333333
58 0.42857142857142855
59 0.0
62 0.125
64 0.0
66 0.0
71 0.0
72 1.0
74 0.0
75 0.6666666666666666
80 0.0
81 1.0
83 0.5357142857142857
84 0.0
85 0.0
86 0.42857142857142855
87 0.42857142857142855
88 0.0
89 0.0
90 1.0
91 0.0
92 0.0
94 0.0
97 0.0
98 0.0
100 0.0
101 0.6111111111111112
104 0.6
106 1.0
110 0.0
111 0.0
119 0.0
123 0.0
126 0.0
127 0.0
130 1.0
134 0.4
135 0.0
136 0.0
137 0.0
140 0.0
142 0.0
144 0.0
151 0.0
157 0.0
158 0.0
159 0.0
167 1.0
169 0.0
170 0.0
171 0.3333333333333333
172 1.0
174 1.0
184 0.0
188 0.0
189 0.6
192 0.0
193 0.0
198 0.3333333333333333
202 0.0
203 0.0
204 0.0
208 0.0
210 0.0
228 0.0
230 0.0
231 0.0
232 0.3333333333333333
233 0.0
237 0.0
242 0.5
250 1.0
258 0.0
267 0.5
268 0.0
273 0.0
280 0.0
282 

In [26]:
sorted_label_acc_dict = {k: v for k, v in sorted(conc_label_acc_dict.items(), key=lambda item: item[1])}
sorted_label_acc_dict

{3: (0.0, 1),
 5: (0.0, 1),
 11: (0.0, 1),
 17: (0.0, 1),
 22: (0.0, 1),
 23: (0.0, 1),
 26: (0.0, 1),
 30: (0.0, 1),
 31: (0.0, 1),
 40: (0.0, 1),
 42: (0.0, 1),
 47: (0.0, 1),
 66: (0.0, 1),
 71: (0.0, 1),
 85: (0.0, 1),
 88: (0.0, 1),
 89: (0.0, 1),
 91: (0.0, 1),
 92: (0.0, 1),
 94: (0.0, 1),
 100: (0.0, 1),
 119: (0.0, 1),
 123: (0.0, 1),
 126: (0.0, 1),
 135: (0.0, 1),
 137: (0.0, 1),
 140: (0.0, 1),
 142: (0.0, 1),
 144: (0.0, 1),
 151: (0.0, 1),
 159: (0.0, 1),
 184: (0.0, 1),
 203: (0.0, 1),
 204: (0.0, 1),
 208: (0.0, 1),
 210: (0.0, 1),
 228: (0.0, 1),
 230: (0.0, 1),
 237: (0.0, 1),
 258: (0.0, 1),
 268: (0.0, 1),
 273: (0.0, 1),
 280: (0.0, 1),
 284: (0.0, 1),
 328: (0.0, 1),
 359: (0.0, 1),
 390: (0.0, 1),
 391: (0.0, 1),
 399: (0.0, 1),
 406: (0.0, 1),
 445: (0.0, 1),
 0: (0.0, 2),
 35: (0.0, 2),
 52: (0.0, 2),
 54: (0.0, 2),
 59: (0.0, 2),
 64: (0.0, 2),
 80: (0.0, 2),
 84: (0.0, 2),
 98: (0.0, 2),
 110: (0.0, 2),
 111: (0.0, 2),
 136: (0.0, 2),
 157: (0.0, 2),
 158: (0

In [27]:
# get accuracy only for conclusion node from val set

total_conc_labels = len(val_dataset)
correct_conc_pred = 0

# dict of the form true_label:(predicted_labels)
incorrect_preds = {}

for graph in val_dataset:
    get_gin_predict = gin_trained(graph.x.to(device),graph.edge_index.long().to(device))
    get_conc_predict = np.argmax(get_gin_predict[-1].detach().cpu().numpy(),axis=0)
    if get_conc_predict == graph.y[-1].item():
        correct_conc_pred += 1
    if get_conc_predict != graph.y[-1].item():
        if graph.y[-1].item() in incorrect_preds:
            incorrect_preds[graph.y[-1].item()] = incorrect_preds[graph.y[-1].item()]+(get_conc_predict,)
        else:
            incorrect_preds[graph.y[-1].item()] = (get_conc_predict,)

print("restricting to conclusion nodes, our val_graph has accuracy being",correct_conc_pred/total_conc_labels)

restricting to conclusion nodes, our val_graph has accuracy being 0.402


In [28]:
sorted_label_acc_dict = {k: v for k, v in sorted(label_acc_dict.items(), key=lambda item: item[1]).reverse()}

NameError: name 'label_acc_dict' is not defined

In [None]:
sorted_label_acc_dict

{423: (tensor(1.), 11),
 412: (tensor(1.), 9),
 470: (tensor(1.), 8),
 284: (tensor(1.), 7),
 334: (tensor(1.), 7),
 104: (tensor(1.), 6),
 480: (tensor(1.), 6),
 269: (tensor(1.), 5),
 358: (tensor(1.), 5),
 439: (tensor(1.), 5),
 484: (tensor(1.), 5),
 527: (tensor(1.), 5),
 242: (tensor(1.), 4),
 282: (tensor(1.), 4),
 326: (tensor(1.), 4),
 355: (tensor(1.), 4),
 382: (tensor(1.), 4),
 385: (tensor(1.), 4),
 469: (tensor(1.), 4),
 478: (tensor(1.), 4),
 177: (tensor(1.), 3),
 327: (tensor(1.), 3),
 349: (tensor(1.), 3),
 363: (tensor(1.), 3),
 375: (tensor(1.), 3),
 422: (tensor(1.), 3),
 424: (tensor(1.), 3),
 429: (tensor(1.), 3),
 434: (tensor(1.), 3),
 466: (tensor(1.), 3),
 467: (tensor(1.), 3),
 532: (tensor(1.), 3),
 548: (tensor(1.), 3),
 552: (tensor(1.), 3),
 14: (tensor(1.), 2),
 70: (tensor(1.), 2),
 82: (tensor(1.), 2),
 149: (tensor(1.), 2),
 234: (tensor(1.), 2),
 239: (tensor(1.), 2),
 243: (tensor(1.), 2),
 285: (tensor(1.), 2),
 286: (tensor(1.), 2),
 309: (tensor

In [None]:
# 8/24 EDA for final conclusion nodes from val set

# get 

In [None]:
# calculate true positive % for label i, i.e. (correct i predictions) / (total i labels)
# but now restrict to final conclusion node

label_acc_dict = {}

def get_gin_label_acc(i):
    correct_preds = 0
    total_labels = 0

    for graph in val_dataset:
        get_gin_predict = gin_trained(graph.x.to(device),graph.edge_index.long().to(device))
        get_predict = np.argmax(get_gin_predict.detach().cpu().numpy(),axis=1)
        if i>0:
            total_labels += sum(torch.div(graph.y[graph.y==i],i))
        elif i==0:
            total_labels += (graph.y[graph.y==0]).nelement()
        for idx, num in enumerate(get_predict):
            if num == i:
                if num==(graph.y.tolist())[idx]:
                    correct_preds += 1
    if total_labels == 0:
        return torch.tensor([0])
    return (correct_preds / total_labels), total_labels, correct_preds


total_labels = 0
total_correct_preds = 0
for i in range(hidden_conc_pf_data.num_classes):
    try:
        container = get_gin_label_acc(i)
        label_acc_dict[i] = (container[0],label_count[i])  # pair of accuracy and label count for each label
        total_labels += container[1]
        total_correct_preds += container[2]
    except Exception as e:
        pass
        #print(e,get_gin_label_acc(i))

print(label_acc_dict)
print("if everything here is correct, our conlusion node val_acc should be:",total_correct_preds/total_labels)

KeyboardInterrupt: 

In [None]:
%%capture cap

lr = [.001,.00005,.00001]
h = [400, 1600,3200,6400]

for rate in lr:
    for hidden in h:
        print(rate,hidden)
        gin_trained = None
        gin = None
        gin = GIN(dim_h=hidden).to(device)
        print(gin)

        # reset weights and train model
        gin_trained = None
        train(gin, train_loader,lr=lr)

with open('h_'+str(h)+'lr_'+str(lr)+'.txt', 'w') as f:
        f.write(cap.stdout)   

In [None]:
with open('h.txt', 'w') as f:
        f.write(cap.stdout)   

In [None]:
val_loader

In [None]:
val_loss, val_acc, val_f1 = test(gin_trained, val_loader)
print(f'Test Loss: {val_loss:.2f} | Test Acc: {val_acc*100:.2f}% | F Score: {val_f1:.2f}')
print()

Test Loss: 5.20 | Test Acc: 65.39% | F Score: 0.44

