In [3]:
!pip install einops
!pip install tqdm

[0m

In [11]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from difflogic import LogicLayer, GroupSum
import einops
import time
import numpy as np
from tqdm import tqdm
from typing import Literal
import os

In [15]:
# Type definitions
InitializationType = Literal['residual', 'random']

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Hyperparameters
batch_size = 100
learning_rate = 0.01
num_epochs = 100
k = 16  # Base number of kernels (from paper: k=16 for small model)

print(f"Base kernel count k = {k}")
print(f"Expected shapes from paper:")
print(f"After conv1 + pool1: {k} × 12 × 12")
print(f"After conv2 + pool2: {3*k} × 6 × 6") 
print(f"After conv3 + pool3: {9*k} × 3 × 3")
print(f"After flattening: {81*k}")

# Data loading and preprocessing
transform = transforms.Compose([
    transforms.ToTensor(),
])

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0)

print(f"Training samples: {len(train_dataset)}")
print(f"Test samples: {len(test_dataset)}")

# Logic gate definitions
logic_gates = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]

def apply_logic_gate(a: torch.Tensor, b: torch.Tensor, logic_gate: int):
    return {
        0:  torch.zeros_like(a),
        1:  a * b,
        2:  a - a * b,
        3:  a,
        4:  b - a * b,
        5:  b,
        6:  a + b - 2 * a * b,
        7:  a + b - a * b,
        8:  1 - (a + b - a * b),
        9:  1 - (a + b - 2 * a * b),
        10: 1 - b,
        11: 1 - b + a * b,
        12: 1 - a,
        13: 1 - a + a * b,
        14: 1 - a * b,
        15: torch.ones_like(a),
    }[logic_gate]

class Logic(nn.Module):
    def __init__(self,
                 in_dim: int,
                 out_dim: int,
                 initialization_type: InitializationType = 'residual',
                 device=None
                 ):
        super().__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.initialization_type = initialization_type
        self.device = device or torch.device('cpu')
        
        a, b = self.get_connections()
        self.register_buffer('a', a)
        self.register_buffer('b', b)
        
        weights = torch.randn(out_dim, len(logic_gates), device=self.device)
        if self.initialization_type == 'residual':
            weights[:, :] = 0
            weights[:, 3] = 5  # Initialize to identity gate
        self.weights = torch.nn.parameter.Parameter(weights)

    def forward(self, x: torch.Tensor):
        a, b = x[:, self.a, ...], x[:, self.b, ...]
        
        if self.training:
            normalized_weights = torch.nn.functional.softmax(self.weights, dim=-1).to(x.dtype).to(self.device)
            r = torch.zeros_like(a).to(x.dtype).to(self.device)
            for logic_gate in logic_gates:
                if len(a.shape) > 2:
                    nw = einops.repeat(normalized_weights[..., logic_gate], 'weights -> weights depth', depth=a.shape[-1])
                else:
                    nw = normalized_weights[..., logic_gate]
                r = r + nw * apply_logic_gate(a, b, logic_gate)
            return r
        else:
            one_hot_weights = torch.nn.functional.one_hot(self.weights.argmax(-1), len(logic_gates)).to(torch.float32).to(self.device)
            with torch.no_grad():
                r = torch.zeros_like(a).to(x.dtype).to(self.device)
                for logic_gate in logic_gates:
                    if len(a.shape) > 2:
                        ohw = einops.repeat(one_hot_weights[..., logic_gate], 'weights -> weights depth', depth=a.shape[-1])
                    else:
                        ohw = one_hot_weights[..., logic_gate]
                    r = r + ohw * apply_logic_gate(a, b, logic_gate)
                return r

    def get_connections(self):
        connections = torch.randperm(2 * self.out_dim) % self.in_dim
        connections = torch.randperm(self.in_dim)[connections]
        connections = connections.reshape(2, self.out_dim)
        a, b = connections[0], connections[1]
        a, b = a.to(torch.int64), b.to(torch.int64)
        a, b = a.to(self.device), b.to(self.device)
        return a, b

class LogicTree(nn.Module):
    def __init__(self,
                 in_dim: int,
                 depth: int = 3,
                 initialization_type: InitializationType = 'residual',
                 device=None,
                 ):
        super().__init__()
        self.device = device or torch.device('cpu')
        
        layers = [LogicLayer(in_dim, int(2 ** (depth - 1)), initialization_type=initialization_type, device=self.device,implementation='cuda' if device.type == 'cuda' else 'python',
                    connections='random',grad_factor=1.5 )]
        for i in range(0, depth - 1, 1):
            layers.append(LogicLayer(int(2 ** (depth - 1 - i)), int(2 ** (depth - 1 - i - 1)), 
                            initialization_type=initialization_type, device=self.device,implementation='cuda' if device.type == 'cuda' else 'python',
                            connections='random',grad_factor=1.5))
        
        self.tree = nn.Sequential(*layers)

    def forward(self, x: torch.Tensor):
        return self.tree(x)

class Conv(nn.Module):
    def __init__(self,
                 in_channels: int,
                 out_channels: int,
                 depth: int = 3,
                 kernel_size: int = 3,
                 stride: int = 1,
                 padding: int = 1,
                 initialization_type: InitializationType = 'residual',
                 device=None
                 ):
        super().__init__()
        self.kernel_size = kernel_size
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.stride = stride
        self.padding = padding
        self.device = device or torch.device('cpu')
        
        self.filters = nn.ModuleList([
            LogicTree(in_dim=kernel_size ** 2 * in_channels, depth=depth, 
                     initialization_type=initialization_type, device=self.device) 
            for _ in range(out_channels)
        ])

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        batch_size, _, height, width = x.shape
        x = F.pad(x, (self.padding, self.padding, self.padding, self.padding), mode='constant', value=0)
        out_height = (height + 2 * self.padding - self.kernel_size) // self.stride + 1
        out_width = (width + 2 * self.padding - self.kernel_size) // self.stride + 1
        patches = F.unfold(x, kernel_size=self.kernel_size, stride=self.stride)
        outputs = []
        
        patches = einops.rearrange(patches, 'b h w -> (b w) h', h=patches.shape[1], w=patches.shape[2]) # Input is (100,25,576) Output: (57600,25)
        for filter in self.filters:
            out = filter(patches)  # Shape: (batch_size, 1, out_height * out_width)
            out = einops.rearrange(out, '(b h w) 1 -> b (h w)', h=out_height, w=out_width)
            outputs.append(out)
        
        output_tensor = torch.stack(outputs, dim=1)  # Shape: (batch_size, out_channels, out_height * out_width)
        output_tensor = einops.rearrange(output_tensor, 'b c (h w) -> b c h w', h=out_height, w=out_width)
        return output_tensor

class CustomOrPool2d(nn.Module):
    def __init__(self, kernel_size=2, stride=2, padding=0):
        super(CustomOrPool2d, self).__init__()
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        
    def forward(self, x):
        # Use MaxPool2d as approximation to OR pooling
        # In binary logic, max operation approximates OR
        return torch.max_pool2d(x, self.kernel_size, self.stride, self.padding)



class ConvDiffLogicMNIST(nn.Module):
    def __init__(self, k=16):
        super(ConvDiffLogicMNIST, self).__init__()
        self.k = k
        
        # Convolutional block 1: k kernels, 5x5, depth=3, no padding
        # Input: 1 × 28 × 28 -> Output: k × 24 × 24 (28-5+1=24)
        self.conv1 = Conv(in_channels=1, out_channels=k, kernel_size=5, depth=3, 
                         padding=0, initialization_type='residual', device=device)
        
        # OR pooling 1: 2x2, stride 2
        # k × 24 × 24 -> k × 12 × 12
        self.pool1 = CustomOrPool2d(kernel_size=2, stride=2)
        
        # Convolutional block 2: 3*k kernels, 3x3, depth=3
        # k × 12 × 12 -> 3*k × 12 × 12 (with padding=1), then pooled to 3*k × 6 × 6
        self.conv2 = Conv(in_channels=k, out_channels=3*k, kernel_size=3, depth=3, 
                         padding=1, initialization_type='residual', device=device)
        
        # OR pooling 2: 2x2, stride 2
        # 3*k × 12 × 12 -> 3*k × 6 × 6
        self.pool2 = CustomOrPool2d(kernel_size=2, stride=2)
        
        # Convolutional block 3: 9*k kernels, 3x3, depth=3
        # 3*k × 6 × 6 -> 9*k × 6 × 6 (with padding=1), then pooled to 9*k × 3 × 3
        self.conv3 = Conv(in_channels=3*k, out_channels=9*k, kernel_size=3, depth=3, 
                         padding=1, initialization_type='residual', device=device)
        
        # OR pooling 3: 2x2, stride 2
        # 9*k × 6 × 6 -> 9*k × 3 × 3
        self.pool3 = CustomOrPool2d(kernel_size=2, stride=2)
        
        # Flatten: 9*k × 3 × 3 -> 81*k
        self.flatten = nn.Flatten()
        
        # Regular differentiable logic layers (as specified in paper)
        # 81*k → 1280*k
        self.fc1 = LogicLayer(
            in_dim=81*k,
            out_dim=1280*k,
            device=device,
            implementation='cuda' if device.type == 'cuda' else 'python',
            connections='random',
            grad_factor=1.5 , # Higher for deeper networks
        )
        
        # 1280*k → 640*k
        self.fc2 = LogicLayer(
            in_dim=1280*k,
            out_dim=640*k,
            device=device,
            implementation='cuda' if device.type == 'cuda' else 'python',
            connections='random',
            grad_factor=1.5
        )
        
        # 640*k → 320*k
        self.fc3 = LogicLayer(
            in_dim=640*k,
            out_dim=320*k,
            device=device,
            implementation='cuda' if device.type == 'cuda' else 'python',
            connections='random',
            grad_factor=1.5
        )
        
        # GroupSum: 320*k → 10 (10 classes)
        # Using tau=30 as in the paper specifications
        self.group_sum = GroupSum(k=10, tau=30)
        
    def forward(self, x):
        # Input thresholding for binary processing (as mentioned in paper)
        # The paper mentions using binary inputs
        x = (x > 0.5).float()
        
        # Debug shape printing (uncomment for debugging)
        # print(f"Input shape: {x.shape}")
        
        # Convolutional processing with logic gates
        x = self.conv1(x)
        # print(f"After conv1: {x.shape}")
        
        x = self.pool1(x)
        # print(f"After pool1: {x.shape}")
        
        x = self.conv2(x)
        # print(f"After conv2: {x.shape}")
        
        x = self.pool2(x)
        # print(f"After pool2: {x.shape}")
        
        x = self.conv3(x)
        # print(f"After conv3: {x.shape}")
        
        x = self.pool3(x)
        # print(f"After pool3: {x.shape}")
        
        # Flatten
        x = self.flatten(x)
        # print(f"After flatten: {x.shape}")
        
        # Fully connected logic layers
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        
        # GroupSum for classification
        x = self.group_sum(x)
        
        return x

# Initialize model
model = ConvDiffLogicMNIST(k=k).to(device)
print(f"Model created with {sum(p.numel() for p in model.parameters())} parameters")

# Print architecture details
print("\n" + "="*80)
print("LOGIC GATE CONVOLUTIONAL DIFFLOGIC MNIST")
print("="*80)
print(f"Input: 1 × 28 × 28")
print(f"Conv1: {k} logic gate filters, 5×5, depth=3, no padding -> {k} × 24 × 24")
print(f"Pool1: OR pooling 2×2, stride 2 -> {k} × 12 × 12")
print(f"Conv2: {3*k} logic gate filters, 3×3, depth=3 -> {3*k} × 12 × 12")
print(f"Pool2: OR pooling 2×2, stride 2 -> {3*k} × 6 × 6")
print(f"Conv3: {9*k} logic gate filters, 3×3, depth=3 -> {9*k} × 6 × 6")
print(f"Pool3: OR pooling 2×2, stride 2 -> {9*k} × 3 × 3")
print(f"Flatten: -> {81*k}")
print(f"FC1: Regular differentiable logic layer {81*k} -> {1280*k}")
print(f"FC2: Regular differentiable logic layer {1280*k} -> {640*k}")
print(f"FC3: Regular differentiable logic layer {640*k} -> {320*k}")
print(f"GroupSum: {320*k} -> 10 classes")
print("="*80)

# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Training function
def train_epoch(model, train_loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    progress_bar = tqdm(train_loader, desc='Training')
    for batch_idx, (data, target) in enumerate(progress_bar):
        data, target = data.to(device), target.to(device)
        
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        _, predicted = torch.max(output.data, 1)
        total += target.size(0)
        correct += (predicted == target).sum().item()
        
        progress_bar.set_postfix({
            'Loss': f'{running_loss/(batch_idx+1):.4f}',
            'Acc': f'{100.*correct/total:.2f}%'
        })
    
    return running_loss / len(train_loader), 100. * correct / total

# Evaluation function
def evaluate(model, test_loader, criterion, device):
    model.eval()
    test_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        progress_bar = tqdm(test_loader, desc='Testing')
        for data, target in progress_bar:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion(output, target).item()
            
            _, predicted = torch.max(output.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()
            
            progress_bar.set_postfix({
                'Acc': f'{100.*correct/total:.2f}%'
            })
    
    return test_loss / len(test_loader), 100. * correct / total

if __name__ == '__main__':
    print("Starting training with logic gate convolutions...")
    train_losses = []
    train_accuracies = []
    test_losses = []
    test_accuracies = []

    best_test_acc = 0.0
    start_time = time.time()

    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        print("-" * 50)
        
        # Train
        train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
        train_losses.append(train_loss)
        train_accuracies.append(train_acc)
        
        # Evaluate
        test_loss, test_acc = evaluate(model, test_loader, criterion, device)
        test_losses.append(test_loss)
        test_accuracies.append(test_acc)
        
        # Save best model
        if test_acc > best_test_acc:
            best_test_acc = test_acc
            torch.save(model.state_dict(), 'best_logic_conv_difflogic_mnist.pth')
            print(f"New best model saved! Test accuracy: {test_acc:.2f}%")
        
        print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
        print(f"Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.2f}%")

    total_time = time.time() - start_time
    print(f"\nTraining completed in {total_time:.2f} seconds")
    print(f"Best test accuracy: {best_test_acc:.2f}%")

    # Test discrete inference (switch to hard logic gates)
    print("\nTesting discrete inference...")
    model.eval()  # This switches to discrete/hard logic mode

    start_time = time.time()
    correct = 0
    total = 0

    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            _, predicted = torch.max(output.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()

    inference_time = time.time() - start_time
    inference_speed = inference_time / total

    print(f"Discrete inference accuracy: {100. * correct / total:.2f}%")
    print(f"Inference speed: {inference_speed:.6f} seconds per sample")

    # Save final model
    torch.save(model.state_dict(), 'final_logic_conv_difflogic_mnist.pth')

    # Final summary
    print("\n" + "="*80)
    print("LOGIC GATE CONVOLUTIONAL DIFFLOGIC MNIST RESULTS")
    print("="*80)
    print(f"Architecture: Custom logic gate convolutions + DiffLogic FC layers")
    print(f"Base kernel count (k): {k}")
    print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
    print(f"Training epochs: {num_epochs}")
    print(f"Best test accuracy: {best_test_acc:.2f}%")
    print(f"Final test accuracy: {test_accuracies[-1]:.2f}%")
    print(f"Inference speed: {inference_speed:.6f} seconds per sample")
    print()
    print("Logic gate convolutions replace traditional conv layers with:")
    print("- 16 different logic gates (AND, OR, XOR, NOT, etc.)")
    print("- Tree-structured logic processing with configurable depth")
    print("- Soft logic during training, hard logic during inference")
    print("="*80)

Using device: cuda
Base kernel count k = 16
Expected shapes from paper:
After conv1 + pool1: 16 × 12 × 12
After conv2 + pool2: 48 × 6 × 6
After conv3 + pool3: 144 × 3 × 3
After flattening: 1296
Training samples: 60000
Test samples: 10000
Model created with 596736 parameters

LOGIC GATE CONVOLUTIONAL DIFFLOGIC MNIST
Input: 1 × 28 × 28
Conv1: 16 logic gate filters, 5×5, depth=3, no padding -> 16 × 24 × 24
Pool1: OR pooling 2×2, stride 2 -> 16 × 12 × 12
Conv2: 48 logic gate filters, 3×3, depth=3 -> 48 × 12 × 12
Pool2: OR pooling 2×2, stride 2 -> 48 × 6 × 6
Conv3: 144 logic gate filters, 3×3, depth=3 -> 144 × 6 × 6
Pool3: OR pooling 2×2, stride 2 -> 144 × 3 × 3
Flatten: -> 1296
FC1: Regular differentiable logic layer 1296 -> 20480
FC2: Regular differentiable logic layer 20480 -> 10240
FC3: Regular differentiable logic layer 10240 -> 5120
GroupSum: 5120 -> 10 classes
Starting training with logic gate convolutions...

Epoch 1/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=1.2357, Acc=66.49%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.25it/s, Acc=81.31%]


New best model saved! Test accuracy: 81.31%
Train Loss: 1.2357, Train Acc: 66.49%
Test Loss: 0.6696, Test Acc: 81.31%

Epoch 2/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=0.5833, Acc=84.55%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.49it/s, Acc=88.50%]


New best model saved! Test accuracy: 88.50%
Train Loss: 0.5833, Train Acc: 84.55%
Test Loss: 0.4603, Test Acc: 88.50%

Epoch 3/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.54it/s, Loss=0.4446, Acc=88.75%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.39it/s, Acc=90.64%]


New best model saved! Test accuracy: 90.64%
Train Loss: 0.4446, Train Acc: 88.75%
Test Loss: 0.3871, Test Acc: 90.64%

Epoch 4/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.54it/s, Loss=0.3857, Acc=90.52%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.52it/s, Acc=91.49%]


New best model saved! Test accuracy: 91.49%
Train Loss: 0.3857, Train Acc: 90.52%
Test Loss: 0.3516, Test Acc: 91.49%

Epoch 5/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:11<00:00,  4.55it/s, Loss=0.3538, Acc=91.49%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.25it/s, Acc=92.33%]


New best model saved! Test accuracy: 92.33%
Train Loss: 0.3538, Train Acc: 91.49%
Test Loss: 0.3311, Test Acc: 92.33%

Epoch 6/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=0.3332, Acc=92.14%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.36it/s, Acc=92.82%]


New best model saved! Test accuracy: 92.82%
Train Loss: 0.3332, Train Acc: 92.14%
Test Loss: 0.3153, Test Acc: 92.82%

Epoch 7/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=0.3190, Acc=92.55%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.25it/s, Acc=93.02%]


New best model saved! Test accuracy: 93.02%
Train Loss: 0.3190, Train Acc: 92.55%
Test Loss: 0.3054, Test Acc: 93.02%

Epoch 8/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=0.3084, Acc=92.86%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.35it/s, Acc=93.21%]


New best model saved! Test accuracy: 93.21%
Train Loss: 0.3084, Train Acc: 92.86%
Test Loss: 0.2973, Test Acc: 93.21%

Epoch 9/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=0.3005, Acc=93.05%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.29it/s, Acc=93.46%]


New best model saved! Test accuracy: 93.46%
Train Loss: 0.3005, Train Acc: 93.05%
Test Loss: 0.2911, Test Acc: 93.46%

Epoch 10/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=0.2947, Acc=93.24%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.40it/s, Acc=93.62%]


New best model saved! Test accuracy: 93.62%
Train Loss: 0.2947, Train Acc: 93.24%
Test Loss: 0.2871, Test Acc: 93.62%

Epoch 11/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=0.2902, Acc=93.44%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.29it/s, Acc=93.73%]


New best model saved! Test accuracy: 93.73%
Train Loss: 0.2902, Train Acc: 93.44%
Test Loss: 0.2836, Test Acc: 93.73%

Epoch 12/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=0.2860, Acc=93.57%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.33it/s, Acc=93.75%]


New best model saved! Test accuracy: 93.75%
Train Loss: 0.2860, Train Acc: 93.57%
Test Loss: 0.2803, Test Acc: 93.75%

Epoch 13/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.54it/s, Loss=0.2826, Acc=93.65%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.47it/s, Acc=93.92%]


New best model saved! Test accuracy: 93.92%
Train Loss: 0.2826, Train Acc: 93.65%
Test Loss: 0.2777, Test Acc: 93.92%

Epoch 14/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.54it/s, Loss=0.2795, Acc=93.74%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.40it/s, Acc=94.07%]


New best model saved! Test accuracy: 94.07%
Train Loss: 0.2795, Train Acc: 93.74%
Test Loss: 0.2741, Test Acc: 94.07%

Epoch 15/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:11<00:00,  4.55it/s, Loss=0.2756, Acc=93.86%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.48it/s, Acc=94.07%]


Train Loss: 0.2756, Train Acc: 93.86%
Test Loss: 0.2713, Test Acc: 94.07%

Epoch 16/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.54it/s, Loss=0.2734, Acc=93.94%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.46it/s, Acc=94.10%]


New best model saved! Test accuracy: 94.10%
Train Loss: 0.2734, Train Acc: 93.94%
Test Loss: 0.2697, Test Acc: 94.10%

Epoch 17/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.54it/s, Loss=0.2711, Acc=94.00%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.48it/s, Acc=94.13%]


New best model saved! Test accuracy: 94.13%
Train Loss: 0.2711, Train Acc: 94.00%
Test Loss: 0.2673, Test Acc: 94.13%

Epoch 18/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.54it/s, Loss=0.2688, Acc=94.06%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.40it/s, Acc=94.14%]


New best model saved! Test accuracy: 94.14%
Train Loss: 0.2688, Train Acc: 94.06%
Test Loss: 0.2649, Test Acc: 94.14%

Epoch 19/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.54it/s, Loss=0.2665, Acc=94.13%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.46it/s, Acc=94.22%]


New best model saved! Test accuracy: 94.22%
Train Loss: 0.2665, Train Acc: 94.13%
Test Loss: 0.2622, Test Acc: 94.22%

Epoch 20/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.54it/s, Loss=0.2645, Acc=94.20%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.45it/s, Acc=94.26%]


New best model saved! Test accuracy: 94.26%
Train Loss: 0.2645, Train Acc: 94.20%
Test Loss: 0.2608, Test Acc: 94.26%

Epoch 21/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.54it/s, Loss=0.2628, Acc=94.25%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.45it/s, Acc=94.42%]


New best model saved! Test accuracy: 94.42%
Train Loss: 0.2628, Train Acc: 94.25%
Test Loss: 0.2590, Test Acc: 94.42%

Epoch 22/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.54it/s, Loss=0.2610, Acc=94.28%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.49it/s, Acc=94.41%]


Train Loss: 0.2610, Train Acc: 94.28%
Test Loss: 0.2578, Test Acc: 94.41%

Epoch 23/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.54it/s, Loss=0.2594, Acc=94.32%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.42it/s, Acc=94.43%]


New best model saved! Test accuracy: 94.43%
Train Loss: 0.2594, Train Acc: 94.32%
Test Loss: 0.2565, Test Acc: 94.43%

Epoch 24/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.54it/s, Loss=0.2583, Acc=94.36%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.46it/s, Acc=94.46%]


New best model saved! Test accuracy: 94.46%
Train Loss: 0.2583, Train Acc: 94.36%
Test Loss: 0.2558, Test Acc: 94.46%

Epoch 25/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.54it/s, Loss=0.2574, Acc=94.40%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.44it/s, Acc=94.54%]


New best model saved! Test accuracy: 94.54%
Train Loss: 0.2574, Train Acc: 94.40%
Test Loss: 0.2549, Test Acc: 94.54%

Epoch 26/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.54it/s, Loss=0.2566, Acc=94.40%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.44it/s, Acc=94.62%]


New best model saved! Test accuracy: 94.62%
Train Loss: 0.2566, Train Acc: 94.40%
Test Loss: 0.2544, Test Acc: 94.62%

Epoch 27/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=0.2559, Acc=94.43%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.37it/s, Acc=94.55%]


Train Loss: 0.2559, Train Acc: 94.43%
Test Loss: 0.2535, Test Acc: 94.55%

Epoch 28/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=0.2552, Acc=94.46%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.27it/s, Acc=94.59%]


Train Loss: 0.2552, Train Acc: 94.46%
Test Loss: 0.2529, Test Acc: 94.59%

Epoch 29/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=0.2545, Acc=94.51%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.49it/s, Acc=94.53%]


Train Loss: 0.2545, Train Acc: 94.51%
Test Loss: 0.2525, Test Acc: 94.53%

Epoch 30/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=0.2537, Acc=94.53%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.33it/s, Acc=94.69%]


New best model saved! Test accuracy: 94.69%
Train Loss: 0.2537, Train Acc: 94.53%
Test Loss: 0.2518, Test Acc: 94.69%

Epoch 31/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=0.2531, Acc=94.56%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.27it/s, Acc=94.69%]


Train Loss: 0.2531, Train Acc: 94.56%
Test Loss: 0.2511, Test Acc: 94.69%

Epoch 32/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=0.2525, Acc=94.58%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.25it/s, Acc=94.72%]


New best model saved! Test accuracy: 94.72%
Train Loss: 0.2525, Train Acc: 94.58%
Test Loss: 0.2506, Test Acc: 94.72%

Epoch 33/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=0.2519, Acc=94.58%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.27it/s, Acc=94.72%]


Train Loss: 0.2519, Train Acc: 94.58%
Test Loss: 0.2501, Test Acc: 94.72%

Epoch 34/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=0.2514, Acc=94.60%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.31it/s, Acc=94.80%]


New best model saved! Test accuracy: 94.80%
Train Loss: 0.2514, Train Acc: 94.60%
Test Loss: 0.2493, Test Acc: 94.80%

Epoch 35/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=0.2509, Acc=94.61%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.42it/s, Acc=94.81%]


New best model saved! Test accuracy: 94.81%
Train Loss: 0.2509, Train Acc: 94.61%
Test Loss: 0.2489, Test Acc: 94.81%

Epoch 36/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.54it/s, Loss=0.2504, Acc=94.63%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.46it/s, Acc=94.80%]


Train Loss: 0.2504, Train Acc: 94.63%
Test Loss: 0.2482, Test Acc: 94.80%

Epoch 37/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.54it/s, Loss=0.2500, Acc=94.63%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.41it/s, Acc=94.75%]


Train Loss: 0.2500, Train Acc: 94.63%
Test Loss: 0.2481, Test Acc: 94.75%

Epoch 38/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.54it/s, Loss=0.2496, Acc=94.64%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.46it/s, Acc=94.76%]


Train Loss: 0.2496, Train Acc: 94.64%
Test Loss: 0.2477, Test Acc: 94.76%

Epoch 39/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.54it/s, Loss=0.2493, Acc=94.67%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.46it/s, Acc=94.78%]


Train Loss: 0.2493, Train Acc: 94.67%
Test Loss: 0.2475, Test Acc: 94.78%

Epoch 40/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.54it/s, Loss=0.2490, Acc=94.68%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.49it/s, Acc=94.80%]


Train Loss: 0.2490, Train Acc: 94.68%
Test Loss: 0.2471, Test Acc: 94.80%

Epoch 41/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=0.2487, Acc=94.69%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.28it/s, Acc=94.77%]


Train Loss: 0.2487, Train Acc: 94.69%
Test Loss: 0.2470, Test Acc: 94.77%

Epoch 42/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=0.2484, Acc=94.70%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.28it/s, Acc=94.84%]


New best model saved! Test accuracy: 94.84%
Train Loss: 0.2484, Train Acc: 94.70%
Test Loss: 0.2469, Test Acc: 94.84%

Epoch 43/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=0.2482, Acc=94.70%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.34it/s, Acc=94.79%]


Train Loss: 0.2482, Train Acc: 94.70%
Test Loss: 0.2466, Test Acc: 94.79%

Epoch 44/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=0.2480, Acc=94.68%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.30it/s, Acc=94.81%]


Train Loss: 0.2480, Train Acc: 94.68%
Test Loss: 0.2464, Test Acc: 94.81%

Epoch 45/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=0.2478, Acc=94.70%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.27it/s, Acc=94.86%]


New best model saved! Test accuracy: 94.86%
Train Loss: 0.2478, Train Acc: 94.70%
Test Loss: 0.2462, Test Acc: 94.86%

Epoch 46/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=0.2476, Acc=94.71%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.34it/s, Acc=94.79%]


Train Loss: 0.2476, Train Acc: 94.71%
Test Loss: 0.2459, Test Acc: 94.79%

Epoch 47/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=0.2474, Acc=94.72%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.23it/s, Acc=94.81%]


Train Loss: 0.2474, Train Acc: 94.72%
Test Loss: 0.2457, Test Acc: 94.81%

Epoch 48/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=0.2472, Acc=94.75%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.33it/s, Acc=94.85%]


Train Loss: 0.2472, Train Acc: 94.75%
Test Loss: 0.2455, Test Acc: 94.85%

Epoch 49/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=0.2471, Acc=94.74%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.32it/s, Acc=94.84%]


Train Loss: 0.2471, Train Acc: 94.74%
Test Loss: 0.2454, Test Acc: 94.84%

Epoch 50/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=0.2469, Acc=94.73%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.35it/s, Acc=94.84%]


Train Loss: 0.2469, Train Acc: 94.73%
Test Loss: 0.2453, Test Acc: 94.84%

Epoch 51/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=0.2468, Acc=94.74%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.37it/s, Acc=94.89%]


New best model saved! Test accuracy: 94.89%
Train Loss: 0.2468, Train Acc: 94.74%
Test Loss: 0.2452, Test Acc: 94.89%

Epoch 52/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=0.2466, Acc=94.75%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.21it/s, Acc=94.88%]


Train Loss: 0.2466, Train Acc: 94.75%
Test Loss: 0.2451, Test Acc: 94.88%

Epoch 53/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.54it/s, Loss=0.2465, Acc=94.74%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.28it/s, Acc=94.88%]


Train Loss: 0.2465, Train Acc: 94.74%
Test Loss: 0.2451, Test Acc: 94.88%

Epoch 54/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=0.2464, Acc=94.74%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.31it/s, Acc=94.91%]


New best model saved! Test accuracy: 94.91%
Train Loss: 0.2464, Train Acc: 94.74%
Test Loss: 0.2448, Test Acc: 94.91%

Epoch 55/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.54it/s, Loss=0.2462, Acc=94.75%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.55it/s, Acc=94.93%]


New best model saved! Test accuracy: 94.93%
Train Loss: 0.2462, Train Acc: 94.75%
Test Loss: 0.2452, Test Acc: 94.93%

Epoch 56/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=0.2461, Acc=94.74%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.29it/s, Acc=94.87%]


Train Loss: 0.2461, Train Acc: 94.74%
Test Loss: 0.2452, Test Acc: 94.87%

Epoch 57/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=0.2460, Acc=94.74%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.32it/s, Acc=94.84%]


Train Loss: 0.2460, Train Acc: 94.74%
Test Loss: 0.2451, Test Acc: 94.84%

Epoch 58/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=0.2459, Acc=94.75%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.52it/s, Acc=94.83%]


Train Loss: 0.2459, Train Acc: 94.75%
Test Loss: 0.2452, Test Acc: 94.83%

Epoch 59/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=0.2457, Acc=94.75%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.35it/s, Acc=94.81%]


Train Loss: 0.2457, Train Acc: 94.75%
Test Loss: 0.2450, Test Acc: 94.81%

Epoch 60/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:11<00:00,  4.55it/s, Loss=0.2455, Acc=94.74%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.49it/s, Acc=94.81%]


Train Loss: 0.2455, Train Acc: 94.74%
Test Loss: 0.2450, Test Acc: 94.81%

Epoch 61/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:11<00:00,  4.55it/s, Loss=0.2454, Acc=94.75%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.47it/s, Acc=94.87%]


Train Loss: 0.2454, Train Acc: 94.75%
Test Loss: 0.2446, Test Acc: 94.87%

Epoch 62/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.54it/s, Loss=0.2453, Acc=94.75%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.35it/s, Acc=94.88%]


Train Loss: 0.2453, Train Acc: 94.75%
Test Loss: 0.2446, Test Acc: 94.88%

Epoch 63/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.54it/s, Loss=0.2452, Acc=94.75%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.23it/s, Acc=94.88%]


Train Loss: 0.2452, Train Acc: 94.75%
Test Loss: 0.2446, Test Acc: 94.88%

Epoch 64/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.54it/s, Loss=0.2451, Acc=94.77%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.69it/s, Acc=94.86%]


Train Loss: 0.2451, Train Acc: 94.77%
Test Loss: 0.2446, Test Acc: 94.86%

Epoch 65/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:11<00:00,  4.55it/s, Loss=0.2450, Acc=94.77%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.42it/s, Acc=94.85%]


Train Loss: 0.2450, Train Acc: 94.77%
Test Loss: 0.2445, Test Acc: 94.85%

Epoch 66/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=0.2449, Acc=94.77%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.26it/s, Acc=94.84%]


Train Loss: 0.2449, Train Acc: 94.77%
Test Loss: 0.2444, Test Acc: 94.84%

Epoch 67/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=0.2448, Acc=94.78%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.30it/s, Acc=94.84%]


Train Loss: 0.2448, Train Acc: 94.78%
Test Loss: 0.2444, Test Acc: 94.84%

Epoch 68/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=0.2448, Acc=94.76%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.29it/s, Acc=94.82%]


Train Loss: 0.2448, Train Acc: 94.76%
Test Loss: 0.2443, Test Acc: 94.82%

Epoch 69/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=0.2447, Acc=94.78%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.24it/s, Acc=94.84%]


Train Loss: 0.2447, Train Acc: 94.78%
Test Loss: 0.2442, Test Acc: 94.84%

Epoch 70/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=0.2446, Acc=94.77%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.32it/s, Acc=94.86%]


Train Loss: 0.2446, Train Acc: 94.77%
Test Loss: 0.2440, Test Acc: 94.86%

Epoch 71/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=0.2445, Acc=94.77%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.33it/s, Acc=94.87%]


Train Loss: 0.2445, Train Acc: 94.77%
Test Loss: 0.2440, Test Acc: 94.87%

Epoch 72/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=0.2444, Acc=94.78%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.33it/s, Acc=94.85%]


Train Loss: 0.2444, Train Acc: 94.78%
Test Loss: 0.2439, Test Acc: 94.85%

Epoch 73/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=0.2444, Acc=94.78%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.30it/s, Acc=94.85%]


Train Loss: 0.2444, Train Acc: 94.78%
Test Loss: 0.2438, Test Acc: 94.85%

Epoch 74/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=0.2443, Acc=94.77%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.31it/s, Acc=94.87%]


Train Loss: 0.2443, Train Acc: 94.77%
Test Loss: 0.2438, Test Acc: 94.87%

Epoch 75/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=0.2442, Acc=94.78%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.27it/s, Acc=94.87%]


Train Loss: 0.2442, Train Acc: 94.78%
Test Loss: 0.2437, Test Acc: 94.87%

Epoch 76/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=0.2441, Acc=94.80%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.29it/s, Acc=94.86%]


Train Loss: 0.2441, Train Acc: 94.80%
Test Loss: 0.2434, Test Acc: 94.86%

Epoch 77/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=0.2440, Acc=94.79%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.24it/s, Acc=94.87%]


Train Loss: 0.2440, Train Acc: 94.79%
Test Loss: 0.2434, Test Acc: 94.87%

Epoch 78/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.52it/s, Loss=0.2439, Acc=94.81%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.23it/s, Acc=94.90%]


Train Loss: 0.2439, Train Acc: 94.81%
Test Loss: 0.2432, Test Acc: 94.90%

Epoch 79/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=0.2437, Acc=94.81%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.32it/s, Acc=94.90%]


Train Loss: 0.2437, Train Acc: 94.81%
Test Loss: 0.2431, Test Acc: 94.90%

Epoch 80/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=0.2435, Acc=94.81%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.33it/s, Acc=94.89%]


Train Loss: 0.2435, Train Acc: 94.81%
Test Loss: 0.2428, Test Acc: 94.89%

Epoch 81/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=0.2434, Acc=94.83%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.44it/s, Acc=94.88%]


Train Loss: 0.2434, Train Acc: 94.83%
Test Loss: 0.2428, Test Acc: 94.88%

Epoch 82/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=0.2433, Acc=94.83%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.27it/s, Acc=94.91%]


Train Loss: 0.2433, Train Acc: 94.83%
Test Loss: 0.2427, Test Acc: 94.91%

Epoch 83/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=0.2433, Acc=94.83%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.30it/s, Acc=94.85%]


Train Loss: 0.2433, Train Acc: 94.83%
Test Loss: 0.2427, Test Acc: 94.85%

Epoch 84/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=0.2432, Acc=94.83%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.27it/s, Acc=94.84%]


Train Loss: 0.2432, Train Acc: 94.83%
Test Loss: 0.2427, Test Acc: 94.84%

Epoch 85/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=0.2431, Acc=94.83%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.32it/s, Acc=94.86%]


Train Loss: 0.2431, Train Acc: 94.83%
Test Loss: 0.2427, Test Acc: 94.86%

Epoch 86/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=0.2431, Acc=94.84%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.40it/s, Acc=94.85%]


Train Loss: 0.2431, Train Acc: 94.84%
Test Loss: 0.2426, Test Acc: 94.85%

Epoch 87/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.54it/s, Loss=0.2430, Acc=94.85%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.47it/s, Acc=94.87%]


Train Loss: 0.2430, Train Acc: 94.85%
Test Loss: 0.2426, Test Acc: 94.87%

Epoch 88/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.54it/s, Loss=0.2430, Acc=94.84%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.46it/s, Acc=94.91%]


Train Loss: 0.2430, Train Acc: 94.84%
Test Loss: 0.2425, Test Acc: 94.91%

Epoch 89/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.54it/s, Loss=0.2429, Acc=94.85%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.46it/s, Acc=94.92%]


Train Loss: 0.2429, Train Acc: 94.85%
Test Loss: 0.2427, Test Acc: 94.92%

Epoch 90/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.54it/s, Loss=0.2428, Acc=94.84%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.41it/s, Acc=94.90%]


Train Loss: 0.2428, Train Acc: 94.84%
Test Loss: 0.2426, Test Acc: 94.90%

Epoch 91/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.54it/s, Loss=0.2428, Acc=94.85%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.46it/s, Acc=94.90%]


Train Loss: 0.2428, Train Acc: 94.85%
Test Loss: 0.2426, Test Acc: 94.90%

Epoch 92/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.54it/s, Loss=0.2427, Acc=94.86%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.23it/s, Acc=94.92%]


Train Loss: 0.2427, Train Acc: 94.86%
Test Loss: 0.2425, Test Acc: 94.92%

Epoch 93/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.54it/s, Loss=0.2427, Acc=94.86%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.28it/s, Acc=94.90%]


Train Loss: 0.2427, Train Acc: 94.86%
Test Loss: 0.2424, Test Acc: 94.90%

Epoch 94/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.54it/s, Loss=0.2426, Acc=94.85%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.41it/s, Acc=94.93%]


Train Loss: 0.2426, Train Acc: 94.85%
Test Loss: 0.2424, Test Acc: 94.93%

Epoch 95/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.54it/s, Loss=0.2426, Acc=94.86%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.42it/s, Acc=94.93%]


Train Loss: 0.2426, Train Acc: 94.86%
Test Loss: 0.2423, Test Acc: 94.93%

Epoch 96/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:11<00:00,  4.55it/s, Loss=0.2426, Acc=94.86%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.41it/s, Acc=94.95%]


New best model saved! Test accuracy: 94.95%
Train Loss: 0.2426, Train Acc: 94.86%
Test Loss: 0.2423, Test Acc: 94.95%

Epoch 97/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.54it/s, Loss=0.2425, Acc=94.86%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.45it/s, Acc=94.95%]


Train Loss: 0.2425, Train Acc: 94.86%
Test Loss: 0.2423, Test Acc: 94.95%

Epoch 98/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=0.2425, Acc=94.87%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.29it/s, Acc=94.92%]


Train Loss: 0.2425, Train Acc: 94.87%
Test Loss: 0.2423, Test Acc: 94.92%

Epoch 99/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=0.2425, Acc=94.86%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.28it/s, Acc=94.92%]


Train Loss: 0.2425, Train Acc: 94.86%
Test Loss: 0.2422, Test Acc: 94.92%

Epoch 100/100
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=0.2425, Acc=94.85%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.34it/s, Acc=94.91%]


Train Loss: 0.2425, Train Acc: 94.85%
Test Loss: 0.2422, Test Acc: 94.91%

Training completed in 13645.03 seconds
Best test accuracy: 94.95%

Testing discrete inference...
Discrete inference accuracy: 94.91%
Inference speed: 0.000404 seconds per sample

LOGIC GATE CONVOLUTIONAL DIFFLOGIC MNIST RESULTS
Architecture: Custom logic gate convolutions + DiffLogic FC layers
Base kernel count (k): 16
Total parameters: 596,736
Training epochs: 100
Best test accuracy: 94.95%
Final test accuracy: 94.91%
Inference speed: 0.000404 seconds per sample

Logic gate convolutions replace traditional conv layers with:
- 16 different logic gates (AND, OR, XOR, NOT, etc.)
- Tree-structured logic processing with configurable depth
- Soft logic during training, hard logic during inference


In [14]:
# Retraining
# Type definitions
InitializationType = Literal['residual', 'random']

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Hyperparameters
batch_size = 100
learning_rate = 0.01
total_epochs = 70  # That's make 100 total epochs
k = 16  # Base number of kernels

print(f"Base kernel count k = {k}")

# Data loading and preprocessing
transform = transforms.Compose([
    transforms.ToTensor(),
])

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0)

print(f"Training samples: {len(train_dataset)}")
print(f"Test samples: {len(test_dataset)}")

# Logic gate definitions
logic_gates = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]

def apply_logic_gate(a: torch.Tensor, b: torch.Tensor, logic_gate: int):
    return {
        0:  torch.zeros_like(a),
        1:  a * b,
        2:  a - a * b,
        3:  a,
        4:  b - a * b,
        5:  b,
        6:  a + b - 2 * a * b,
        7:  a + b - a * b,
        8:  1 - (a + b - a * b),
        9:  1 - (a + b - 2 * a * b),
        10: 1 - b,
        11: 1 - b + a * b,
        12: 1 - a,
        13: 1 - a + a * b,
        14: 1 - a * b,
        15: torch.ones_like(a),
    }[logic_gate]

class Logic(nn.Module):
    def __init__(self,
                 in_dim: int,
                 out_dim: int,
                 initialization_type: InitializationType = 'residual',
                 device=None
                 ):
        super().__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.initialization_type = initialization_type
        self.device = device or torch.device('cpu')
        
        a, b = self.get_connections()
        self.register_buffer('a', a)
        self.register_buffer('b', b)
        
        weights = torch.randn(out_dim, len(logic_gates), device=self.device)
        if self.initialization_type == 'residual':
            weights[:, :] = 0
            weights[:, 3] = 5  # Initialize to identity gate
        self.weights = torch.nn.parameter.Parameter(weights)

    def forward(self, x: torch.Tensor):
        a, b = x[:, self.a, ...], x[:, self.b, ...]
        
        if self.training:
            normalized_weights = torch.nn.functional.softmax(self.weights, dim=-1).to(x.dtype).to(self.device)
            r = torch.zeros_like(a).to(x.dtype).to(self.device)
            for logic_gate in logic_gates:
                if len(a.shape) > 2:
                    nw = einops.repeat(normalized_weights[..., logic_gate], 'weights -> weights depth', depth=a.shape[-1])
                else:
                    nw = normalized_weights[..., logic_gate]
                r = r + nw * apply_logic_gate(a, b, logic_gate)
            return r
        else:
            one_hot_weights = torch.nn.functional.one_hot(self.weights.argmax(-1), len(logic_gates)).to(torch.float32).to(self.device)
            with torch.no_grad():
                r = torch.zeros_like(a).to(x.dtype).to(self.device)
                for logic_gate in logic_gates:
                    if len(a.shape) > 2:
                        ohw = einops.repeat(one_hot_weights[..., logic_gate], 'weights -> weights depth', depth=a.shape[-1])
                    else:
                        ohw = one_hot_weights[..., logic_gate]
                    r = r + ohw * apply_logic_gate(a, b, logic_gate)
                return r

    def get_connections(self):
        connections = torch.randperm(2 * self.out_dim) % self.in_dim
        connections = torch.randperm(self.in_dim)[connections]
        connections = connections.reshape(2, self.out_dim)
        a, b = connections[0], connections[1]
        a, b = a.to(torch.int64), b.to(torch.int64)
        a, b = a.to(self.device), b.to(self.device)
        return a, b

class LogicTree(nn.Module):
    def __init__(self,
                 in_dim: int,
                 depth: int = 3,
                 initialization_type: InitializationType = 'residual',
                 device=None,
                 ):
        super().__init__()
        self.device = device or torch.device('cpu')
        
        layers = [LogicLayer(in_dim, int(2 ** (depth - 1)), initialization_type=initialization_type, device=self.device,implementation='cuda' if device.type == 'cuda' else 'python',
                    connections='random',grad_factor=1.5 )]
        for i in range(0, depth - 1, 1):
            layers.append(LogicLayer(int(2 ** (depth - 1 - i)), int(2 ** (depth - 1 - i - 1)), 
                            initialization_type=initialization_type, device=self.device,implementation='cuda' if device.type == 'cuda' else 'python',
                            connections='random',grad_factor=1.5))
        
        self.tree = nn.Sequential(*layers)

    def forward(self, x: torch.Tensor):
        return self.tree(x)

class Conv(nn.Module):
    def __init__(self,
                 in_channels: int,
                 out_channels: int,
                 depth: int = 3,
                 kernel_size: int = 3,
                 stride: int = 1,
                 padding: int = 1,
                 initialization_type: InitializationType = 'residual',
                 device=None
                 ):
        super().__init__()
        self.kernel_size = kernel_size
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.stride = stride
        self.padding = padding
        self.device = device or torch.device('cpu')
        
        self.filters = nn.ModuleList([
            LogicTree(in_dim=kernel_size ** 2 * in_channels, depth=depth, 
                     initialization_type=initialization_type, device=self.device) 
            for _ in range(out_channels)
        ])

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        batch_size, _, height, width = x.shape
        x = F.pad(x, (self.padding, self.padding, self.padding, self.padding), mode='constant', value=0)
        out_height = (height + 2 * self.padding - self.kernel_size) // self.stride + 1
        out_width = (width + 2 * self.padding - self.kernel_size) // self.stride + 1
        patches = F.unfold(x, kernel_size=self.kernel_size, stride=self.stride)
        outputs = []
        
        patches = einops.rearrange(patches, 'b h w -> (b w) h', h=patches.shape[1], w=patches.shape[2]) # Input is (100,25,576) Output: (57600,25)
        for filter in self.filters:
            out = filter(patches)  # Shape: (batch_size, 1, out_height * out_width)
            out = einops.rearrange(out, '(b h w) 1 -> b (h w)', h=out_height, w=out_width)
            outputs.append(out)
        
        output_tensor = torch.stack(outputs, dim=1)  # Shape: (batch_size, out_channels, out_height * out_width)
        output_tensor = einops.rearrange(output_tensor, 'b c (h w) -> b c h w', h=out_height, w=out_width)
        return output_tensor

class CustomOrPool2d(nn.Module):
    def __init__(self, kernel_size=2, stride=2, padding=0):
        super(CustomOrPool2d, self).__init__()
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        
    def forward(self, x):
        # Use MaxPool2d as approximation to OR pooling
        # In binary logic, max operation approximates OR
        return torch.max_pool2d(x, self.kernel_size, self.stride, self.padding)

class ConvDiffLogicMNIST(nn.Module):
    def __init__(self, k=16):
        super(ConvDiffLogicMNIST, self).__init__()
        self.k = k
        
        # Convolutional block 1: k kernels, 5x5, depth=3, no padding
        # Input: 1 × 28 × 28 -> Output: k × 24 × 24 (28-5+1=24)
        self.conv1 = Conv(in_channels=1, out_channels=k, kernel_size=5, depth=3, 
                         padding=0, initialization_type='residual', device=device)
        
        # OR pooling 1: 2x2, stride 2
        # k × 24 × 24 -> k × 12 × 12
        self.pool1 = CustomOrPool2d(kernel_size=2, stride=2)
        
        # Convolutional block 2: 3*k kernels, 3x3, depth=3
        # k × 12 × 12 -> 3*k × 12 × 12 (with padding=1), then pooled to 3*k × 6 × 6
        self.conv2 = Conv(in_channels=k, out_channels=3*k, kernel_size=3, depth=3, 
                         padding=1, initialization_type='residual', device=device)
        
        # OR pooling 2: 2x2, stride 2
        # 3*k × 12 × 12 -> 3*k × 6 × 6
        self.pool2 = CustomOrPool2d(kernel_size=2, stride=2)
        
        # Convolutional block 3: 9*k kernels, 3x3, depth=3
        # 3*k × 6 × 6 -> 9*k × 6 × 6 (with padding=1), then pooled to 9*k × 3 × 3
        self.conv3 = Conv(in_channels=3*k, out_channels=9*k, kernel_size=3, depth=3, 
                         padding=1, initialization_type='residual', device=device)
        
        # OR pooling 3: 2x2, stride 2
        # 9*k × 6 × 6 -> 9*k × 3 × 3
        self.pool3 = CustomOrPool2d(kernel_size=2, stride=2)
        
        # Flatten: 9*k × 3 × 3 -> 81*k
        self.flatten = nn.Flatten()
        
        # Regular differentiable logic layers (as specified in paper)
        # 81*k → 1280*k
        self.fc1 = LogicLayer(
            in_dim=81*k,
            out_dim=1280*k,
            device=device,
            implementation='cuda' if device.type == 'cuda' else 'python',
            connections='random',
            grad_factor=1.5 , # Higher for deeper networks
        )
        
        # 1280*k → 640*k
        self.fc2 = LogicLayer(
            in_dim=1280*k,
            out_dim=640*k,
            device=device,
            implementation='cuda' if device.type == 'cuda' else 'python',
            connections='random',
            grad_factor=1.5
        )
        
        # 640*k → 320*k
        self.fc3 = LogicLayer(
            in_dim=640*k,
            out_dim=320*k,
            device=device,
            implementation='cuda' if device.type == 'cuda' else 'python',
            connections='random',
            grad_factor=1.5
        )
        
        # GroupSum: 320*k → 10 (10 classes)
        # Using tau=30 as in the paper specifications
        self.group_sum = GroupSum(k=10, tau=30)
        
    def forward(self, x):
        # Input thresholding for binary processing (as mentioned in paper)
        # The paper mentions using binary inputs
        x = (x > 0.5).float()
        
        # Debug shape printing (uncomment for debugging)
        # print(f"Input shape: {x.shape}")
        
        # Convolutional processing with logic gates
        x = self.conv1(x)
        # print(f"After conv1: {x.shape}")
        
        x = self.pool1(x)
        # print(f"After pool1: {x.shape}")
        
        x = self.conv2(x)
        # print(f"After conv2: {x.shape}")
        
        x = self.pool2(x)
        # print(f"After pool2: {x.shape}")
        
        x = self.conv3(x)
        # print(f"After conv3: {x.shape}")
        
        x = self.pool3(x)
        # print(f"After pool3: {x.shape}")
        
        # Flatten
        x = self.flatten(x)
        # print(f"After flatten: {x.shape}")
        
        # Fully connected logic layers
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        
        # GroupSum for classification
        x = self.group_sum(x)
        
        return x

# Initialize model
model = ConvDiffLogicMNIST(k=k).to(device)
print(f"Model created with {sum(p.numel() for p in model.parameters())} parameters")

# Print architecture details
print("\n" + "="*80)
print("LOGIC GATE CONVOLUTIONAL DIFFLOGIC MNIST")
print("="*80)
print(f"Input: 1 × 28 × 28")
print(f"Conv1: {k} logic gate filters, 5×5, depth=3, no padding -> {k} × 24 × 24")
print(f"Pool1: OR pooling 2×2, stride 2 -> {k} × 12 × 12")
print(f"Conv2: {3*k} logic gate filters, 3×3, depth=3 -> {3*k} × 12 × 12")
print(f"Pool2: OR pooling 2×2, stride 2 -> {3*k} × 6 × 6")
print(f"Conv3: {9*k} logic gate filters, 3×3, depth=3 -> {9*k} × 6 × 6")
print(f"Pool3: OR pooling 2×2, stride 2 -> {9*k} × 3 × 3")
print(f"Flatten: -> {81*k}")
print(f"FC1: Regular differentiable logic layer {81*k} -> {1280*k}")
print(f"FC2: Regular differentiable logic layer {1280*k} -> {640*k}")
print(f"FC3: Regular differentiable logic layer {640*k} -> {320*k}")
print(f"GroupSum: {320*k} -> 10 classes")
print("="*80)

# Loss function
criterion = nn.CrossEntropyLoss()

model_path = 'best_logic_conv_difflogic_mnist.pth'
current_best_acc = 0.0

# Load checkpoint if exists
if os.path.exists(model_path):
    model.load_state_dict(torch.load(model_path))
    print(f"Loaded best model checkpoint from {model_path}")
    
    # Evaluate the loaded model to get current best accuracy
    model.eval()
    with torch.no_grad():
        correct = 0
        total = 0
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            outputs = model(data)
            _, predicted = torch.max(outputs.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()
        
        current_best_acc = 100. * correct / total
        print(f"Loaded model accuracy: {current_best_acc:.2f}%")

# Initialize optimizer AFTER potential model loading
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Training function
def train_epoch(model, train_loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    progress_bar = tqdm(train_loader, desc='Training')
    for batch_idx, (data, target) in enumerate(progress_bar):
        data, target = data.to(device), target.to(device)
        
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        _, predicted = torch.max(output.data, 1)
        total += target.size(0)
        correct += (predicted == target).sum().item()
        
        progress_bar.set_postfix({
            'Loss': f'{running_loss/(batch_idx+1):.4f}',
            'Acc': f'{100.*correct/total:.2f}%'
        })
    
    return running_loss / len(train_loader), 100. * correct / total

# Evaluation function
def evaluate(model, test_loader, criterion, device):
    model.eval()
    test_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        progress_bar = tqdm(test_loader, desc='Testing')
        for data, target in progress_bar:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion(output, target).item()
            
            _, predicted = torch.max(output.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()
            
            progress_bar.set_postfix({
                'Acc': f'{100.*correct/total:.2f}%'
            })
    
    return test_loss / len(test_loader), 100. * correct / total

if __name__ == '__main__':
    print("Starting training with logic gate convolutions...")
    train_losses = []
    train_accuracies = []
    test_losses = []
    test_accuracies = []

    start_time = time.time()

    for epoch in range(0, total_epochs):
        print(f"\nEpoch {epoch+1}/{total_epochs}")
        print("-" * 50)
        
        # Train
        train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
        train_losses.append(train_loss)
        train_accuracies.append(train_acc)
        
        # Evaluate
        test_loss, test_acc = evaluate(model, test_loader, criterion, device)
        test_losses.append(test_loss)
        test_accuracies.append(test_acc)
        
        # Save best model
        if test_acc > current_best_acc:
            current_best_acc = test_acc
            torch.save(model.state_dict(), model_path)
            print(f"New best model saved! Test accuracy: {test_acc:.2f}%")
        
        print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
        print(f"Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.2f}%")

    total_time = time.time() - start_time
    print(f"\nTraining completed in {total_time:.2f} seconds")
    print(f"Best test accuracy: {current_best_acc:.2f}%")

    # Test discrete inference with best model
    print("\nTesting discrete inference with BEST model...")
    model.load_state_dict(torch.load(model_path))  # Load best model
    model.eval()  # This switches to discrete/hard logic mode

    start_time = time.time()
    correct = 0
    total = 0

    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            _, predicted = torch.max(output.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()

    inference_time = time.time() - start_time
    inference_speed = inference_time / total

    print(f"Discrete inference accuracy: {100. * correct / total:.2f}%")
    print(f"Inference speed: {inference_speed:.6f} seconds per sample")

    # Save final model
    torch.save(model.state_dict(), 'final_logic_conv_difflogic_mnist.pth')

    # Final summary
    print("\n" + "="*80)
    print("LOGIC GATE CONVOLUTIONAL DIFFLOGIC MNIST RESULTS")
    print("="*80)
    print(f"Architecture: Custom logic gate convolutions + DiffLogic FC layers")
    print(f"Base kernel count (k): {k}")
    print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
    print(f"Training epochs: {total_epochs}")
    print(f"Best test accuracy: {current_best_acc:.2f}%")
    print(f"Inference speed: {inference_speed:.6f} seconds per sample")
    print()
    print("Logic gate convolutions replace traditional conv layers with:")
    print("- 16 different logic gates (AND, OR, XOR, NOT, etc.)")
    print("- Tree-structured logic processing with configurable depth")
    print("- Soft logic during training, hard logic during inference")
    print("="*80)

Using device: cuda
Base kernel count k = 16
Training samples: 60000
Test samples: 10000
Model created with 596736 parameters

LOGIC GATE CONVOLUTIONAL DIFFLOGIC MNIST
Input: 1 × 28 × 28
Conv1: 16 logic gate filters, 5×5, depth=3, no padding -> 16 × 24 × 24
Pool1: OR pooling 2×2, stride 2 -> 16 × 12 × 12
Conv2: 48 logic gate filters, 3×3, depth=3 -> 48 × 12 × 12
Pool2: OR pooling 2×2, stride 2 -> 48 × 6 × 6
Conv3: 144 logic gate filters, 3×3, depth=3 -> 144 × 6 × 6
Pool3: OR pooling 2×2, stride 2 -> 144 × 3 × 3
Flatten: -> 1296
FC1: Regular differentiable logic layer 1296 -> 20480
FC2: Regular differentiable logic layer 20480 -> 10240
FC3: Regular differentiable logic layer 10240 -> 5120
GroupSum: 5120 -> 10 classes
Loaded best model checkpoint from best_logic_conv_difflogic_mnist.pth
Loaded model accuracy: 13.09%
Starting training with logic gate convolutions...

Epoch 1/70
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=1.5881, Acc=62.49%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.50it/s, Acc=82.02%]


New best model saved! Test accuracy: 82.02%
Train Loss: 1.5881, Train Acc: 62.49%
Test Loss: 0.9703, Test Acc: 82.02%

Epoch 2/70
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=0.8845, Acc=82.77%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.60it/s, Acc=85.72%]


New best model saved! Test accuracy: 85.72%
Train Loss: 0.8845, Train Acc: 82.77%
Test Loss: 0.7682, Test Acc: 85.72%

Epoch 3/70
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=0.7575, Acc=85.09%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.54it/s, Acc=87.28%]


New best model saved! Test accuracy: 87.28%
Train Loss: 0.7575, Train Acc: 85.09%
Test Loss: 0.6980, Test Acc: 87.28%

Epoch 4/70
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=0.7055, Acc=85.83%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.57it/s, Acc=87.55%]


New best model saved! Test accuracy: 87.55%
Train Loss: 0.7055, Train Acc: 85.83%
Test Loss: 0.6662, Test Acc: 87.55%

Epoch 5/70
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=0.6776, Acc=86.38%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.60it/s, Acc=87.71%]


New best model saved! Test accuracy: 87.71%
Train Loss: 0.6776, Train Acc: 86.38%
Test Loss: 0.6449, Test Acc: 87.71%

Epoch 6/70
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=0.6559, Acc=86.75%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.46it/s, Acc=88.42%]


New best model saved! Test accuracy: 88.42%
Train Loss: 0.6559, Train Acc: 86.75%
Test Loss: 0.6290, Test Acc: 88.42%

Epoch 7/70
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=0.6405, Acc=87.12%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.55it/s, Acc=88.45%]


New best model saved! Test accuracy: 88.45%
Train Loss: 0.6405, Train Acc: 87.12%
Test Loss: 0.6160, Test Acc: 88.45%

Epoch 8/70
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=0.6284, Acc=87.36%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.56it/s, Acc=88.88%]


New best model saved! Test accuracy: 88.88%
Train Loss: 0.6284, Train Acc: 87.36%
Test Loss: 0.6041, Test Acc: 88.88%

Epoch 9/70
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=0.6180, Acc=87.65%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.56it/s, Acc=89.01%]


New best model saved! Test accuracy: 89.01%
Train Loss: 0.6180, Train Acc: 87.65%
Test Loss: 0.5964, Test Acc: 89.01%

Epoch 10/70
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=0.6100, Acc=87.80%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.60it/s, Acc=89.04%]


New best model saved! Test accuracy: 89.04%
Train Loss: 0.6100, Train Acc: 87.80%
Test Loss: 0.5905, Test Acc: 89.04%

Epoch 11/70
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=0.6036, Acc=87.99%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.58it/s, Acc=89.20%]


New best model saved! Test accuracy: 89.20%
Train Loss: 0.6036, Train Acc: 87.99%
Test Loss: 0.5852, Test Acc: 89.20%

Epoch 12/70
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=0.5966, Acc=88.25%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.57it/s, Acc=89.24%]


New best model saved! Test accuracy: 89.24%
Train Loss: 0.5966, Train Acc: 88.25%
Test Loss: 0.5785, Test Acc: 89.24%

Epoch 13/70
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=0.5908, Acc=88.39%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.55it/s, Acc=89.29%]


New best model saved! Test accuracy: 89.29%
Train Loss: 0.5908, Train Acc: 88.39%
Test Loss: 0.5739, Test Acc: 89.29%

Epoch 14/70
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=0.5851, Acc=88.44%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.49it/s, Acc=89.39%]


New best model saved! Test accuracy: 89.39%
Train Loss: 0.5851, Train Acc: 88.44%
Test Loss: 0.5671, Test Acc: 89.39%

Epoch 15/70
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=0.5791, Acc=88.57%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.57it/s, Acc=89.54%]


New best model saved! Test accuracy: 89.54%
Train Loss: 0.5791, Train Acc: 88.57%
Test Loss: 0.5626, Test Acc: 89.54%

Epoch 16/70
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=0.5750, Acc=88.68%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.62it/s, Acc=89.61%]


New best model saved! Test accuracy: 89.61%
Train Loss: 0.5750, Train Acc: 88.68%
Test Loss: 0.5586, Test Acc: 89.61%

Epoch 17/70
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=0.5714, Acc=88.73%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.52it/s, Acc=89.73%]


New best model saved! Test accuracy: 89.73%
Train Loss: 0.5714, Train Acc: 88.73%
Test Loss: 0.5554, Test Acc: 89.73%

Epoch 18/70
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=0.5685, Acc=88.78%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.56it/s, Acc=89.80%]


New best model saved! Test accuracy: 89.80%
Train Loss: 0.5685, Train Acc: 88.78%
Test Loss: 0.5528, Test Acc: 89.80%

Epoch 19/70
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.54it/s, Loss=0.5645, Acc=88.94%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.57it/s, Acc=89.90%]


New best model saved! Test accuracy: 89.90%
Train Loss: 0.5645, Train Acc: 88.94%
Test Loss: 0.5484, Test Acc: 89.90%

Epoch 20/70
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=0.5614, Acc=89.03%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.46it/s, Acc=90.07%]


New best model saved! Test accuracy: 90.07%
Train Loss: 0.5614, Train Acc: 89.03%
Test Loss: 0.5453, Test Acc: 90.07%

Epoch 21/70
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.52it/s, Loss=0.5578, Acc=89.10%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.59it/s, Acc=89.93%]


Train Loss: 0.5578, Train Acc: 89.10%
Test Loss: 0.5418, Test Acc: 89.93%

Epoch 22/70
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=0.5545, Acc=89.22%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.51it/s, Acc=90.00%]


Train Loss: 0.5545, Train Acc: 89.22%
Test Loss: 0.5383, Test Acc: 90.00%

Epoch 23/70
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=0.5507, Acc=89.30%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.53it/s, Acc=90.33%]


New best model saved! Test accuracy: 90.33%
Train Loss: 0.5507, Train Acc: 89.30%
Test Loss: 0.5328, Test Acc: 90.33%

Epoch 24/70
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=0.5468, Acc=89.40%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.58it/s, Acc=90.28%]


Train Loss: 0.5468, Train Acc: 89.40%
Test Loss: 0.5306, Test Acc: 90.28%

Epoch 25/70
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=0.5444, Acc=89.42%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.55it/s, Acc=90.38%]


New best model saved! Test accuracy: 90.38%
Train Loss: 0.5444, Train Acc: 89.42%
Test Loss: 0.5282, Test Acc: 90.38%

Epoch 26/70
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=0.5419, Acc=89.48%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.55it/s, Acc=90.44%]


New best model saved! Test accuracy: 90.44%
Train Loss: 0.5419, Train Acc: 89.48%
Test Loss: 0.5263, Test Acc: 90.44%

Epoch 27/70
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=0.5399, Acc=89.52%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.42it/s, Acc=90.31%]


Train Loss: 0.5399, Train Acc: 89.52%
Test Loss: 0.5244, Test Acc: 90.31%

Epoch 28/70
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.52it/s, Loss=0.5386, Acc=89.53%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.40it/s, Acc=90.46%]


New best model saved! Test accuracy: 90.46%
Train Loss: 0.5386, Train Acc: 89.53%
Test Loss: 0.5236, Test Acc: 90.46%

Epoch 29/70
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.52it/s, Loss=0.5373, Acc=89.59%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.39it/s, Acc=90.49%]


New best model saved! Test accuracy: 90.49%
Train Loss: 0.5373, Train Acc: 89.59%
Test Loss: 0.5226, Test Acc: 90.49%

Epoch 30/70
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.52it/s, Loss=0.5358, Acc=89.61%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.56it/s, Acc=90.57%]


New best model saved! Test accuracy: 90.57%
Train Loss: 0.5358, Train Acc: 89.61%
Test Loss: 0.5207, Test Acc: 90.57%

Epoch 31/70
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=0.5345, Acc=89.61%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.57it/s, Acc=90.69%]


New best model saved! Test accuracy: 90.69%
Train Loss: 0.5345, Train Acc: 89.61%
Test Loss: 0.5197, Test Acc: 90.69%

Epoch 32/70
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.52it/s, Loss=0.5332, Acc=89.64%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.42it/s, Acc=90.73%]


New best model saved! Test accuracy: 90.73%
Train Loss: 0.5332, Train Acc: 89.64%
Test Loss: 0.5193, Test Acc: 90.73%

Epoch 33/70
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.52it/s, Loss=0.5323, Acc=89.66%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.47it/s, Acc=90.78%]


New best model saved! Test accuracy: 90.78%
Train Loss: 0.5323, Train Acc: 89.66%
Test Loss: 0.5181, Test Acc: 90.78%

Epoch 34/70
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.52it/s, Loss=0.5314, Acc=89.67%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.45it/s, Acc=90.73%]


Train Loss: 0.5314, Train Acc: 89.67%
Test Loss: 0.5173, Test Acc: 90.73%

Epoch 35/70
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.52it/s, Loss=0.5300, Acc=89.71%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.40it/s, Acc=90.67%]


Train Loss: 0.5300, Train Acc: 89.71%
Test Loss: 0.5156, Test Acc: 90.67%

Epoch 36/70
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.52it/s, Loss=0.5261, Acc=89.78%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.46it/s, Acc=90.62%]


Train Loss: 0.5261, Train Acc: 89.78%
Test Loss: 0.5120, Test Acc: 90.62%

Epoch 37/70
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.52it/s, Loss=0.5245, Acc=89.78%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.42it/s, Acc=90.66%]


Train Loss: 0.5245, Train Acc: 89.78%
Test Loss: 0.5107, Test Acc: 90.66%

Epoch 38/70
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.52it/s, Loss=0.5237, Acc=89.83%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.42it/s, Acc=90.72%]


Train Loss: 0.5237, Train Acc: 89.83%
Test Loss: 0.5099, Test Acc: 90.72%

Epoch 39/70
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.52it/s, Loss=0.5231, Acc=89.80%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.42it/s, Acc=90.73%]


Train Loss: 0.5231, Train Acc: 89.80%
Test Loss: 0.5096, Test Acc: 90.73%

Epoch 40/70
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.52it/s, Loss=0.5223, Acc=89.85%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.42it/s, Acc=90.74%]


Train Loss: 0.5223, Train Acc: 89.85%
Test Loss: 0.5086, Test Acc: 90.74%

Epoch 41/70
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=0.5217, Acc=89.85%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.53it/s, Acc=90.71%]


Train Loss: 0.5217, Train Acc: 89.85%
Test Loss: 0.5081, Test Acc: 90.71%

Epoch 42/70
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.52it/s, Loss=0.5208, Acc=89.86%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.45it/s, Acc=90.77%]


Train Loss: 0.5208, Train Acc: 89.86%
Test Loss: 0.5073, Test Acc: 90.77%

Epoch 43/70
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.51it/s, Loss=0.5201, Acc=89.89%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.40it/s, Acc=90.76%]


Train Loss: 0.5201, Train Acc: 89.89%
Test Loss: 0.5063, Test Acc: 90.76%

Epoch 44/70
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=0.5192, Acc=89.98%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.55it/s, Acc=90.78%]


Train Loss: 0.5192, Train Acc: 89.98%
Test Loss: 0.5055, Test Acc: 90.78%

Epoch 45/70
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=0.5181, Acc=89.96%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.54it/s, Acc=90.87%]


New best model saved! Test accuracy: 90.87%
Train Loss: 0.5181, Train Acc: 89.96%
Test Loss: 0.5040, Test Acc: 90.87%

Epoch 46/70
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=0.5172, Acc=89.97%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.43it/s, Acc=90.88%]


New best model saved! Test accuracy: 90.88%
Train Loss: 0.5172, Train Acc: 89.97%
Test Loss: 0.5034, Test Acc: 90.88%

Epoch 47/70
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.51it/s, Loss=0.5168, Acc=89.96%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.39it/s, Acc=90.82%]


Train Loss: 0.5168, Train Acc: 89.96%
Test Loss: 0.5030, Test Acc: 90.82%

Epoch 48/70
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.51it/s, Loss=0.5162, Acc=89.98%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.44it/s, Acc=90.84%]


Train Loss: 0.5162, Train Acc: 89.98%
Test Loss: 0.5027, Test Acc: 90.84%

Epoch 49/70
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.52it/s, Loss=0.5156, Acc=89.94%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.42it/s, Acc=90.86%]


Train Loss: 0.5156, Train Acc: 89.94%
Test Loss: 0.5014, Test Acc: 90.86%

Epoch 50/70
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.51it/s, Loss=0.5146, Acc=89.97%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.38it/s, Acc=90.86%]


Train Loss: 0.5146, Train Acc: 89.97%
Test Loss: 0.5008, Test Acc: 90.86%

Epoch 51/70
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.52it/s, Loss=0.5138, Acc=89.98%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.56it/s, Acc=90.86%]


Train Loss: 0.5138, Train Acc: 89.98%
Test Loss: 0.4997, Test Acc: 90.86%

Epoch 52/70
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=0.5132, Acc=89.97%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.51it/s, Acc=90.89%]


New best model saved! Test accuracy: 90.89%
Train Loss: 0.5132, Train Acc: 89.97%
Test Loss: 0.4990, Test Acc: 90.89%

Epoch 53/70
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=0.5128, Acc=89.99%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.56it/s, Acc=90.84%]


Train Loss: 0.5128, Train Acc: 89.99%
Test Loss: 0.4986, Test Acc: 90.84%

Epoch 54/70
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=0.5125, Acc=89.98%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.59it/s, Acc=90.86%]


Train Loss: 0.5125, Train Acc: 89.98%
Test Loss: 0.4984, Test Acc: 90.86%

Epoch 55/70
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.52it/s, Loss=0.5121, Acc=89.99%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.40it/s, Acc=90.82%]


Train Loss: 0.5121, Train Acc: 89.99%
Test Loss: 0.4977, Test Acc: 90.82%

Epoch 56/70
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.52it/s, Loss=0.5117, Acc=89.97%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.48it/s, Acc=90.92%]


New best model saved! Test accuracy: 90.92%
Train Loss: 0.5117, Train Acc: 89.97%
Test Loss: 0.4972, Test Acc: 90.92%

Epoch 57/70
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.51it/s, Loss=0.5113, Acc=89.99%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.46it/s, Acc=90.92%]


Train Loss: 0.5113, Train Acc: 89.99%
Test Loss: 0.4971, Test Acc: 90.92%

Epoch 58/70
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.52it/s, Loss=0.5110, Acc=90.01%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.61it/s, Acc=90.97%]


New best model saved! Test accuracy: 90.97%
Train Loss: 0.5110, Train Acc: 90.01%
Test Loss: 0.4962, Test Acc: 90.97%

Epoch 59/70
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=0.5103, Acc=89.98%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.58it/s, Acc=90.91%]


Train Loss: 0.5103, Train Acc: 89.98%
Test Loss: 0.4954, Test Acc: 90.91%

Epoch 60/70
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=0.5094, Acc=90.03%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.60it/s, Acc=90.93%]


Train Loss: 0.5094, Train Acc: 90.03%
Test Loss: 0.4950, Test Acc: 90.93%

Epoch 61/70
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=0.5080, Acc=90.07%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.56it/s, Acc=90.93%]


Train Loss: 0.5080, Train Acc: 90.07%
Test Loss: 0.4938, Test Acc: 90.93%

Epoch 62/70
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=0.5075, Acc=90.09%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.60it/s, Acc=90.96%]


Train Loss: 0.5075, Train Acc: 90.09%
Test Loss: 0.4933, Test Acc: 90.96%

Epoch 63/70
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=0.5072, Acc=90.08%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.60it/s, Acc=90.97%]


Train Loss: 0.5072, Train Acc: 90.08%
Test Loss: 0.4931, Test Acc: 90.97%

Epoch 64/70
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=0.5068, Acc=90.08%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.54it/s, Acc=90.95%]


Train Loss: 0.5068, Train Acc: 90.08%
Test Loss: 0.4926, Test Acc: 90.95%

Epoch 65/70
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=0.5064, Acc=90.09%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.56it/s, Acc=90.91%]


Train Loss: 0.5064, Train Acc: 90.09%
Test Loss: 0.4921, Test Acc: 90.91%

Epoch 66/70
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.53it/s, Loss=0.5059, Acc=90.11%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.61it/s, Acc=90.94%]


Train Loss: 0.5059, Train Acc: 90.11%
Test Loss: 0.4918, Test Acc: 90.94%

Epoch 67/70
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.51it/s, Loss=0.5053, Acc=90.14%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.45it/s, Acc=90.86%]


Train Loss: 0.5053, Train Acc: 90.14%
Test Loss: 0.4911, Test Acc: 90.86%

Epoch 68/70
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.52it/s, Loss=0.5048, Acc=90.14%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.45it/s, Acc=91.01%]


New best model saved! Test accuracy: 91.01%
Train Loss: 0.5048, Train Acc: 90.14%
Test Loss: 0.4908, Test Acc: 91.01%

Epoch 69/70
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.51it/s, Loss=0.5046, Acc=90.14%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.41it/s, Acc=90.95%]


Train Loss: 0.5046, Train Acc: 90.14%
Test Loss: 0.4906, Test Acc: 90.95%

Epoch 70/70
--------------------------------------------------


Training: 100%|██████████| 600/600 [02:12<00:00,  4.51it/s, Loss=0.5043, Acc=90.15%]
Testing: 100%|██████████| 100/100 [00:04<00:00, 24.45it/s, Acc=90.95%]


Train Loss: 0.5043, Train Acc: 90.15%
Test Loss: 0.4904, Test Acc: 90.95%

Training completed in 9568.44 seconds
Best test accuracy: 91.01%

Testing discrete inference with BEST model...
Discrete inference accuracy: 91.01%
Inference speed: 0.000403 seconds per sample

LOGIC GATE CONVOLUTIONAL DIFFLOGIC MNIST RESULTS
Architecture: Custom logic gate convolutions + DiffLogic FC layers
Base kernel count (k): 16
Total parameters: 596,736
Training epochs: 70
Best test accuracy: 91.01%
Inference speed: 0.000403 seconds per sample

Logic gate convolutions replace traditional conv layers with:
- 16 different logic gates (AND, OR, XOR, NOT, etc.)
- Tree-structured logic processing with configurable depth
- Soft logic during training, hard logic during inference


In [10]:
# Added Validation
# Type definitions
InitializationType = Literal['residual', 'random']

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Hyperparameters
batch_size = 256
learning_rate = 0.01
num_epochs = 100
k = 16  # Base number of kernels (from paper: k=16 for small model)

print(f"Base kernel count k = {k}")
print(f"Expected shapes from paper:")
print(f"After conv1 + pool1: {k} × 12 × 12")
print(f"After conv2 + pool2: {3*k} × 6 × 6") 
print(f"After conv3 + pool3: {9*k} × 3 × 3")
print(f"After flattening: {81*k}")

# Data loading and preprocessing
transform = transforms.Compose([
    transforms.ToTensor(),
])

# Load dataset
full_train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform)

# Split into 50k train and 10k validation
train_size = 50000
val_size = 10000
torch.manual_seed(42)  # For reproducibility
train_dataset, val_dataset = torch.utils.data.random_split(
    full_train_dataset, [train_size, val_size]
)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0)

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")
print(f"Test samples: {len(test_dataset)}")

# Logic gate definitions
logic_gates = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]

def apply_logic_gate(a: torch.Tensor, b: torch.Tensor, logic_gate: int):
    return {
        0:  torch.zeros_like(a),
        1:  a * b,
        2:  a - a * b,
        3:  a,
        4:  b - a * b,
        5:  b,
        6:  a + b - 2 * a * b,
        7:  a + b - a * b,
        8:  1 - (a + b - a * b),
        9:  1 - (a + b - 2 * a * b),
        10: 1 - b,
        11: 1 - b + a * b,
        12: 1 - a,
        13: 1 - a + a * b,
        14: 1 - a * b,
        15: torch.ones_like(a),
    }[logic_gate]

class Logic(nn.Module):
    def __init__(self,
                 in_dim: int,
                 out_dim: int,
                 initialization_type: InitializationType = 'residual',
                 device=None
                 ):
        super().__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.initialization_type = initialization_type
        self.device = device or torch.device('cpu')
        
        a, b = self.get_connections()
        self.register_buffer('a', a)
        self.register_buffer('b', b)
        
        weights = torch.randn(out_dim, len(logic_gates), device=self.device)
        if self.initialization_type == 'residual':
            weights[:, :] = 0
            weights[:, 3] = 5  # Initialize to identity gate
        self.weights = torch.nn.parameter.Parameter(weights)

    def forward(self, x: torch.Tensor):
        a, b = x[:, self.a, ...], x[:, self.b, ...]
        
        if self.training:
            normalized_weights = torch.nn.functional.softmax(self.weights, dim=-1).to(x.dtype).to(self.device)
            r = torch.zeros_like(a).to(x.dtype).to(self.device)
            for logic_gate in logic_gates:
                if len(a.shape) > 2:
                    nw = einops.repeat(normalized_weights[..., logic_gate], 'weights -> weights depth', depth=a.shape[-1])
                else:
                    nw = normalized_weights[..., logic_gate]
                r = r + nw * apply_logic_gate(a, b, logic_gate)
            return r
        else:
            one_hot_weights = torch.nn.functional.one_hot(self.weights.argmax(-1), len(logic_gates)).to(torch.float32).to(self.device)
            with torch.no_grad():
                r = torch.zeros_like(a).to(x.dtype).to(self.device)
                for logic_gate in logic_gates:
                    if len(a.shape) > 2:
                        ohw = einops.repeat(one_hot_weights[..., logic_gate], 'weights -> weights depth', depth=a.shape[-1])
                    else:
                        ohw = one_hot_weights[..., logic_gate]
                    r = r + ohw * apply_logic_gate(a, b, logic_gate)
                return r

    def get_connections(self):
        connections = torch.randperm(2 * self.out_dim) % self.in_dim
        connections = torch.randperm(self.in_dim)[connections]
        connections = connections.reshape(2, self.out_dim)
        a, b = connections[0], connections[1]
        a, b = a.to(torch.int64), b.to(torch.int64)
        a, b = a.to(self.device), b.to(self.device)
        return a, b

class LogicTree(nn.Module):
    def __init__(self,
                 in_dim: int,
                 depth: int = 3,
                 initialization_type: InitializationType = 'residual',
                 device=None,
                 ):
        super().__init__()
        self.device = device or torch.device('cpu')
        
        layers = [LogicLayer(in_dim, int(2 ** (depth - 1)), initialization_type=initialization_type, device=self.device,implementation='cuda' if device.type == 'cuda' else 'python',
                    connections='random',grad_factor=1.5 )]
        for i in range(0, depth - 1, 1):
            layers.append(LogicLayer(int(2 ** (depth - 1 - i)), int(2 ** (depth - 1 - i - 1)), 
                            initialization_type=initialization_type, device=self.device,implementation='cuda' if device.type == 'cuda' else 'python',
                            connections='random',grad_factor=1.5))
        
        self.tree = nn.Sequential(*layers)

    def forward(self, x: torch.Tensor):
        return self.tree(x)

class Conv(nn.Module):
    def __init__(self,
                 in_channels: int,
                 out_channels: int,
                 depth: int = 3,
                 kernel_size: int = 3,
                 stride: int = 1,
                 padding: int = 1,
                 initialization_type: InitializationType = 'residual',
                 device=None
                 ):
        super().__init__()
        self.kernel_size = kernel_size
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.stride = stride
        self.padding = padding
        self.device = device or torch.device('cpu')
        
        self.filters = nn.ModuleList([
            LogicTree(in_dim=kernel_size ** 2 * in_channels, depth=depth, 
                     initialization_type=initialization_type, device=self.device) 
            for _ in range(out_channels)
        ])

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        batch_size, _, height, width = x.shape
        x = F.pad(x, (self.padding, self.padding, self.padding, self.padding), mode='constant', value=0)
        out_height = (height + 2 * self.padding - self.kernel_size) // self.stride + 1
        out_width = (width + 2 * self.padding - self.kernel_size) // self.stride + 1
        patches = F.unfold(x, kernel_size=self.kernel_size, stride=self.stride)
        outputs = []
        
        patches = einops.rearrange(patches, 'b h w -> (b w) h', h=patches.shape[1], w=patches.shape[2]) # Input is (100,25,576) Output: (57600,25)
        for filter in self.filters:
            out = filter(patches)  # Shape: (batch_size, 1, out_height * out_width)
            out = einops.rearrange(out, '(b h w) 1 -> b (h w)', h=out_height, w=out_width)
            outputs.append(out)
        
        output_tensor = torch.stack(outputs, dim=1)  # Shape: (batch_size, out_channels, out_height * out_width)
        output_tensor = einops.rearrange(output_tensor, 'b c (h w) -> b c h w', h=out_height, w=out_width)
        return output_tensor

class CustomOrPool2d(nn.Module):
    def __init__(self, kernel_size=2, stride=2, padding=0):
        super(CustomOrPool2d, self).__init__()
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        
    def forward(self, x):
        # Use MaxPool2d as approximation to OR pooling
        # In binary logic, max operation approximates OR
        return torch.max_pool2d(x, self.kernel_size, self.stride, self.padding)



class ConvDiffLogicMNIST(nn.Module):
    def __init__(self, k=16):
        super(ConvDiffLogicMNIST, self).__init__()
        self.k = k
        
        # Convolutional block 1: k kernels, 5x5, depth=3, no padding
        # Input: 1 × 28 × 28 -> Output: k × 24 × 24 (28-5+1=24)
        self.conv1 = Conv(in_channels=1, out_channels=k, kernel_size=5, depth=3, 
                         padding=0, initialization_type='residual', device=device)
        
        # OR pooling 1: 2x2, stride 2
        # k × 24 × 24 -> k × 12 × 12
        self.pool1 = CustomOrPool2d(kernel_size=2, stride=2)
        
        # Convolutional block 2: 3*k kernels, 3x3, depth=3
        # k × 12 × 12 -> 3*k × 12 × 12 (with padding=1), then pooled to 3*k × 6 × 6
        self.conv2 = Conv(in_channels=k, out_channels=3*k, kernel_size=3, depth=3, 
                         padding=1, initialization_type='residual', device=device)
        
        # OR pooling 2: 2x2, stride 2
        # 3*k × 12 × 12 -> 3*k × 6 × 6
        self.pool2 = CustomOrPool2d(kernel_size=2, stride=2)
        
        # Convolutional block 3: 9*k kernels, 3x3, depth=3
        # 3*k × 6 × 6 -> 9*k × 6 × 6 (with padding=1), then pooled to 9*k × 3 × 3
        self.conv3 = Conv(in_channels=3*k, out_channels=9*k, kernel_size=3, depth=3, 
                         padding=1, initialization_type='residual', device=device)
        
        # OR pooling 3: 2x2, stride 2
        # 9*k × 6 × 6 -> 9*k × 3 × 3
        self.pool3 = CustomOrPool2d(kernel_size=2, stride=2)
        
        # Flatten: 9*k × 3 × 3 -> 81*k
        self.flatten = nn.Flatten()
        
        # Regular differentiable logic layers (as specified in paper)
        # 81*k → 1280*k
        self.fc1 = LogicLayer(
            in_dim=81*k,
            out_dim=1280*k,
            device=device,
            implementation='cuda' if device.type == 'cuda' else 'python',
            connections='random',
            grad_factor=1.5 , # Higher for deeper networks
        )
        
        # 1280*k → 640*k
        self.fc2 = LogicLayer(
            in_dim=1280*k,
            out_dim=640*k,
            device=device,
            implementation='cuda' if device.type == 'cuda' else 'python',
            connections='random',
            grad_factor=1.5
        )
        
        # 640*k → 320*k
        self.fc3 = LogicLayer(
            in_dim=640*k,
            out_dim=320*k,
            device=device,
            implementation='cuda' if device.type == 'cuda' else 'python',
            connections='random',
            grad_factor=1.5
        )
        
        # GroupSum: 320*k → 10 (10 classes)
        # Using tau=30 as in the paper specifications
        self.group_sum = GroupSum(k=10, tau=30)
        
    def forward(self, x):
        # Input thresholding for binary processing (as mentioned in paper)
        # STE for binary inputs - FIXED GRADIENT FLOW
        x_bin = (x > 0.5).float()
        x = x + (x_bin - x).detach()
        
        # Debug shape printing (uncomment for debugging)
        # print(f"Input shape: {x.shape}")
        
        # Convolutional processing with logic gates
        x = self.conv1(x)
        # print(f"After conv1: {x.shape}")
        
        x = self.pool1(x)
        # print(f"After pool1: {x.shape}")
        
        x = self.conv2(x)
        # print(f"After conv2: {x.shape}")
        
        x = self.pool2(x)
        # print(f"After pool2: {x.shape}")
        
        x = self.conv3(x)
        # print(f"After conv3: {x.shape}")
        
        x = self.pool3(x)
        # print(f"After pool3: {x.shape}")
        
        # Flatten
        x = self.flatten(x)
        # print(f"After flatten: {x.shape}")
        
        # Fully connected logic layers
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        
        # GroupSum for classification
        x = self.group_sum(x)
        
        return x

# Initialize model
model = ConvDiffLogicMNIST(k=k).to(device)
print(f"Model created with {sum(p.numel() for p in model.parameters())} parameters")

# Print architecture details
print("\n" + "="*80)
print("LOGIC GATE CONVOLUTIONAL DIFFLOGIC MNIST")
print("="*80)
print(f"Input: 1 × 28 × 28")
print(f"Conv1: {k} logic gate filters, 5×5, depth=3, no padding -> {k} × 24 × 24")
print(f"Pool1: OR pooling 2×2, stride 2 -> {k} × 12 × 12")
print(f"Conv2: {3*k} logic gate filters, 3×3, depth=3 -> {3*k} × 12 × 12")
print(f"Pool2: OR pooling 2×2, stride 2 -> {3*k} × 6 × 6")
print(f"Conv3: {9*k} logic gate filters, 3×3, depth=3 -> {9*k} × 6 × 6")
print(f"Pool3: OR pooling 2×2, stride 2 -> {9*k} × 3 × 3")
print(f"Flatten: -> {81*k}")
print(f"FC1: Regular differentiable logic layer {81*k} -> {1280*k}")
print(f"FC2: Regular differentiable logic layer {1280*k} -> {640*k}")
print(f"FC3: Regular differentiable logic layer {640*k} -> {320*k}")
print(f"GroupSum: {320*k} -> 10 classes")
print("="*80)

# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Training function
def train_epoch(model, train_loader, val_loader, criterion, optimizer, device, global_step, best_val_acc):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    progress_bar = tqdm(train_loader, desc='Training')
    for batch_idx, (data, target) in enumerate(progress_bar):
        data, target = data.to(device), target.to(device)
        
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        _, predicted = torch.max(output.data, 1)
        total += target.size(0)
        correct += (predicted == target).sum().item()
        
        progress_bar.set_postfix({
            'Loss': f'{running_loss/(batch_idx+1):.4f}',
            'Acc': f'{100.*correct/total:.2f}%'
        })
        
        # Update global step
        global_step += 1
        
        # Validate every 5000 steps
        if global_step % 5000 == 0:
            val_loss, val_acc = evaluate(model, val_loader, criterion, device, desc='Validation')
            print(f"Step {global_step}: Val Acc: {val_acc:.2f}%")
            
            # Save best model based on validation accuracy
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                torch.save(model.state_dict(), 'best_logic_conv_difflogic_mnist_val.pth')
                print(f"New best model saved! Validation accuracy: {val_acc:.2f}%")
    
    return running_loss / len(train_loader), 100. * correct / total, global_step, best_val_acc

def evaluate(model, loader, criterion, device, desc='Evaluating'):
    model.eval()
    test_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        progress_bar = tqdm(loader, desc=desc)
        for data, target in progress_bar:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion(output, target).item()
            
            _, predicted = torch.max(output.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()
            
            progress_bar.set_postfix({
                'Acc': f'{100.*correct/total:.2f}%'
            })
    
    return test_loss / len(loader), 100. * correct / total

if __name__ == '__main__':
    print("Starting training with logic gate convolutions...")
    train_losses = []
    train_accuracies = []
    test_losses = []
    test_accuracies = []
    val_accuracies = []

    best_val_acc = 0.0  # Track best validation accuracy
    global_step = 0     # Global training step counter
    start_time = time.time()

    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        print("-" * 50)
        
        # Train (returns updated global_step and best_val_acc)
        train_loss, train_acc, global_step, best_val_acc = train_epoch(
            model, train_loader, val_loader, criterion, optimizer, device, 
            global_step, best_val_acc
        )
        train_losses.append(train_loss)
        train_accuracies.append(train_acc)
        
        # Evaluate on test set
        test_loss, test_acc = evaluate(model, test_loader, criterion, device, desc='Testing')
        test_losses.append(test_loss)
        test_accuracies.append(test_acc)
        
        # Evaluate on validation set
        val_loss, val_acc = evaluate(model, val_loader, criterion, device, desc='Validation')
        val_accuracies.append(val_acc)
        
        print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
        print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")
        print(f"Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.2f}%")

    total_time = time.time() - start_time
    print(f"\nTraining completed in {total_time:.2f} seconds")
    print(f"Best validation accuracy: {best_val_acc:.2f}%")

    # Load best model for final evaluation
    model.load_state_dict(torch.load('best_logic_conv_difflogic_mnist_val.pth'))
    print("Loaded best model based on validation accuracy")

    # Final evaluation on test set
    print("\nEvaluating best model on test set...")
    test_loss, test_acc = evaluate(model, test_loader, criterion, device, desc='Final Test')
    print(f"Best model test accuracy: {test_acc:.2f}%")

    # Test discrete inference
    print("\nTesting discrete inference...")
    model.eval()
    start_time = time.time()
    correct = 0
    total = 0

    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            _, predicted = torch.max(output.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()

    inference_time = time.time() - start_time
    inference_speed = inference_time / total

    print(f"Discrete inference accuracy: {100. * correct / total:.2f}%")
    print(f"Inference speed: {inference_speed:.6f} seconds per sample")

    # Save final model
    torch.save(model.state_dict(), 'final_logic_conv_difflogic_mnist.pth')

    # Final summary
    print("\n" + "="*80)
    print("LOGIC GATE CONVOLUTIONAL DIFFLOGIC MNIST RESULTS")
    print("="*80)
    print(f"Training samples: {len(train_dataset)}")
    print(f"Validation samples: {len(val_dataset)}")
    print(f"Test samples: {len(test_dataset)}")
    print(f"Best validation accuracy: {best_val_acc:.2f}%")
    print(f"Final test accuracy: {test_acc:.2f}%")
    print(f"Inference speed: {inference_speed:.6f} seconds per sample")
    print("="*80)

Using device: cuda
Base kernel count k = 16
Expected shapes from paper:
After conv1 + pool1: 16 × 12 × 12
After conv2 + pool2: 48 × 6 × 6
After conv3 + pool3: 144 × 3 × 3
After flattening: 1296
Training samples: 50000
Validation samples: 10000
Test samples: 10000
Model created with 596736 parameters

LOGIC GATE CONVOLUTIONAL DIFFLOGIC MNIST
Input: 1 × 28 × 28
Conv1: 16 logic gate filters, 5×5, depth=3, no padding -> 16 × 24 × 24
Pool1: OR pooling 2×2, stride 2 -> 16 × 12 × 12
Conv2: 48 logic gate filters, 3×3, depth=3 -> 48 × 12 × 12
Pool2: OR pooling 2×2, stride 2 -> 48 × 6 × 6
Conv3: 144 logic gate filters, 3×3, depth=3 -> 144 × 6 × 6
Pool3: OR pooling 2×2, stride 2 -> 144 × 3 × 3
Flatten: -> 1296
FC1: Regular differentiable logic layer 1296 -> 20480
FC2: Regular differentiable logic layer 20480 -> 10240
FC3: Regular differentiable logic layer 10240 -> 5120
GroupSum: 5120 -> 10 classes
Starting training with logic gate convolutions...

Epoch 1/100
------------------------------------

Training: 100%|██████████| 196/196 [00:57<00:00,  3.44it/s, Loss=1.7863, Acc=55.38%]
Testing: 100%|██████████| 40/40 [00:02<00:00, 18.52it/s, Acc=30.42%]
Validation: 100%|██████████| 40/40 [00:02<00:00, 18.40it/s, Acc=30.64%]


Train Loss: 1.7863, Train Acc: 55.38%
Val Loss: 2.0326, Val Acc: 30.64%
Test Loss: 2.0275, Test Acc: 30.42%

Epoch 2/100
--------------------------------------------------


Training: 100%|██████████| 196/196 [00:57<00:00,  3.44it/s, Loss=0.9933, Acc=72.98%]
Testing: 100%|██████████| 40/40 [00:02<00:00, 18.46it/s, Acc=77.75%]
Validation: 100%|██████████| 40/40 [00:02<00:00, 18.43it/s, Acc=76.55%]


Train Loss: 0.9933, Train Acc: 72.98%
Val Loss: 0.8033, Val Acc: 76.55%
Test Loss: 0.7755, Test Acc: 77.75%

Epoch 3/100
--------------------------------------------------


Training: 100%|██████████| 196/196 [00:57<00:00,  3.44it/s, Loss=0.7217, Acc=80.25%]
Testing: 100%|██████████| 40/40 [00:02<00:00, 18.48it/s, Acc=83.24%]
Validation: 100%|██████████| 40/40 [00:02<00:00, 18.39it/s, Acc=81.98%]


Train Loss: 0.7217, Train Acc: 80.25%
Val Loss: 0.6245, Val Acc: 81.98%
Test Loss: 0.5927, Test Acc: 83.24%

Epoch 4/100
--------------------------------------------------


Training: 100%|██████████| 196/196 [00:56<00:00,  3.44it/s, Loss=0.5974, Acc=83.97%]
Testing: 100%|██████████| 40/40 [00:02<00:00, 18.48it/s, Acc=85.84%]
Validation: 100%|██████████| 40/40 [00:02<00:00, 18.43it/s, Acc=84.35%]


Train Loss: 0.5974, Train Acc: 83.97%
Val Loss: 0.5497, Val Acc: 84.35%
Test Loss: 0.5198, Test Acc: 85.84%

Epoch 5/100
--------------------------------------------------


Training: 100%|██████████| 196/196 [00:57<00:00,  3.44it/s, Loss=0.5233, Acc=86.15%]
Testing: 100%|██████████| 40/40 [00:02<00:00, 18.46it/s, Acc=87.83%]
Validation: 100%|██████████| 40/40 [00:02<00:00, 18.39it/s, Acc=86.38%]


Train Loss: 0.5233, Train Acc: 86.15%
Val Loss: 0.4938, Val Acc: 86.38%
Test Loss: 0.4661, Test Acc: 87.83%

Epoch 6/100
--------------------------------------------------


Training: 100%|██████████| 196/196 [00:57<00:00,  3.44it/s, Loss=0.4763, Acc=87.51%]
Testing: 100%|██████████| 40/40 [00:02<00:00, 18.52it/s, Acc=88.57%]
Validation: 100%|██████████| 40/40 [00:02<00:00, 18.37it/s, Acc=87.09%]


Train Loss: 0.4763, Train Acc: 87.51%
Val Loss: 0.4652, Val Acc: 87.09%
Test Loss: 0.4389, Test Acc: 88.57%

Epoch 7/100
--------------------------------------------------


Training: 100%|██████████| 196/196 [00:57<00:00,  3.44it/s, Loss=0.4477, Acc=88.43%]
Testing: 100%|██████████| 40/40 [00:02<00:00, 18.47it/s, Acc=89.31%]
Validation: 100%|██████████| 40/40 [00:02<00:00, 18.44it/s, Acc=88.14%]


Train Loss: 0.4477, Train Acc: 88.43%
Val Loss: 0.4421, Val Acc: 88.14%
Test Loss: 0.4173, Test Acc: 89.31%

Epoch 8/100
--------------------------------------------------


Training: 100%|██████████| 196/196 [00:56<00:00,  3.44it/s, Loss=0.4229, Acc=89.25%]
Testing: 100%|██████████| 40/40 [00:02<00:00, 18.51it/s, Acc=90.13%]
Validation: 100%|██████████| 40/40 [00:02<00:00, 18.43it/s, Acc=88.79%]


Train Loss: 0.4229, Train Acc: 89.25%
Val Loss: 0.4224, Val Acc: 88.79%
Test Loss: 0.3975, Test Acc: 90.13%

Epoch 9/100
--------------------------------------------------


Training: 100%|██████████| 196/196 [00:57<00:00,  3.44it/s, Loss=0.4046, Acc=89.83%]
Testing: 100%|██████████| 40/40 [00:02<00:00, 18.47it/s, Acc=90.80%]
Validation: 100%|██████████| 40/40 [00:02<00:00, 18.43it/s, Acc=89.13%]


Train Loss: 0.4046, Train Acc: 89.83%
Val Loss: 0.4076, Val Acc: 89.13%
Test Loss: 0.3828, Test Acc: 90.80%

Epoch 10/100
--------------------------------------------------


Training: 100%|██████████| 196/196 [00:57<00:00,  3.44it/s, Loss=0.3915, Acc=90.26%]
Testing: 100%|██████████| 40/40 [00:02<00:00, 18.46it/s, Acc=91.21%]
Validation: 100%|██████████| 40/40 [00:02<00:00, 18.43it/s, Acc=89.70%]


Train Loss: 0.3915, Train Acc: 90.26%
Val Loss: 0.3962, Val Acc: 89.70%
Test Loss: 0.3722, Test Acc: 91.21%

Epoch 11/100
--------------------------------------------------


Training: 100%|██████████| 196/196 [00:56<00:00,  3.44it/s, Loss=0.3815, Acc=90.58%]
Testing: 100%|██████████| 40/40 [00:02<00:00, 18.51it/s, Acc=91.51%]
Validation: 100%|██████████| 40/40 [00:02<00:00, 18.36it/s, Acc=89.97%]


Train Loss: 0.3815, Train Acc: 90.58%
Val Loss: 0.3878, Val Acc: 89.97%
Test Loss: 0.3650, Test Acc: 91.51%

Epoch 12/100
--------------------------------------------------


Training: 100%|██████████| 196/196 [00:57<00:00,  3.44it/s, Loss=0.3731, Acc=90.83%]
Testing: 100%|██████████| 40/40 [00:02<00:00, 18.44it/s, Acc=91.76%]
Validation: 100%|██████████| 40/40 [00:02<00:00, 18.42it/s, Acc=90.25%]


Train Loss: 0.3731, Train Acc: 90.83%
Val Loss: 0.3820, Val Acc: 90.25%
Test Loss: 0.3590, Test Acc: 91.76%

Epoch 13/100
--------------------------------------------------


Training: 100%|██████████| 196/196 [00:57<00:00,  3.44it/s, Loss=0.3663, Acc=91.06%]
Testing: 100%|██████████| 40/40 [00:02<00:00, 18.49it/s, Acc=91.74%]
Validation: 100%|██████████| 40/40 [00:02<00:00, 18.40it/s, Acc=90.44%]


Train Loss: 0.3663, Train Acc: 91.06%
Val Loss: 0.3775, Val Acc: 90.44%
Test Loss: 0.3543, Test Acc: 91.74%

Epoch 14/100
--------------------------------------------------


Training: 100%|██████████| 196/196 [00:56<00:00,  3.44it/s, Loss=0.3607, Acc=91.21%]
Testing: 100%|██████████| 40/40 [00:02<00:00, 18.45it/s, Acc=91.85%]
Validation: 100%|██████████| 40/40 [00:02<00:00, 18.40it/s, Acc=90.60%]


Train Loss: 0.3607, Train Acc: 91.21%
Val Loss: 0.3724, Val Acc: 90.60%
Test Loss: 0.3495, Test Acc: 91.85%

Epoch 15/100
--------------------------------------------------


Training: 100%|██████████| 196/196 [00:57<00:00,  3.44it/s, Loss=0.3547, Acc=91.42%]
Testing: 100%|██████████| 40/40 [00:02<00:00, 18.46it/s, Acc=92.01%]
Validation: 100%|██████████| 40/40 [00:02<00:00, 18.43it/s, Acc=90.90%]


Train Loss: 0.3547, Train Acc: 91.42%
Val Loss: 0.3677, Val Acc: 90.90%
Test Loss: 0.3448, Test Acc: 92.01%

Epoch 16/100
--------------------------------------------------


Training: 100%|██████████| 196/196 [00:57<00:00,  3.44it/s, Loss=0.3496, Acc=91.57%]
Testing: 100%|██████████| 40/40 [00:02<00:00, 18.52it/s, Acc=92.04%]
Validation: 100%|██████████| 40/40 [00:02<00:00, 18.37it/s, Acc=90.92%]


Train Loss: 0.3496, Train Acc: 91.57%
Val Loss: 0.3641, Val Acc: 90.92%
Test Loss: 0.3412, Test Acc: 92.04%

Epoch 17/100
--------------------------------------------------


Training: 100%|██████████| 196/196 [00:56<00:00,  3.44it/s, Loss=0.3454, Acc=91.80%]
Testing: 100%|██████████| 40/40 [00:02<00:00, 18.49it/s, Acc=92.28%]
Validation: 100%|██████████| 40/40 [00:02<00:00, 18.43it/s, Acc=91.26%]


Train Loss: 0.3454, Train Acc: 91.80%
Val Loss: 0.3588, Val Acc: 91.26%
Test Loss: 0.3347, Test Acc: 92.28%

Epoch 18/100
--------------------------------------------------


Training: 100%|██████████| 196/196 [00:56<00:00,  3.44it/s, Loss=0.3409, Acc=91.85%]
Testing: 100%|██████████| 40/40 [00:02<00:00, 18.47it/s, Acc=92.52%]
Validation: 100%|██████████| 40/40 [00:02<00:00, 18.38it/s, Acc=91.44%]


Train Loss: 0.3409, Train Acc: 91.85%
Val Loss: 0.3544, Val Acc: 91.44%
Test Loss: 0.3312, Test Acc: 92.52%

Epoch 19/100
--------------------------------------------------


Training: 100%|██████████| 196/196 [00:56<00:00,  3.44it/s, Loss=0.3376, Acc=92.00%]
Testing: 100%|██████████| 40/40 [00:02<00:00, 18.50it/s, Acc=92.52%]
Validation: 100%|██████████| 40/40 [00:02<00:00, 18.44it/s, Acc=91.29%]


Train Loss: 0.3376, Train Acc: 92.00%
Val Loss: 0.3527, Val Acc: 91.29%
Test Loss: 0.3295, Test Acc: 92.52%

Epoch 20/100
--------------------------------------------------


Training: 100%|██████████| 196/196 [00:56<00:00,  3.44it/s, Loss=0.3344, Acc=92.06%]
Testing: 100%|██████████| 40/40 [00:02<00:00, 18.45it/s, Acc=92.74%]
Validation: 100%|██████████| 40/40 [00:02<00:00, 18.41it/s, Acc=91.41%]


Train Loss: 0.3344, Train Acc: 92.06%
Val Loss: 0.3480, Val Acc: 91.41%
Test Loss: 0.3244, Test Acc: 92.74%

Epoch 21/100
--------------------------------------------------


Training: 100%|██████████| 196/196 [00:56<00:00,  3.44it/s, Loss=0.3305, Acc=92.16%]
Testing: 100%|██████████| 40/40 [00:02<00:00, 18.45it/s, Acc=92.87%]
Validation: 100%|██████████| 40/40 [00:02<00:00, 18.35it/s, Acc=91.58%]


Train Loss: 0.3305, Train Acc: 92.16%
Val Loss: 0.3452, Val Acc: 91.58%
Test Loss: 0.3211, Test Acc: 92.87%

Epoch 22/100
--------------------------------------------------


Training: 100%|██████████| 196/196 [00:57<00:00,  3.44it/s, Loss=0.3261, Acc=92.25%]
Testing: 100%|██████████| 40/40 [00:02<00:00, 18.48it/s, Acc=92.92%]
Validation: 100%|██████████| 40/40 [00:02<00:00, 18.43it/s, Acc=91.51%]


Train Loss: 0.3261, Train Acc: 92.25%
Val Loss: 0.3422, Val Acc: 91.51%
Test Loss: 0.3180, Test Acc: 92.92%

Epoch 23/100
--------------------------------------------------


Training: 100%|██████████| 196/196 [00:56<00:00,  3.44it/s, Loss=0.3224, Acc=92.40%]
Testing: 100%|██████████| 40/40 [00:02<00:00, 18.50it/s, Acc=92.95%]
Validation: 100%|██████████| 40/40 [00:02<00:00, 18.40it/s, Acc=91.53%]


Train Loss: 0.3224, Train Acc: 92.40%
Val Loss: 0.3393, Val Acc: 91.53%
Test Loss: 0.3147, Test Acc: 92.95%

Epoch 24/100
--------------------------------------------------


Training: 100%|██████████| 196/196 [00:56<00:00,  3.44it/s, Loss=0.3186, Acc=92.50%]
Testing: 100%|██████████| 40/40 [00:02<00:00, 18.46it/s, Acc=93.08%]
Validation: 100%|██████████| 40/40 [00:02<00:00, 18.42it/s, Acc=91.76%]


Train Loss: 0.3186, Train Acc: 92.50%
Val Loss: 0.3351, Val Acc: 91.76%
Test Loss: 0.3112, Test Acc: 93.08%

Epoch 25/100
--------------------------------------------------


Training: 100%|██████████| 196/196 [00:57<00:00,  3.44it/s, Loss=0.3156, Acc=92.60%]
Testing: 100%|██████████| 40/40 [00:02<00:00, 18.51it/s, Acc=93.13%]
Validation: 100%|██████████| 40/40 [00:02<00:00, 18.43it/s, Acc=91.72%]


Train Loss: 0.3156, Train Acc: 92.60%
Val Loss: 0.3331, Val Acc: 91.72%
Test Loss: 0.3088, Test Acc: 93.13%

Epoch 26/100
--------------------------------------------------


Training:  51%|█████     | 99/196 [00:29<00:28,  3.44it/s, Loss=0.3146, Acc=92.64%]
Validation:   0%|          | 0/40 [00:00<?, ?it/s][A
Validation:   0%|          | 0/40 [00:00<?, ?it/s, Acc=89.84%][A
Validation:   0%|          | 0/40 [00:00<?, ?it/s, Acc=91.21%][A
Validation:   5%|▌         | 2/40 [00:00<00:02, 18.17it/s, Acc=91.21%][A
Validation:   5%|▌         | 2/40 [00:00<00:02, 18.17it/s, Acc=91.28%][A
Validation:   5%|▌         | 2/40 [00:00<00:02, 18.17it/s, Acc=91.31%][A
Validation:  10%|█         | 4/40 [00:00<00:01, 18.08it/s, Acc=91.31%][A
Validation:  10%|█         | 4/40 [00:00<00:01, 18.08it/s, Acc=91.33%][A
Validation:  10%|█         | 4/40 [00:00<00:01, 18.08it/s, Acc=91.60%][A
Validation:  15%|█▌        | 6/40 [00:00<00:01, 18.04it/s, Acc=91.60%][A
Validation:  15%|█▌        | 6/40 [00:00<00:01, 18.04it/s, Acc=91.96%][A
Validation:  15%|█▌        | 6/40 [00:00<00:01, 18.04it/s, Acc=91.94%][A
Validation:  20%|██        | 8/40 [00:00<00:01, 18.05it/s, Acc=9

Step 5000: Val Acc: 91.84%
New best model saved! Validation accuracy: 91.84%


RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn