## Lab 9 - Differentiable Logic Gate Networks
Zaimplementuj Differentiable Logic Gate Network zgodnie z załączonym artykułem. Przetestuj na wybranym problemie.

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
class DifferentiableLogicGate(nn.Module):
    def __init__(self, temperature=1):
        super().__init__()
        self.gate_logits = nn.Parameter(torch.rand(16))
        self.temperature = temperature

        self.logic_operators = [
            lambda x, y: torch.zeros_like(x),
            lambda x, y: x * y,
            lambda x, y: x - x * y,
            lambda x, y: x,
            lambda x, y: y - x * y,
            lambda x, y: y,
            lambda x, y: x + y - 2 * x * y,
            lambda x, y: x + y - x * y,
            lambda x, y: 1 - (x + y - x * y),
            lambda x, y: 1 - (x + y - 2 * x * y),
            lambda x, y: 1 - y,
            lambda x, y: 1 - y + x * y,
            lambda x, y: 1 - x,
            lambda x, y: 1 - x + x * y,
            lambda x, y: 1 - x * y,
            lambda x, y: torch.ones_like(x)
        ]

    def gate_probability(self):
        return F.softmax(self.gate_logits, dim=0)

    def output(self, x1, x2, probabilities):
        """
        x1: shape [batch_size]
        x2: shape [batch_size]
        probabilities: shape [16]
        """
        if self.training:
            x1 = x1.unsqueeze(1)
            x2 = x2.unsqueeze(1)

            operators = torch.stack([op(x1, x2) for op in self.logic_operators], dim=0)

            probabilities = probabilities.view(-1, 1, 1)

            result_mid = torch.sum(probabilities * operators, dim=0)
            result = result_mid.squeeze(-1)
        else:
            x1 = x1.unsqueeze(1)
            x2 = x2.unsqueeze(1)

            result = self.logic_operators[torch.argmax(probabilities)](x1, x2).squeeze(1)

        return result

    def forward(self, x1, x2):
        probs = self.gate_probability()
        return self.output(x1, x2, probs)

In [3]:
class LogicLayer(nn.Module):
    def __init__(self, input_size, num_gates, temperature=1):
        super().__init__()
        self.num_gates = num_gates
        self.gates = nn.ModuleList([
            DifferentiableLogicGate(temperature) for _ in range(num_gates)
        ])

        # Random unchangeable connections for each Gate
        self.register_buffer('connections', torch.randint(low=0, high=input_size, size=(num_gates, 2)))

    def forward(self, inputs):
        outputs = []

        for i in range(self.num_gates):
            con1, con2 = self.connections[i]
            outputs.append(self.gates[i](inputs[:, con1], inputs[:, con2]))

        return torch.stack(outputs, dim=1)

In [4]:
import numpy as np

class DifferentiableLogicGateNetwork(nn.Module):
    def __init__(self, layer_sizes, n_outputs, temperature, beta):
        super().__init__()
        self.temperature = temperature
        self.beta = beta
        self.n_outputs = n_outputs

        self.layers = nn.ModuleList([
            LogicLayer(layer_sizes[i], layer_sizes[i+1], temperature) for i in range(len(layer_sizes)-1)
        ])

        last_layer_size = layer_sizes[-1]

        kernel_size = last_layer_size // n_outputs
        assert kernel_size * n_outputs == last_layer_size, "last_layer_size must be divisible by n_outputs"
        self.pool = nn.AvgPool1d(kernel_size=kernel_size, stride=kernel_size)

    def aggregate_output_neurons(self, outputs):
        batch_size = outputs.shape[0]
        x = outputs.squeeze(1)
        pooled = self.pool(x)
        aggregated = pooled.squeeze(1)
        scaled = aggregated / self.temperature + self.beta

        if self.training:
            return F.softmax(scaled, dim=-1)
        else:
            max_indices = torch.argmax(scaled, dim=-1)
            one_hot = F.one_hot(max_indices, num_classes=self.n_outputs).float()
            return one_hot

    def forward(self, x):
        layer_output = x
        for layer in self.layers:
            layer_output = layer(layer_output)
        return self.aggregate_output_neurons(layer_output)


In [5]:
%%capture
!pip install datasets

In [6]:
from datasets import load_dataset

ds = load_dataset("ylecun/mnist")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md:   0%|          | 0.00/6.97k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/15.6M [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/2.60M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/60000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/10000 [00:00<?, ? examples/s]

In [7]:
from torch.utils.data import Dataset, DataLoader, Subset
from torchvision import transforms

class MNISTDataset(Dataset):
    def __init__(self, hf_dataset, transform=None, target_transform=None):
        self.dataset = hf_dataset
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        image = self.dataset['image'][idx]
        label = self.dataset['label'][idx]

        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)

        return torch.flatten(image), label

class SubsetDataset(Dataset):
    def __init__(self, subset, transform=None, target_transform=None):
        self.subset = subset
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.subset)

    def __getitem__(self, idx):
        x, y, = self.subset[idx]['image'], self.subset[idx]['label']
        if self.transform:
            x = self.transform(x)
        if self.target_transform:
            y = self.target_transform(y)
        return torch.flatten(x), y

In [8]:
import random

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Train - 20k elementów
train_ind = list(random.choices(range(60000), k=20000))
# Validation - 500 elementów
val_ind = list(range(0, 500))
# Test - 2k elementów
test_ind = list(random.choices(range(500, 10000), k=2000))

train_set = Subset(ds['train'], train_ind)
val_set = Subset(ds['test'], val_ind)
test_set = Subset(ds['test'], test_ind)
train_dataset = SubsetDataset(train_set, transform=transform)
val_dataset = SubsetDataset(val_set, transform=transform)
test_dataset = SubsetDataset(test_set, transform=transform)

train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=True)

In [9]:
def train_model(device, model, train_loader, val_loader, loss_fn, optimizer, num_epoch):
    train_losses = []
    val_losses = []

    print('Starting training')
    for epoch in range(num_epoch):
        model.train()
        train_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)

            outputs = outputs.to(torch.float32)

            labels_one_hot = F.one_hot(labels, num_classes=10).float()

            loss = loss_fn(outputs, labels_one_hot)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()#*inputs.size(0)

        model.eval()
        val_loss = 0.0
        correct = 0
        total = 0

        with torch.inference_mode():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                outputs = outputs.unsqueeze(1)
                labels_one_hot = F.one_hot(labels, num_classes=10).float()
                loss = loss_fn(outputs.squeeze(1), labels_one_hot)

                val_loss += loss.item()*inputs.size(0)

                outputs = outputs.squeeze(1)
                _, predicted = torch.max(outputs.data, 1)
                #_, true_labels = torch.max(labels, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        val_loss /= len(val_loader.dataset)

        val_accuracy = 100 * correct / total
        val_losses.append(val_loss)

        print(f'Epoch [{epoch+1}/{num_epoch}], '
              f'Train Loss: {train_loss:.4f}, '
              f'Validation Loss: {val_loss:.4f}, '
              f'Validation Accuracy: {val_accuracy:.2f}%')

    return train_losses, val_losses

In [10]:
def test_model(device, model, test_loader, loss_fn):
    model.eval()
    test_loss = 0.0
    correct = 0
    total = 0

    with torch.inference_mode():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)
            outputs = outputs.unsqueeze(1)
            labels_one_hot = F.one_hot(labels, num_classes=10).float()
            loss = loss_fn(outputs.squeeze(1), labels_one_hot)
            test_loss += loss.item() * inputs.size(0)

            outputs = outputs.squeeze(1)
            _, predicted = torch.max(outputs.data, 1)
            #_, true_labels = torch.max(labels, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    test_loss /= len(test_loader.dataset)
    test_accuracy = 100 * correct / total

    print(f'Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%')
    return test_loss, test_accuracy

In [12]:
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)

In [13]:
device

'cuda'

In [14]:
net = DifferentiableLogicGateNetwork([784, 512, 256, 128, 40], 10, 0.5, 1)
optimizer = torch.optim.Adam(net.parameters(), lr=0.15)
loss_fn = torch.nn.BCELoss()

net.to(device)

train_losses, val_losses = train_model(device, net, train_dataloader, val_dataloader, loss_fn, optimizer, 20)

Starting training
Epoch [1/20], Train Loss: 62.9458, Validation Loss: 13.2800, Validation Accuracy: 33.60%
Epoch [2/20], Train Loss: 52.6344, Validation Loss: 12.2400, Validation Accuracy: 38.80%
Epoch [3/20], Train Loss: 50.8949, Validation Loss: 11.6000, Validation Accuracy: 42.00%
Epoch [4/20], Train Loss: 50.0197, Validation Loss: 11.1600, Validation Accuracy: 44.20%
Epoch [5/20], Train Loss: 49.6357, Validation Loss: 14.2400, Validation Accuracy: 28.80%
Epoch [6/20], Train Loss: 49.3356, Validation Loss: 14.5200, Validation Accuracy: 27.40%
Epoch [7/20], Train Loss: 49.3013, Validation Loss: 12.6000, Validation Accuracy: 37.00%
Epoch [8/20], Train Loss: 49.3397, Validation Loss: 11.8400, Validation Accuracy: 40.80%
Epoch [9/20], Train Loss: 49.3423, Validation Loss: 12.2400, Validation Accuracy: 38.80%
Epoch [10/20], Train Loss: 49.1489, Validation Loss: 11.2400, Validation Accuracy: 43.80%
Epoch [11/20], Train Loss: 49.0275, Validation Loss: 13.2000, Validation Accuracy: 34.00%
E

In [15]:
test_loss, test_accuracy = test_model(device, net, test_dataloader, loss_fn)

Test Loss: 10.7400, Test Accuracy: 46.30%
