In [1]:
# Load the dataset into PyTactician's visualizer.
from pytact import data_reader, graph_visualize_browse
import pathlib
from typing import Optional, List, DefaultDict
from pytact.data_reader import Node
from pytact.graph_api_capnp_cython import EdgeClassification
from pytact.graph_api_capnp_cython import Graph_Node_Label_Which
from collections import defaultdict
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import random
import numpy as np
from sklearn.metrics import classification_report



In [2]:
class BasicCSRNN(nn.Module):
    def __init__(self, embedding_size, hidden_size, nodes_number, edges_number): #Nodes_number is an input tokens size
        super(BasicCSRNN, self).__init__()
        self.embedding = nn.Embedding(nodes_number, embedding_size)
        self.Wx = torch.randn(embedding_size, hidden_size) # n_inputs X n_neurons
        self.We = torch.randn(edges_number, hidden_size, hidden_size) # n_edges X 1 X n_neurons
        self.hidden_size = hidden_size
        self.b = torch.zeros(1, hidden_size) # 1 X n_neurons

    def forward(self, node):
        return self.node_forward(node)

    def node_forward(self, node):
        emb = self.embedding(torch.tensor(node.label.which.value))
        emb = emb.view(1, -1)
        x = torch.mm(emb, self.Wx)
        if node.children and not node.label.which.name == 'REL':
            hidden = torch.mean(torch.stack([x  + torch.mm(self.node_forward(child), self.We[edge_type.value]) for edge_type, child in list(node.children)]), dim=0) 
        else:
            # Ensure that the zero tensor is of the correct shape [batch size, hidden size]
            hidden = torch.zeros(1, self.hidden_size, dtype=torch.float, device=x.device)
        return torch.tanh(x + hidden + self.b)
    
    
class RNNLabelDecode(nn.Module):
    def __init__(self, hidden_size, output_size, edges_number):
        super(RNNLabelDecode, self).__init__()
        self.hidden_size = hidden_size
        self.We = nn.Parameter(torch.randn(edges_number, hidden_size, hidden_size))
        self.Wdc = nn.Linear(hidden_size, output_size, bias=True)
        self.be = torch.zeros(1, hidden_size) # 1 X n_neurons
        
        # Keep track of edges if needed
        self.decoded_edges = []
    def forward(self, embedding, node, max_depth):
        self.decoded_edges = []
        self.node_decode_forward(embedding, node, depth=1, max_depth=max_depth)
        return self.decoded_edges
    
    def node_decode_forward(self, embedding, node, depth, max_depth):
        # Decode label 
        logits = self.Wdc(embedding)
        probabilities = F.softmax(logits)
        self.decoded_edges.append(probabilities)
        if node.children and not node.label.which.name == 'REL' and depth < max_depth:
            for edge_type, child in node.children:                 
                new_embedding = torch.mm(embedding, self.We[edge_type.value]) + self.be #Calculate new hidden state
                self.node_decode_forward(new_embedding, child, depth=depth+1, max_depth=max_depth) # Decode child
                
                
class DecoderRNNClasifier(nn.Module):
    def __init__(self, embedding_size, hidden_size, nodes_number, edges_number):
        super(DecoderRNNClasifier, self).__init__()    
        self.dec = RNNLabelDecode(hidden_size, nodes_number, edges_number)
        self.enc = BasicCSRNN(embedding_size, hidden_size, nodes_number, edges_number) 
    
    def forward(self, node, max_depth): 
        emb = self.enc(node)
        dec = self.dec(emb, node, max_depth)
         
        return dec


class LabelGetter: 
    def __init__(self): 
        self.labels = []
    def get_labels(self, graph, max_depth):
        self.labels = []
        self.get_labels_helper(graph, 1, max_depth)
        return self.labels
    def get_labels_helper(self, graph, depth, max_depth):
        self.labels.append(graph.label.which.value)
        if graph.children and not graph.label.which.name == 'REL' and depth < max_depth: 
            for _, child in list(graph.children):
                self.get_labels_helper(child, depth+1, max_depth)
                
def get_file_size(reader, dataset_pointer): 
        pdl = dataset_pointer.lowlevel
        size = len(pdl.graph.nodes)
        return size

In [3]:
# Constants and configurations
DATASET_PATH = '../../../../v15-stdlib-coq8.11/dataset'
FILE_PATH = "coq-tactician-stdlib.8.11.dev/theories/Init/Logic.bin"
DATASET_PATH = pathlib.Path(DATASET_PATH)
FILE_PATH = pathlib.Path(FILE_PATH)

In [4]:
# Randomness
RANDOM_SEED = 42
random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)

# Model Parameters 
NODES_NUMBER = 30
EMBEDDING_SIZE = 8
HIDDEN_SIZE = 16
EDGES_NUMBER = 50

# Model Introduction
model = DecoderRNNClasifier(EMBEDDING_SIZE, HIDDEN_SIZE, NODES_NUMBER, EDGES_NUMBER)
lg = LabelGetter() #graph node_labels extractor

# Model Training Details
LEARNING_RATE = 0.001
BATCH_SIZE = 20
MAX_DECODING_DEPTH = 3
EPOCHS = 3
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

In [5]:
with data_reader.data_reader(DATASET_PATH) as reader:
    dataset_pointer = reader[FILE_PATH]      
    grpahs_number = get_file_size(reader, dataset_pointer)
    shuffled_indexes = list(range(grpahs_number)) # change indexes to random_shuffle
    random.shuffle(shuffled_indexes)
    train_indexes = shuffled_indexes[:grpahs_number*7//10]
    test_indexes = shuffled_indexes[grpahs_number*7//10:]
    for max_depth in range(1, MAX_DECODING_DEPTH+1):
        for epoch in range(EPOCHS):
            # Training Loop
            correct = 0
            total = 0
            total_loss = 0
            for i in train_indexes:
                graph = dataset_pointer.node_by_id(i)
                labels = lg.get_labels(graph, max_depth)
                optimizer.zero_grad()
                output_whole = model(graph, max_depth=max_depth)
                loss = criterion(torch.stack(output_whole).squeeze(1), torch.tensor(labels))/len(labels)
                loss.backward()
                if (i + 1) % BATCH_SIZE == 0:
                    optimizer.step()
                    optimizer.zero_grad()
                total_loss += loss.item()
                predictions = torch.argmax(torch.stack(output_whole).squeeze(1), dim=1)
                correct += (predictions == torch.tensor(labels)).sum().item()
                total += len(labels)
            trainig_accuracy = correct / total if total > 0 else 0
            
            # Testing Loop
            correct = 0
            total = 0
            for i in test_indexes:
                graph = dataset_pointer.node_by_id(i)
                labels = lg.get_labels(graph, max_depth)
                with torch.no_grad():
                    output_whole = model(graph, max_depth=max_depth)
                    predictions = torch.argmax(torch.stack(output_whole).squeeze(1), dim=1)
                    correct += (predictions == torch.tensor(labels)).sum().item()
                    total += len(labels)
            
            accuracy = correct / total if total > 0 else 0
            print(f'Max decoding depth: {max_depth}, Epoch {epoch+1}/{EPOCHS}, Training Loss: {total_loss / len(train_indexes)}, TrainingAccuracy: {trainig_accuracy* 100:.2f}%, Test Accuracy: {accuracy * 100:.2f}%')

  probabilities = F.softmax(logits)


Max decoding depth: 1, Epoch 1/3, Training Loss: 3.052128731485556, TrainingAccuracy: 67.51%, Test Accuracy: 89.13%
Max decoding depth: 1, Epoch 2/3, Training Loss: 2.6578493127871217, TrainingAccuracy: 91.94%, Test Accuracy: 92.03%
Max decoding depth: 1, Epoch 3/3, Training Loss: 2.562999010732011, TrainingAccuracy: 93.79%, Test Accuracy: 94.66%
Max decoding depth: 2, Epoch 1/3, Training Loss: 1.1934545814126256, TrainingAccuracy: 48.01%, Test Accuracy: 60.79%
Max decoding depth: 2, Epoch 2/3, Training Loss: 1.1457983314641975, TrainingAccuracy: 66.84%, Test Accuracy: 71.08%
Max decoding depth: 2, Epoch 3/3, Training Loss: 1.132676905871813, TrainingAccuracy: 76.18%, Test Accuracy: 77.36%
Max decoding depth: 3, Epoch 1/3, Training Loss: 0.8206769919992821, TrainingAccuracy: 47.33%, Test Accuracy: 48.19%
Max decoding depth: 3, Epoch 2/3, Training Loss: 0.8182403788558541, TrainingAccuracy: 49.05%, Test Accuracy: 51.46%
Max decoding depth: 3, Epoch 3/3, Training Loss: 0.8166042187206709