In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

In [None]:
LOGIC_OPERATIONS = {
    'AND': lambda x, y: x & y,  
    # Logical AND: True if both x and y are True; otherwise, False.
    
    'OR': lambda x, y: x | y,  
    # Logical OR: True if at least one of x or y is True.
    
    'XOR': lambda x, y: x ^ y,  
    # Exclusive OR: True if x and y are different, False if they are the same.
    
    'NAND': lambda x, y: ~(x & y) & 1,  
    # NOT AND: Inverse of AND; True unless both x and y are True.
    
    'NOR': lambda x, y: ~(x | y) & 1,  
    # NOT OR: Inverse of OR; True only if both x and y are False.
    
    'XNOR': lambda x, y: ~(x ^ y) & 1,  
    # Logical Equivalence (XNOR): True if x and y are the same.
    
    'IMPLIES': lambda x, y: (~x | y) & 1,  
    # Logical Implication (if/then, P → Q): False only when x is True and y is False.
    
    'REVERSE_IMPLIES': lambda x, y: (x | ~y) & 1,  
    # Reverse Implication (then/if, Q → P): False only when y is True and x is False.
    
    'XQ': lambda x, y: (~x & y) & 1,  
    # Custom logic from the paper: True only if x is False and y is True.
    
    'ABJ': lambda x, y: (x & ~y) & 1  
    # Material Nonimplication (Abjunction, P ⊅ Q): True if x is True and y is False.
}

In [None]:
for key in LOGIC_OPERATIONS.keys():
    print(key)
    for x in [0, 1]:
        for y in [0, 1]:
            print(x, y, LOGIC_OPERATIONS[key](x, y))
    print()

AND
0 0 0
0 1 0
1 0 0
1 1 1

OR
0 0 0
0 1 1
1 0 1
1 1 1

XOR
0 0 0
0 1 1
1 0 1
1 1 0

NAND
0 0 1
0 1 1
1 0 1
1 1 0

NOR
0 0 1
0 1 0
1 0 0
1 1 0

XNOR
0 0 1
0 1 0
1 0 0
1 1 1

IMPLIES
0 0 1
0 1 1
1 0 0
1 1 1

REVERSE_IMPLIES
0 0 1
0 1 0
1 0 1
1 1 1

XQ
0 0 0
0 1 1
1 0 0
1 1 0

ABJ
0 0 0
0 1 0
1 0 1
1 1 0



In [None]:
class LogicDataset(torch.utils.data.Dataset):
    """
    A logic operation dataset.
    """

    def __init__(
            self,
            num_samples = 10000,
            num_operations = 10,
            seq_len = 10,
        ):
        """
        Initialize the dataset.
        """

        self.num_samples = num_samples
        self.num_operations = num_operations
        self.seq_len = seq_len
        self.operation_list = list(LOGIC_OPERATIONS.keys())
        
        # generate data
        self.generate_data()


    def generate_data(self):
        """
        Generate dataset.
        """
        
        self.data = []
        self.labels = []
        self.symbols = []
        self.operations = []
        
        for _ in range(self.num_samples):
            # sample initial variables
            s1, s2 = np.random.choice([0, 1], size = 2)

            # sample operations
            operation_indices = np.random.randint(0, len(self.operation_list), size = self.num_operations)

            # compute ground truth output
            result = s1
            for idx in operation_indices:
                result = LOGIC_OPERATIONS[self.operation_list[idx]](result, s2)
                result = int(result)

                if result not in [0, 1]:
                    raise ValueError('Result is not boolean.')

            # convert operations to one-hot encoding using torch
            operations = F.one_hot(torch.tensor(operation_indices), num_classes = len(self.operation_list)).float() # (seq_len, num_operations)
            operations = operations.reshape(-1) # (seq_len * num_operation,)
            
            # create tensors
            input = torch.cat([torch.tensor([s1, s2], dtype = torch.float32), operations], dim = 0) # (seq_len * num_operations + 2,)
            label = torch.tensor(result, dtype = torch.long) # integer class index

            # append data and label
            self.data.append(input)
            self.labels.append(label)
            self.symbols.append([s1, s2])
            self.operations.append(list(operation_indices))
        
        self.data = torch.stack(self.data)
        self.data = self.data.unsqueeze(1).repeat(1, self.seq_len, 1) # (num_samples, seq_len, feature_size)
        self.labels = torch.tensor(self.labels) # (num_samples)

        self.symbols = torch.tensor(self.symbols)
        self.operations = torch.tensor(self.operations)


    def __len__(self):
        return self.num_samples


    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

In [None]:
dataset = LogicDataset(
    num_samples = 30000,
    num_operations = 10,
    seq_len = 10,
)
print(dataset.data.shape)
print(dataset.labels.shape)

torch.Size([30000, 10, 102])
torch.Size([30000])


In [None]:
class LogicGRU(nn.Module):
    """
    A GRU network class.
    """

    def __init__(
            self,
            input_size,
            hidden_size = 128,
        ):
        """
        Initialize the network.
        """
        
        super(LogicGRU, self).__init__()
        self.hidden_size = hidden_size
        self.gru = nn.GRU(input_size, hidden_size, batch_first = True)
        self.fc = nn.Linear(hidden_size, 2)
    

    def forward(self, x):
        """
        Forward the network.
        
        Args:
            x: a torch.tensor with a shape of (batch_size, seq_len, input_size).
        
        Returns:
            outputs: a torch.tensor with a shape of (batch_size, seq_len, 2)
            hiddens: a torch.tensor with a shape of (batch_size, seq_len, hidden_size)
        """

        batch_size = x.size(0)
        hidden_init = torch.zeros(1, batch_size, self.hidden_size) # (layer_size, batch_size, hidden_size)
        hiddens, _ = self.gru(x, hidden_init) # (batch_size, seq_len, hidden_size)
        outputs = self.fc(hiddens) # (batch_size, seq_len, 2)

        return outputs, hiddens

In [None]:
class Trainer:
    """
    A trainer class.
    """
    
    def __init__(
            self,
            model,
            train_loader,
            lr = 1e-3,
            device = 'cpu'
        ):
        """
        Initialize the trainer.
        """

        self.model = model.to(device)
        self.train_loader = train_loader
        self.device = device
        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = optim.Adam(self.model.parameters(), lr = lr)


    def train(self, num_epochs):
        """
        Train the network.
        """
        
        for epoch in range(num_epochs):
            epoch_loss = 0
            for inputs, targets in self.train_loader:
                inputs, targets = inputs.to(self.device), targets.to(self.device)

                self.optimizer.zero_grad()
                outputs, _ = self.model(inputs) # (batch_size, seq_len, 2)
                loss = self.criterion(outputs[:, -1, :], targets)
                loss.backward()
                self.optimizer.step()

                epoch_loss += loss.item()

            print(f'Epoch {epoch}, Loss: {epoch_loss / len(self.train_loader):.4f}')

In [None]:
# initialize dataset and dataloader
dataset = LogicDataset(
    num_samples = 10000,
    num_operations = 10,
    seq_len = 10,
)
train_loader = torch.utils.data.DataLoader(dataset, batch_size = 1024, shuffle = True)

# define model and trainer
model = LogicGRU(
    input_size = 102,
    hidden_size = 128,
)
trainer = Trainer(
    model = model,
    train_loader = train_loader,
    lr = 1e-3,
)

# train the model
trainer.train(num_epochs = 200)

Epoch 0, Loss: 0.6847
Epoch 1, Loss: 0.6446
Epoch 2, Loss: 0.5989
Epoch 3, Loss: 0.5892
Epoch 4, Loss: 0.5871
Epoch 5, Loss: 0.5843
Epoch 6, Loss: 0.5834
Epoch 7, Loss: 0.5807
Epoch 8, Loss: 0.5767
Epoch 9, Loss: 0.5691
Epoch 10, Loss: 0.5442
Epoch 11, Loss: 0.4754
Epoch 12, Loss: 0.4434
Epoch 13, Loss: 0.4214
Epoch 14, Loss: 0.4037
Epoch 15, Loss: 0.3904
Epoch 16, Loss: 0.3747
Epoch 17, Loss: 0.3548
Epoch 18, Loss: 0.3421
Epoch 19, Loss: 0.3205
Epoch 20, Loss: 0.3131
Epoch 21, Loss: 0.2983
Epoch 22, Loss: 0.2752
Epoch 23, Loss: 0.2557
Epoch 24, Loss: 0.2425
Epoch 25, Loss: 0.2293
Epoch 26, Loss: 0.2289
Epoch 27, Loss: 0.2246
Epoch 28, Loss: 0.2079
Epoch 29, Loss: 0.2041
Epoch 30, Loss: 0.1970
Epoch 31, Loss: 0.1994
Epoch 32, Loss: 0.1785
Epoch 33, Loss: 0.1669
Epoch 34, Loss: 0.1705
Epoch 35, Loss: 0.1914
Epoch 36, Loss: 0.1648
Epoch 37, Loss: 0.1462
Epoch 38, Loss: 0.1324
Epoch 39, Loss: 0.1277
Epoch 40, Loss: 0.1441
Epoch 41, Loss: 0.1388
Epoch 42, Loss: 0.1226
Epoch 43, Loss: 0.122

In [None]:
eval_dataset = LogicDataset(
    num_samples = 10000,
    num_operations = 10,
    seq_len = 10,
)
eval_loader = torch.utils.data.DataLoader(eval_dataset, batch_size = 1024, shuffle = True)

with torch.no_grad():
    inputs = eval_dataset.dataset.data
    targets = eval_dataset.dataset.labels

    outputs, hiddens = model(inputs) # (num_samples, seq_len, 2) / (num_samples, seq_len, hidden_size)

    probs = torch.softmax(outputs, dim = -1)
    idxs = torch.argmax(probs, dim = -1)
    print(idxs.shape)

torch.Size([10000, 10, 2])
torch.Size([10000, 10, 128])


In [None]:
plt.figure(figsize = (10, 10))
for i in range(100):
    plt.subplot(10, 10, i + 1)
    plt.plot(probs[i, :, 1], marker = 'o', markersize = 3)
    plt.axhline(y = 0.5, color = 'k', linestyle = '--', linewidth = 2)
    plt.xticks([])
    plt.yticks([])
    plt.ylim((-0.1, 1.1))
plt.tight_layout()
plt.show()

In [None]:
operation_list = list(LOGIC_OPERATIONS.keys())

for i in range(10):
    s1, s2 = dataset.symbols[i]
    operation_indices = dataset.operations[i]

    # compute ground truth output
    results = []
    result = s1
    for idx in operation_indices:
        result = LOGIC_OPERATIONS[operation_list[idx]](result, s2)
        result = int(result)
        results.append(result)
    
    print(results)

In [None]:
plt.figure(figsize = (10, 10))
for i in range(100):
    plt.subplot(10, 10, i + 1)
    s1, s2 = dataset.symbols[i]
    operation_indices = dataset.operations[i]
    results = []
    result = s1
    for idx in operation_indices:
        result = LOGIC_OPERATIONS[operation_list[idx]](result, s2)
        result = int(result)
        results.append(result)

    plt.plot(probs[i, :, 1], marker = 'o', markersize = 3)
    plt.plot(results, alpha = 0.6)
    plt.axhline(y = 0.5, color = 'k', linestyle = '--', linewidth = 2)
    plt.xticks([])
    plt.yticks([])
    plt.ylim((-0.1, 1.1))

plt.tight_layout()
plt.show()

In [None]:
above = probs[:, :, 1] > 0.5  # shape (10000, 15), bool
crossings = above[:, 1:] != above[:, :-1]  # shape (10000, 14), bool
num_crossings = crossings.sum(dim=1)

plt.figure()
plt.hist(num_crossings)
plt.show()