In [1]:
# Load the dataset into PyTactician's visualizer.
from pytact import data_reader, graph_visualize_browse
import pathlib
from typing import Optional, List, DefaultDict
from pytact.data_reader import Node
from pytact.graph_api_capnp_cython import EdgeClassification
from pytact.graph_api_capnp_cython import Graph_Node_Label_Which
from collections import defaultdict
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import random
import numpy as np
#from sklearn.metrics import classification_report
from typing import List



In [2]:

class TreeLSTMCell(nn.Module):
    def __init__(self, hidden_size: int, edges_number: int):
        super(TreeLSTMCell, self).__init__()
        self.hidden_size = hidden_size
        self.edges_number = edges_number

        # Parameters for the LSTM cell
        # Note: Each edge type has its own set of LSTM parameters
        self.W_i = nn.Parameter(torch.randn(hidden_size, hidden_size))
        self.W_f = nn.Parameter(torch.randn(edges_number, hidden_size, hidden_size))
        self.W_o = nn.Parameter(torch.randn(hidden_size, hidden_size))
        self.W_c = nn.Parameter(torch.randn(hidden_size, hidden_size))
        
        self.U_i = nn.Parameter(torch.randn(hidden_size, hidden_size))
        self.U_f = nn.Parameter(torch.randn(hidden_size, hidden_size))
        self.U_o = nn.Parameter(torch.randn(hidden_size, hidden_size))
        self.U_c = nn.Parameter(torch.randn(hidden_size, hidden_size))
        
        self.b_i = nn.Parameter(torch.randn(hidden_size))
        self.b_f = nn.Parameter(torch.randn(hidden_size))
        self.b_o = nn.Parameter(torch.randn(hidden_size))
        self.b_c = nn.Parameter(torch.randn(hidden_size))

    def forward(self, h_, c_, e, x: torch.Tensor):
        # h_ and c_ are lists of hidden states and cell states of child nodes
        # e is a list of edge types

        H, C = torch.stack(h_), torch.stack(c_)
        # Gates calculations
        i_t = torch.sigmoid(torch.mean(torch.stack([
            torch.mm(h, self.W_i) for h, _ in zip(H, e)
        ]), dim=0) + torch.mm(x, self.U_i) + self.b_i)

        f_t = torch.sigmoid(torch.mean(torch.stack([
            torch.mm(h, self.W_f[edge_idx]) for h, edge_idx in zip(H, e)
        ]), dim=0) + torch.mm(x, self.U_f) + self.b_f)

        o_t = torch.sigmoid(torch.mean(torch.stack([
            torch.mm(h, self.W_o) for h, _ in zip(H, e)
        ]), dim=0) + torch.mm(x, self.U_o) + self.b_o)

        c_hat_t = torch.tanh(torch.mean(torch.stack([
            torch.mm(h, self.W_c) for h, edge_idx in zip(H, e)
        ]), dim=0) + torch.mm(x, self.U_c) + self.b_c)

        # Calculate cell state
        C_t = torch.sum(f_t * C, dim=0) + i_t * c_hat_t

        # Calculate hidden state
        h_t = o_t * torch.tanh(C_t)
        return h_t.squeeze(1), C_t.squeeze(1)


class DecodeEmbedding(nn.Module):
    def __init__(self, hidden_size: int, edges_number: int):
        super(DecodeEmbedding, self).__init__()
        # Initialize the Ep matrix as a trainable parameter
        self.Ep = nn.Parameter(torch.randn(edges_number, hidden_size, hidden_size))
        # Initialize cp as a trainable parameter 
        self.cp = nn.Parameter(torch.zeros(1, hidden_size))

    def forward(self, h: torch.Tensor, edge_idx: int) -> torch.Tensor:
        # Ensure cp is broadcasted correctly over the batches
        transformed_h = torch.mm(h, self.Ep[edge_idx])
        # Add cp (broadcasted) to the result of the matrix multiplication
        # Note: If cp is intended to be a fixed bias, it should be initialized outside the parameter list
        return torch.sigmoid(transformed_h + self.cp)  # squeeze cp to match dimensions
    

class LSTMEncoderDecoderClasifier(nn.Module): 
    def __init__(self, hidden_size: int, edges_number: int, nodes_number: int): 
        super(LSTMEncoderDecoderClasifier, self).__init__()
        self.emb = nn.Embedding(nodes_number, hidden_size)
        self.dec = DecodeEmbedding(hidden_size, edges_number)
        self.enc = TreeLSTMCell(hidden_size, edges_number)
        self.R = nn.Linear(hidden_size, nodes_number, bias=True) #output
      
    def encode_dag(self, dag):
        def encode_node(node): 
            if node.children and node.label.which.name != 'REL': 
                h_list = [] 
                c_list = []
                e_list = []
                for edge_label, child in node.children:
                    h, c = encode_node(child)
                    h_list.append(h)
                    c_list.append(c)
                    e_list.append(edge_label.value)
                return self.enc(h_list, c_list, e_list, self.emb(torch.tensor(node.label.which.value)).unsqueeze(0))
            else: 
                h = self.emb(torch.tensor(node.label.which.value)).unsqueeze(0)
                c = self.emb(torch.tensor(node.label.which.value)).unsqueeze(0)
                return h, c
        h, _ = encode_node(dag)
        return h
    
    def decode_dag(self, dag, h, max_depth):
        decoded_graph = []
        def decode_node(node, h, depth, max_depth):
            logits = self.R(h)
            decoded_graph.append((torch.softmax(logits, dim=-1), depth))
            if depth < max_depth and node.children and not node.label.which.name == 'REL':
                for edge_label, child in node.children: 
                    h = self.dec(h, edge_label.value)
                    decode_node(child, h, depth+1, max_depth)

        decode_node(dag, h, 1, max_depth)
        return decoded_graph
    
    def forward(self, dag, max_depth):
        enc = self.encode_dag(dag)
        return self.decode_dag(dag, enc, max_depth)

    
class LabelGetter: 
    def __init__(self): 
        self.labels = []
    def get_labels(self, graph, max_depth):
        self.labels = []
        self.get_labels_helper(graph, 1, max_depth)
        return self.labels
    def get_labels_helper(self, graph, depth, max_depth):
        self.labels.append(graph.label.which.value)
        if graph.children and not graph.label.which.name == 'REL' and depth < max_depth: 
            for _, child in list(graph.children):
                self.get_labels_helper(child, depth+1, max_depth)
                
def get_file_size(reader, dataset_pointer): 
        pdl = dataset_pointer.lowlevel
        size = len(pdl.graph.nodes)
        return size

In [15]:
# Constants and configurations
DATASET_PATH = '../../../v15-stdlib-coq8.11/dataset'
FILE_PATH = "coq-tactician-stdlib.8.11.dev/theories/Init/Logic.bin"
DATASET_PATH = pathlib.Path(DATASET_PATH)
FILE_PATH = pathlib.Path(FILE_PATH)

In [16]:
# Randomness
RANDOM_SEED = 42
random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)

# Model Parameters 
NODES_NUMBER = 30
EMBEDDING_SIZE = 8
HIDDEN_SIZE = 16
EDGES_NUMBER = 50

# Model Introduction
model = LSTMEncoderDecoderClasifier(HIDDEN_SIZE, EDGES_NUMBER, NODES_NUMBER)
lg = LabelGetter() #graph node_labels extractor

# Model Training Details
LEARNING_RATE = 0.001
BATCH_SIZE = 20
MAX_DECODING_DEPTH = 3
EPOCHS = 3
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

In [17]:
with data_reader.data_reader(DATASET_PATH) as reader:
    dataset_pointer = reader[FILE_PATH]      
    grpahs_number = get_file_size(reader, dataset_pointer)
    shuffled_indexes = list(range(grpahs_number)) # change indexes to random_shuffle
    random.shuffle(shuffled_indexes)
    train_indexes = shuffled_indexes[:grpahs_number*7//10]
    test_indexes = shuffled_indexes[grpahs_number*7//10:]
    for max_depth in range(1, MAX_DECODING_DEPTH+1):
        for epoch in range(EPOCHS):
            # Training Loop
            correct = 0
            total = 0
            total_loss = 0
            for i in train_indexes:
                graph = dataset_pointer.node_by_id(i)
                labels = lg.get_labels(graph, max_depth)
                optimizer.zero_grad()
                output_whole = model(graph, max_depth)    
                output_whole = [j for j,_ in output_whole]
                loss = criterion(torch.stack(output_whole).squeeze(1), torch.tensor(labels))/len(labels)
                loss.backward()
                if (i + 1) % BATCH_SIZE == 0:
                    optimizer.step()
                    optimizer.zero_grad()
                total_loss += loss.item()
                predictions = torch.argmax(torch.stack(output_whole).squeeze(1), dim=1)
                correct += (predictions == torch.tensor(labels)).sum().item()
                total += len(labels)
            trainig_accuracy = correct / total if total > 0 else 0
            
            # Testing Loop
            correct = 0
            total = 0
            for i in test_indexes:
                graph = dataset_pointer.node_by_id(i)
                labels = lg.get_labels(graph, max_depth)
                with torch.no_grad():
                    output_whole = model(graph, max_depth=max_depth)

                    output_whole = [j for j,_ in output_whole]
                    predictions = torch.argmax(torch.stack(output_whole).squeeze(1), dim=1)
                    correct += (predictions == torch.tensor(labels)).sum().item()
                    total += len(labels)
            
            accuracy = correct / total if total > 0 else 0
            print(f'Max decoding depth: {max_depth}, Epoch {epoch+1}/{EPOCHS}, Training Loss: {total_loss / len(train_indexes)}, TrainingAccuracy: {trainig_accuracy* 100:.2f}%, Test Accuracy: {accuracy * 100:.2f}%')

Max decoding depth: 1, Epoch 1/3, Training Loss: 3.128580997090639, TrainingAccuracy: 52.36%, Test Accuracy: 63.45%
Max decoding depth: 1, Epoch 2/3, Training Loss: 2.794567477263431, TrainingAccuracy: 70.91%, Test Accuracy: 77.78%
Max decoding depth: 1, Epoch 3/3, Training Loss: 2.6824124881675706, TrainingAccuracy: 85.04%, Test Accuracy: 87.44%
Max decoding depth: 2, Epoch 1/3, Training Loss: 1.2233384361171864, TrainingAccuracy: 37.36%, Test Accuracy: 44.68%
Max decoding depth: 2, Epoch 2/3, Training Loss: 1.196088734693178, TrainingAccuracy: 48.20%, Test Accuracy: 52.87%
Max decoding depth: 2, Epoch 3/3, Training Loss: 1.1650459589897346, TrainingAccuracy: 59.30%, Test Accuracy: 61.35%
Max decoding depth: 3, Epoch 1/3, Training Loss: 0.8138397572566332, TrainingAccuracy: 53.95%, Test Accuracy: 60.39%
Max decoding depth: 3, Epoch 2/3, Training Loss: 0.8096095865061339, TrainingAccuracy: 62.22%, Test Accuracy: 63.95%
Max decoding depth: 3, Epoch 3/3, Training Loss: 0.8060436052268769