### Proving Theorems In Propositional Logic with LSTM-Based Text Generators
#### Author: Omar Afifi

Consider a number of propositional (i.e. variable-free) sentences (premsises). For example: 

1. p→q 
2. o
3. pv¬o

A propositional proof from the premises to a conclusion (another sentence) is a sequence of variable-free statements that follow from the premises by logical deduction rules (e.g. modus ponens, modus tollens, modus tollendo ponens, dysjuntive elimination, etc ... )

For example, a proof of the propositional sentence (q) from the preceding premises is as follows: 

4. ¬¬o (from 2, double negation)
5. p (from 3 and 4, dysjuntive elimination )
5. q (from 1 and 5, modus ponens)

QED.


This notebook explores the utility of using LSTM text-generators to generate a propositional proof given a collection of propositional sentences. Our hope is that it can be helpful as a stepping stone to making progress in the arena of stochastic theorem provers. 

Credits: Hugging Face User ergotts for building this dataset: https://huggingface.co/datasets/ergotts/propositional-logic


### Loading Data and preparring the input

In [1]:
import process_data

#load the data from hugging face mode = 'w' means that we are tokenizing words rather than characters or sentences. 
proofs_dataset = process_data.LoadLogicData(mode = 'w') 

#format the proofs: essentially just mapping words to integers and then creating n-gram sequences
word_to_int, int_to_word, sequenced_proofs = process_data.generate_sequences(proofs_dataset)

#split data into input and label by setting label equal to next word.
#sequence length is the length of eqch sequence, this allows us to pack them during training. 
X, sequence_lengths,y = process_data.makeXy(sequenced_proofs)


  from .autonotebook import tqdm as notebook_tqdm


### making the data compatible with torch API

In [2]:
import torch as t
import torch.utils.data  as data

X = t.tensor(X, dtype = t.int64)
sequence_lengths =  t.tensor(sequence_lengths, dtype = t.int64)
y = t.tensor(y, dtype = t.int64).view(-1,1)

torch_data = data.DataLoader(data.TensorDataset(X,sequence_lengths, y),
                            batch_size = 100)

In [3]:
X.shape


torch.Size([50472, 106])

In [4]:
y


tensor([[63],
        [14],
        [52],
        ...,
        [57],
        [13],
        [ 1]])

### Loading the Model and Training

In [5]:
import torch 
from torch import nn
from torch.nn import functional as F
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

class LSTM(nn.Module):
    #constructor that inherits from nn.Module
    def __init__(self, seq_length,  hidden_size, num_layers,vocab_size):

        super(LSTM, self).__init__()
        #should probably initilize the hidden states
        self.seq_length = seq_length
        self.hidden_dim = hidden_size
        self.vocab_size = vocab_size
        self.num_layers = num_layers

        #we need to embed the words, a rule of thumb is that the 
        # embedding has the fourth root of the size of the vocabulary
        self.embedding = nn.Embedding(vocab_size, hidden_size)

        #initilize an lstm layer
        self.lstm = nn.LSTM(hidden_size, hidden_size, num_layers, batch_first = True)
        
        #output hidden layer
        self.fc = nn.Linear(hidden_size, self.vocab_size)

    def forward(self, X, sequence_lengths):
        """ forward pass through network"""

        X = self.embedding(X)

        X = pack_padded_sequence(X, sequence_lengths, 
                                 batch_first = True, 
                                 enforce_sorted = False)


        X, (H,C) = self.lstm(X)
        X, _ = pad_packed_sequence(X, batch_first = True)

        fc_out = self.fc(X)

        return F.log_softmax(fc_out, dim = -1)


    def train(self, train_loader, epochs, 
              loss_function, optimizer):

        for epoch in range(epochs): # for each epoch

            epoch_loss = 0
            correct_count = 0
            prediction_count = 0

            for index, data in enumerate(train_loader): #one pass over the training data

                X, sequence_lengths, y = data

                optimizer.zero_grad() # zero gradients to avoid blowup
                output = self.forward(X, sequence_lengths)

                output = output[range(len(X)), sequence_lengths-1]

                output = output.view(-1, self.vocab_size)

                #print(output.shape)
                y = y.view(-1)

                loss = loss_function(output, y)

                loss.backward()
                #gradient clipping helps avoid blowup, which was a problem with training
                torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=4)  
                optimizer.step()
   
                #update metrics
                epoch_loss += loss.item()
                _, y_hat = torch.max(output, dim = 1)
                correct_count += (y_hat == y).sum().item()
                prediction_count += y.size(0)


            #print metrics
            print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}')
            print(f'Epoch [{epoch+1}/{epochs}], Accuracy: {correct_count/prediction_count:.4f}')
            print('   ')




In [6]:

vocab_size = len(word_to_int)
hidden_size = 6
num_layers = 2
loss_function = torch.nn.CrossEntropyLoss(reduction = "mean")

In [7]:

lstm = LSTM(seq_length = len(X[0]), 
            hidden_size = hidden_size, 
            num_layers = num_layers, 
            vocab_size = vocab_size)


optimizer = torch.optim.Adam(params = lstm.parameters(), lr = .001)
lstm.train(torch_data, 100, loss_function, optimizer)




Epoch [1/100], Loss: 3.0755
Epoch [1/100], Accuracy: 0.0793
   
Epoch [2/100], Loss: 2.7707
Epoch [2/100], Accuracy: 0.1847
   
Epoch [3/100], Loss: 2.6553
Epoch [3/100], Accuracy: 0.2484
   
Epoch [4/100], Loss: 2.5764
Epoch [4/100], Accuracy: 0.2663
   
Epoch [5/100], Loss: 2.5018
Epoch [5/100], Accuracy: 0.2715
   
Epoch [6/100], Loss: 2.3714
Epoch [6/100], Accuracy: 0.3112
   
Epoch [7/100], Loss: 2.2223
Epoch [7/100], Accuracy: 0.3421
   
Epoch [8/100], Loss: 2.0735
Epoch [8/100], Accuracy: 0.3480
   
Epoch [9/100], Loss: 1.9747
Epoch [9/100], Accuracy: 0.3574
   
Epoch [10/100], Loss: 1.9227
Epoch [10/100], Accuracy: 0.3686
   
Epoch [11/100], Loss: 1.8898
Epoch [11/100], Accuracy: 0.3733
   
Epoch [12/100], Loss: 1.8638
Epoch [12/100], Accuracy: 0.3783
   
Epoch [13/100], Loss: 1.8406
Epoch [13/100], Accuracy: 0.3870
   
Epoch [14/100], Loss: 1.8178
Epoch [14/100], Accuracy: 0.3964
   
Epoch [15/100], Loss: 1.7944
Epoch [15/100], Accuracy: 0.4063
   
Epoch [16/100], Loss: 1.7706

Building a Proofs DataSet