In [1]:
from use_dataset import ProofDataset
import torch
import numpy as np
from torch_geometric.data import Batch
from torch_geometric.loader import DataLoader, NeighborLoader
import random
from sklearn.metrics import f1_score as f1
from sklearn.utils.class_weight import compute_class_weight
import torch.nn.functional as F
from torch.nn import Linear, ReLU, Dropout
from torch_geometric.nn import GCNConv

torch.manual_seed(0)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [2]:
# use file_limit=5000 to only load and verify the first 5000 graphs (~60 MB)
# if entire graph dataset is desired, use file_limit=None
file_limit = 5000    # desired number of graphs to work with
vocab_size = 1598   # number of characters in our vocabulary

pf_data = ProofDataset(root="data/",file_limit=file_limit)  

In [3]:
# make train/val/test for GCN
# set seed for random # generation
random.seed(10)
length = file_limit
total_indices = [i for i in range(file_limit)]

# create index vectors to filter dataset
train_indices = random.sample(total_indices, int(length*.8))
train_indices.sort()

val_index_options = [x for x in total_indices if x not in train_indices]
val_indices = random.sample(val_index_options, int(length*.1))
val_indices.sort()

test_index_options = [x for x in total_indices if x not in train_indices if x not in val_indices]
test_indices = random.sample(test_index_options, int(length*.1))
test_indices.sort()

# Create training, validation, and test sets
train_dataset = pf_data[train_indices]
val_dataset = pf_data[val_indices]
test_dataset = pf_data[test_indices]

print(f'Training set   = {len(train_dataset)} graphs')
print(f'Validation set = {len(val_dataset)} graphs')
print(f'Test set       = {len(test_dataset)} graphs')

Training set   = 4000 graphs
Validation set = 500 graphs
Test set       = 500 graphs


In [4]:
# Create mini-batches
# Shuffling for now; probably will remove shuffling later
train_loader = DataLoader(train_dataset, batch_size=1000, shuffle=True,num_workers=0)
val_loader   = DataLoader(val_dataset, batch_size=1000, shuffle=False)
test_loader  = DataLoader(test_dataset, batch_size=1000, shuffle=False)

In [5]:
print('\nTrain loader:')
for i, batch in enumerate(train_loader):
    print(f' - Batch {i}: {batch}')

print('\nValidation loader:')
for i, batch in enumerate(val_loader):
    print(f' - Batch {i}: {batch}')

print('\nTest loader:')
for i, batch in enumerate(test_loader):
    print(f' - Batch {i}: {batch}')


Train loader:
 - Batch 0: DataBatch(x=[5770, 512], edge_index=[2, 4770], y=[5770], batch=[5770], ptr=[1001])
 - Batch 1: DataBatch(x=[6208, 512], edge_index=[2, 5208], y=[6208], batch=[6208], ptr=[1001])
 - Batch 2: DataBatch(x=[6196, 512], edge_index=[2, 5196], y=[6196], batch=[6196], ptr=[1001])
 - Batch 3: DataBatch(x=[6376, 512], edge_index=[2, 5376], y=[6376], batch=[6376], ptr=[1001])

Validation loader:
 - Batch 0: DataBatch(x=[3207, 512], edge_index=[2, 2707], y=[3207], batch=[3207], ptr=[501])

Test loader:
 - Batch 0: DataBatch(x=[3037, 512], edge_index=[2, 2537], y=[3037], batch=[3037], ptr=[501])


In [6]:
# make a dictionary to record label frequency

# get max label used in pf_data
max_label = 0

for i in range(file_limit):
    for j in pf_data.get(i).y:
        if j > max_label:
            max_label = j.to(int).item()

# initialize histogram for labels used in pf_data            
label_count = {}

for i in range(file_limit):
    for j in range(max_label+1):        
        label_count[j] = 0

for i in range(file_limit):
    for j in pf_data.get(i).y:
        label_count[j.to(int).item()] += 1

step_count = 0
max = 0
max_freq_index = None   #find the most frequently used index
labels_never_used = 0
labels_used_once = 0
labels_used_twice = 0


for k,v in label_count.items():
    step_count += v
    if v > max:     
        max = v
        max_freq_index = k

    if v == 0:
        labels_never_used += 1
    if v == 1:
        labels_used_once += 1
    if v ==2:
        labels_used_twice += 1

In [7]:
print(f"total number of steps is:", step_count)
print(f"highest frequency label is {max_freq_index} and occurs {max} times")
print(f"final label used is {len(label_count)-1}")
print(label_count)
print(len(label_count),"unique labels are used")
print(labels_never_used,"unique labels never used")
print(labels_used_once, "unique labels used once")
print(labels_used_twice, "unique labels used twice")

total number of steps is: 30794
highest frequency label is 0 and occurs 5140 times
final label used is 3228
{0: 5140, 1: 545, 2: 52, 3: 10, 4: 3, 5: 23, 6: 7, 7: 269, 8: 2, 9: 13, 10: 39, 11: 41, 12: 389, 13: 25, 14: 3, 15: 27, 16: 2, 17: 244, 18: 22, 19: 56, 20: 5, 21: 11, 22: 7, 23: 20, 24: 107, 25: 5, 26: 46, 27: 74, 28: 13, 29: 9, 30: 8, 31: 9, 32: 5, 33: 2, 34: 17, 35: 4, 36: 6, 37: 4, 38: 2, 39: 25, 40: 1, 41: 4, 42: 16, 43: 2, 44: 4, 45: 6, 46: 1, 47: 12, 48: 3, 49: 3, 50: 4, 51: 5, 52: 17, 53: 7, 54: 16, 55: 1, 56: 6, 57: 6, 58: 2, 59: 3, 60: 8, 61: 3, 62: 3, 63: 4, 64: 11, 65: 5, 66: 6, 67: 14, 68: 49, 69: 17, 70: 6, 71: 9, 72: 6, 73: 7, 74: 5, 75: 7, 76: 2, 77: 4, 78: 2, 79: 3, 80: 4, 81: 3, 82: 3, 83: 2, 84: 5, 85: 1, 86: 2, 87: 4, 88: 6, 89: 12, 90: 1, 91: 2, 92: 13, 93: 11, 94: 41, 95: 7, 96: 2, 97: 1, 98: 5, 99: 3, 100: 1, 101: 2, 102: 8, 103: 1, 104: 3, 105: 11, 106: 7, 107: 2, 108: 3, 109: 19, 110: 8, 111: 2, 112: 12, 113: 7, 114: 13, 115: 2, 116: 17, 117: 5, 118: 1, 11

In [8]:
# make array of label frequencies for sklearn compute_class_weight using entire dataset
# really should be doing this for train set (otherwise, data leakage...)
# however, train set may not include certain labels, which leads to error in compute_class_weight

# make array of unique classes
class_num_arr = [i for i in range(len(label_count))]
class_num_arr = np.array(class_num_arr)

# make array of all data points with labels
lbl_arr = np.array([])
for i in range(file_limit):
    for y in pf_data.get(i).y:
        lbl_arr = np.append(lbl_arr,[y.numpy()],axis=0).astype(int)

class_weights = compute_class_weight(class_weight="balanced",classes = class_num_arr, y=lbl_arr)
class_weights = torch.from_numpy(class_weights).float().to(device)

In [9]:
# Make class for GCN model

class GCN(torch.nn.Module):
    """GCN"""
    def __init__(self, dim_h):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(pf_data.num_features, dim_h)
        self.conv2 = GCNConv(dim_h, dim_h)
        self.conv3 = GCNConv(dim_h, dim_h)
        self.conv4 = GCNConv(dim_h, dim_h)
        self.lin = Linear(dim_h, len(class_num_arr))

    def forward(self, x, edge_index):
        h = self.conv1(x, edge_index)
        h = h.relu()
        h = F.dropout(h, p=0.1, training=self.training)
        h = self.conv2(h, edge_index)        
        h = h.relu()
        h = F.dropout(h, p=0.5, training=self.training)
        h = self.lin(h)
        
        return F.log_softmax(h, dim=1)

In [10]:
# GCN model training

def train(model, loader):
    criterion = torch.nn.CrossEntropyLoss()
    # commented out code is to use class weights to account for imbalanced dataset
    #criterion = torch.nn.CrossEntropyLoss(weight=class_weights)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    epochs = 1000

    model.train()
    for epoch in range(epochs+1):
        total_loss = 0
        acc = 0
        val_loss = 0
        val_acc = 0
        cur_graph = 0   # used to keep track of current statement to enforce preds of only PREVIOUS labels in training

        # Train on batches
        for data in loader:
            cur_graph_batch = data.batch + cur_graph            
            cur_graph += torch.max(data.batch) + 1
            data = data.to(device, non_blocking=True)
            data.y = data.y.to(torch.float).to(device, non_blocking=True)
            optimizer.zero_grad()
            length = len(loader)
            out = model(data.x, data.edge_index.long())
            data.y = data.y.type(torch.LongTensor).to(device, non_blocking=True)
            out = out.type(torch.float32).to(device, non_blocking=True)
            loss = criterion(out, data.y)
            total_loss += loss / length

            # commented out code below is meant to enforce predictions to only come from previous theorems
            # for dict to be properly created, you must delete and recreate test.pt by rerunning 
            # pf_data = ProofDataset(root="data/",file_limit=file_limit)
            
            #dict = pf_data.class_corr
            #dict_keys = [k for k in dict.keys()]            

            #def return_next_lowest_idx(num):
                #if num in dict_keys:
                    #return dict[num]
                #while num not in dict_keys:
                    #try:
                        #num -= 1
                        #return dict[num]
                    #except:
                        #pass

            #with torch.no_grad():
                #cur_graph_batch.apply_(return_next_lowest_idx)
                #masked_lbls = (torch.arange(out.size(1)) < (cur_graph_batch[..., None]+1))*(out.cpu())
                #try:
                    #masked_lbls = (torch.arange(out.size(1)) < (cur_graph_batch[..., None]+1))*(out.cpu())
                    #masked_lbls = torch.where(masked_lbls==0,np.nan,masked_lbls)
                    #masked_lbls = masked_lbls.detach().numpy()                
                    #pred = np.nanargmax(masked_lbls,axis=1)
                    #pred = torch.from_numpy(pred)
                    #acc += accuracy(pred, data.y.cpu()) / length
                #except Exception as e:
                    #print("a lil error")
                    #out = out + .00000001
                    #masked_lbls = (torch.arange(out.size(1)) < (cur_graph_batch[..., None]+1))*(out.cpu())
                    #masked_lbls = torch.where(masked_lbls==0,np.nan,masked_lbls)
                    #masked_lbls = masked_lbls.detach().numpy()
                    #pred = np.nanargmax(masked_lbls,axis=1)
                    #pred = torch.from_numpy(pred)
                    #acc += accuracy(pred, data.y.cpu()) / length
                    #out = out - .00000001

            #comment out the follow pred and acc lines if enforcing predictions as described above
            pred = out.argmax(dim=1)
            acc += accuracy(pred, data.y) / length

            loss.backward()
            optimizer.step()

            # run model on validation set
            val_loss, val_acc, val_f1 = test(model, val_loader)

        # Print metrics every epoch
        if(epoch % 10 == 0):
            print(f'Epoch {epoch:>3} | Train Loss: {total_loss:.2f} | Train Acc: {acc*100:>5.2f}% | Val Loss: {val_loss:.2f} | Val Acc: {val_acc*100:.2f}% | F Score: {val_f1:.2f}')
            
    return model

@torch.no_grad()
def test(model, loader):
    
    criterion = torch.nn.CrossEntropyLoss()
    model.eval()
    loss = 0
    acc = 0
    fscore = 0
    
    for data in loader:
        data = data.to(device, non_blocking=True)
        length = len(loader)
        out = model(data.x, data.edge_index.long())
        data.y = data.y.type(torch.LongTensor).to(device, non_blocking=True)
        loss += criterion(out, data.y) / length
        pred = out.argmax(dim=1)
        acc += accuracy(pred, data.y) / length
        fscore += f1(pred.cpu(), data.y.cpu(), average='macro')    # micro looks better, but macro prob more accurate

    return loss, acc, fscore

def accuracy(pred_y, y):
    """Calculate accuracy."""
    return ((pred_y == y).sum() / len(y)).item()

In [11]:
# initialize (and reset weights of) model
# dim_h is hyperparameter of number of hidden layers

gcn_trained = None
gcn = None
gcn = GCN(dim_h=800).to(device)
gcn

GCN(
  (conv1): GCNConv(512, 800)
  (conv2): GCNConv(800, 800)
  (conv3): GCNConv(800, 800)
  (conv4): GCNConv(800, 800)
  (lin): Linear(in_features=800, out_features=3229, bias=True)
)

In [12]:
# reset weights and train model
gcn_trained = None
gcn_trained = train(gcn, train_loader)

In [None]:
test_loss, test_acc, test_f1 = test(gcn_trained, test_loader)
print(f'Test Loss: {test_loss:.2f} | Test Acc: {test_acc*100:.2f}% | F Score: {test_f1:.2f}')
print()

Test Loss: 6.50 | Test Acc: 42.86% | F Score: 0.08

