In [None]:
# Import necessary libraries
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt

# Set device (GPU if available, else CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define basic transform for MNIST (normalize to [-1, 1])
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

# Load original MNIST dataset
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# Create data loaders
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Define the Basic CNN architecture
class BasicCNN(nn.Module):
    def __init__(self):
        super(BasicCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(64 * 7 * 7, 64)
        self.fc2 = nn.Linear(64, 10)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x)))
        x = self.pool(self.relu(self.conv2(x)))
        x = self.relu(self.conv3(x))
        x = x.view(-1, 64 * 7 * 7)  # Flatten
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x
    def get_reg_loss(self):
        return 0

# Function to flip convolutional kernels horizontally
def horizontal_flip_kernel(kernel):
    return torch.flip(kernel, [3])  # Flip along width dimension

# Custom flip-invariant Conv2D layer with regularization
class FlipInvariantConv2D(nn.Conv2d):
    def __init__(self, in_channels, out_channels, kernel_size, reg_strength=1e-4, **kwargs):
        super(FlipInvariantConv2D, self).__init__(in_channels, out_channels, kernel_size, **kwargs)
        self.reg_strength = reg_strength

    def forward(self, x):
        flipped_kernel = horizontal_flip_kernel(self.weight)
        reg_loss = self.reg_strength * torch.sum((self.weight - flipped_kernel) ** 2)
        self.add_loss(reg_loss)
        return super(FlipInvariantConv2D, self).forward(x)

    def add_loss(self, loss):
        if not hasattr(self, 'reg_losses'):
            self.reg_losses = []
        self.reg_losses.append(loss)

# Define the Flip-Invariant CNN architecture
class FlipInvariantCNN(nn.Module):
    def __init__(self, reg_strength=1e-4):
        super(FlipInvariantCNN, self).__init__()
        self.conv1 = FlipInvariantConv2D(1, 32, kernel_size=3, stride=1, padding=1, reg_strength=reg_strength)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = FlipInvariantConv2D(32, 64, kernel_size=3, stride=1, padding=1, reg_strength=reg_strength)
        self.conv3 = FlipInvariantConv2D(64, 64, kernel_size=3, stride=1, padding=1, reg_strength=reg_strength)
        self.fc1 = nn.Linear(64 * 7 * 7, 64)
        self.fc2 = nn.Linear(64, 10)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x)))
        x = self.pool(self.relu(self.conv2(x)))
        x = self.relu(self.conv3(x))
        x = x.view(-1, 64 * 7 * 7)  # Flatten
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x

    def get_reg_loss(self):
        if hasattr(self, 'reg_losses'):
            return sum(self.reg_losses)
        return 0

# Training function
def train(model, train_loader, criterion, optimizer, epochs=3):
    model.train()
    for epoch in range(epochs):
        for data, target in train_loader:
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target) + model.get_reg_loss()
            loss.backward()
            optimizer.step()
            model.reg_losses = []  # Clear regularization losses

# Evaluation function
def evaluate(model, test_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            _, predicted = torch.max(output.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()
    return correct / total

# --- Training on Original MNIST ---
print("Training on Original MNIST")

# Train Basic CNN
basic_cnn = BasicCNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(basic_cnn.parameters(), lr=0.001)
train(basic_cnn, train_loader, criterion, optimizer)
basic_acc = evaluate(basic_cnn, test_loader)
print(f"Basic CNN Test Accuracy: {basic_acc:.4f}")

# Train Flip-Invariant CNN with two regularization strengths
reg_strengths = [1e-4, 1e-2]
for reg in reg_strengths:
    print(f"\nFlip-Invariant CNN with reg_strength={reg}")
    flip_cnn = FlipInvariantCNN(reg_strength=reg).to(device)
    optimizer = optim.Adam(flip_cnn.parameters(), lr=0.001)
    train(flip_cnn, train_loader, criterion, optimizer)
    flip_acc = evaluate(flip_cnn, test_loader)
    print(f"Flip-Invariant CNN Test Accuracy (reg={reg}): {flip_acc:.4f}")

# --- Training on Flipped MNIST ---
print("\nTraining on Flipped MNIST")
transform_flipped = transforms.Compose([transforms.RandomHorizontalFlip(p=1.0), transform])
train_dataset_flipped = datasets.MNIST(root='./data', train=True, download=True, transform=transform_flipped)
test_dataset_flipped = datasets.MNIST(root='./data', train=False, download=True, transform=transform_flipped)
train_loader_flipped = DataLoader(train_dataset_flipped, batch_size=batch_size, shuffle=True)
test_loader_flipped = DataLoader(test_dataset_flipped, batch_size=batch_size, shuffle=False)

# Basic CNN on flipped MNIST
basic_cnn_flipped = BasicCNN().to(device)
optimizer = optim.Adam(basic_cnn_flipped.parameters(), lr=0.001)
train(basic_cnn_flipped, train_loader_flipped, criterion, optimizer)
basic_flipped_acc = evaluate(basic_cnn_flipped, test_loader_flipped)
print(f"Basic CNN on Flipped MNIST Test Accuracy: {basic_flipped_acc:.4f}")

# Flip-Invariant CNN on flipped MNIST
for reg in reg_strengths:
    print(f"\nFlip-Invariant CNN with reg_strength={reg} on Flipped MNIST")
    flip_cnn_flipped = FlipInvariantCNN(reg_strength=reg).to(device)
    optimizer = optim.Adam(flip_cnn_flipped.parameters(), lr=0.001)
    train(flip_cnn_flipped, train_loader_flipped, criterion, optimizer)
    flip_flipped_acc = evaluate(flip_cnn_flipped, test_loader_flipped)
    print(f"Flip-Invariant CNN Test Accuracy (reg={reg}): {flip_flipped_acc:.4f}")

# --- Training on Rotated MNIST ---
print("\nTraining on Rotated MNIST")
transform_rotated = transforms.Compose([transforms.RandomRotation(degrees=90), transform])
train_dataset_rotated = datasets.MNIST(root='./data', train=True, download=True, transform=transform_rotated)
test_dataset_rotated = datasets.MNIST(root='./data', train=False, download=True, transform=transform_rotated)
train_loader_rotated = DataLoader(train_dataset_rotated, batch_size=batch_size, shuffle=True)
test_loader_rotated = DataLoader(test_dataset_rotated, batch_size=batch_size, shuffle=False)

# Basic CNN on rotated MNIST
basic_cnn_rotated = BasicCNN().to(device)
optimizer = optim.Adam(basic_cnn_rotated.parameters(), lr=0.001)
train(basic_cnn_rotated, train_loader_rotated, criterion, optimizer)
basic_rotated_acc = evaluate(basic_cnn_rotated, test_loader_rotated)
print(f"Basic CNN on Rotated MNIST Test Accuracy: {basic_rotated_acc:.4f}")

# Flip-Invariant CNN on rotated MNIST
for reg in reg_strengths:
    print(f"\nFlip-Invariant CNN with reg_strength={reg} on Rotated MNIST")
    flip_cnn_rotated = FlipInvariantCNN(reg_strength=reg).to(device)
    optimizer = optim.Adam(flip_cnn_rotated.parameters(), lr=0.001)
    train(flip_cnn_rotated, train_loader_rotated, criterion, optimizer)
    flip_rotated_acc = evaluate(flip_cnn_rotated, test_loader_rotated)
    print(f"Flip-Invariant CNN Test Accuracy (reg={reg}): {flip_rotated_acc:.4f}")

Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x7f2a6088c340>>
Traceback (most recent call last):
  File "/homes/55/lvierling/miniconda/envs/lora_project/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 775, in _clean_thread_parent_frames
    def _clean_thread_parent_frames(
KeyboardInterrupt: 


Training on Original MNIST


In [19]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Hyperparameters
num_epochs = 3
batch_size = 32
learning_rate = 0.001
alpha = 1.0  # For weighting function
lambda_recon = 1.0  # Weight for reconstruction loss
num_classes = 10

In [12]:
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

# Hyperparameters (define batch_size)
batch_size = 32

# Define transformations for MNIST (single-channel grayscale images)
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Load MNIST dataset
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [13]:
class NormalCNN(nn.Module):
    def __init__(self, out_channels=32, in_channels=3):
        super(NormalCNN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, 16, 3, padding=1)
        self.conv2 = nn.Conv2d(16, out_channels, 3, padding=1)

    def forward(self, x):
        x = F.relu(self.conv1(x))  # Output: (batch, 16, 32, 32)
        x = F.relu(self.conv2(x))  # Output: (batch, 32, 32, 32)
        return x  # F_1

In [14]:
class RotationInvariantCNN(nn.Module):
    def __init__(self, out_channels=32,in_channels=3):
        super(RotationInvariantCNN, self).__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(in_channels, 16, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, out_channels, 3, padding=1),
            nn.ReLU()
        )

    def forward(self, x):
        # Rotate input by 0째, 90째, 180째, 270째
        rotations = [0, 1, 2, 3]  # k values for torch.rot90
        rotated_features = []
        for k in rotations:
            x_rot = torch.rot90(x, k, (2, 3))  # Rotate spatial dimensions
            feat = self.cnn(x_rot)
            feat = torch.rot90(feat, -k, (2, 3))  # Rotate back
            rotated_features.append(feat)
        # Average across rotations
        F_2 = torch.stack(rotated_features).mean(dim=0)
        return F_2  # Shape: (batch, 32, 32, 32)

In [15]:
class Autoencoder(nn.Module):
    def __init__(self, in_channels=32):
        super(Autoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels, 16, 3, stride=2, padding=1),  # Downsample
            nn.ReLU(),
            nn.Conv2d(16, 8, 3, stride=2, padding=1),
            nn.ReLU()
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(8, 16, 3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(16, in_channels, 3, stride=2, padding=1, output_padding=1),
            nn.Sigmoid()  # Assuming F_2 is normalized between 0 and 1
        )

    def forward(self, x):
        encoded = self.encoder(x)
        reconstructed = self.decoder(encoded)
        return reconstructed  # F_2_hat

In [22]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class TwoPipelineModel(nn.Module):
    def __init__(self, in_channels=3, feature_channels=32, alpha=1.0, lambda_recon=1.0):
        super(TwoPipelineModel, self).__init__()
        self.normal_cnn = NormalCNN(in_channels=in_channels, out_channels=feature_channels)
        self.rotation_cnn = RotationInvariantCNN(in_channels=in_channels, out_channels=feature_channels)
        self.normal_autoencoder = Autoencoder(in_channels=feature_channels)
        self.rotation_autoencoder = Autoencoder(in_channels=feature_channels)
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),  # Reduce to (batch, 32, 1, 1)
            nn.Flatten(),
            nn.Linear(feature_channels, num_classes)
        )
        self.alpha = alpha  # Unused in this version but kept for compatibility
        self.lambda_recon = lambda_recon

    def forward(self, x):
        # Compute feature maps
        F_1 = self.normal_cnn(x)
        F_2 = self.rotation_cnn(x)

        # Compute reconstructions
        F_1_hat = self.normal_autoencoder(F_1)
        F_2_hat = self.rotation_autoencoder(F_2)

        # Compute per-sample reconstruction losses for both feature maps
        loss_1 = F.mse_loss(F_1, F_1_hat, reduction='none').mean(dim=(1, 2, 3))  # Shape: (batch,)
        loss_2 = F.mse_loss(F_2, F_2_hat, reduction='none').mean(dim=(1, 2, 3))  # Shape: (batch,)

        if self.training:
            # During training, use fixed 50/50 sum
            F_agg = 0.5 * F_1 + 0.5 * F_2
            output = self.classifier(F_agg)
            return output, loss_1, loss_2
        else:
            # During inference, normalize losses to sum to 1 and use as weights
            total_loss = loss_1 + loss_2
            # Avoid division by zero with a small epsilon
            epsilon = 1e-8
            w_1 = loss_1 / (total_loss + epsilon)  # Shape: (batch,)
            w_2 = loss_2 / (total_loss + epsilon)  # Shape: (batch,)
            # Reshape weights for broadcasting
            w_1 = w_1.view(-1, 1, 1, 1)  # Shape: (batch, 1, 1, 1)
            w_2 = w_2.view(-1, 1, 1, 1)  # Shape: (batch, 1, 1, 1)
            # Weighted aggregation
            F_agg = w_1 * F_1 + w_2 * F_2
            output = self.classifier(F_agg)
            return output  # Could return w_1, w_2 for analysis if desired

    def compute_loss(self, output, target, loss_1, loss_2):
        class_loss = F.cross_entropy(output, target)
        # Average the two reconstruction losses
        recon_loss = (loss_1.mean() + loss_2.mean()) / 2
        total_loss = class_loss + self.lambda_recon * recon_loss
        return total_loss

In [25]:
model = TwoPipelineModel(in_channels=1,feature_channels=32, alpha=alpha, lambda_recon=lambda_recon).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

for epoch in range(num_epochs):
    model.train()
    for i, (images, labels) in enumerate(train_loader):
        images, labels = images.to(device), labels.to(device)

        # Forward pass
        outputs, recon_1, recon_2 = model(images)
        loss = model.compute_loss(outputs, labels, recon_1, recon_2)

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (i + 1) % 100 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}], Loss: {loss.item():.4f}')

Epoch [1/3], Step [100], Loss: 2.2858
Epoch [1/3], Step [200], Loss: 2.2767
Epoch [1/3], Step [300], Loss: 2.2958
Epoch [1/3], Step [400], Loss: 2.3031
Epoch [1/3], Step [500], Loss: 2.2369
Epoch [1/3], Step [600], Loss: 2.1497
Epoch [1/3], Step [700], Loss: 2.1099
Epoch [1/3], Step [800], Loss: 2.1371
Epoch [1/3], Step [900], Loss: 2.0474
Epoch [1/3], Step [1000], Loss: 2.2111
Epoch [1/3], Step [1100], Loss: 2.0798
Epoch [1/3], Step [1200], Loss: 2.0732
Epoch [1/3], Step [1300], Loss: 1.9244
Epoch [1/3], Step [1400], Loss: 1.8319
Epoch [1/3], Step [1500], Loss: 1.8811
Epoch [1/3], Step [1600], Loss: 1.8052
Epoch [1/3], Step [1700], Loss: 1.9679
Epoch [1/3], Step [1800], Loss: 1.8494
Epoch [2/3], Step [100], Loss: 1.9761
Epoch [2/3], Step [200], Loss: 1.7876
Epoch [2/3], Step [300], Loss: 1.8705
Epoch [2/3], Step [400], Loss: 1.8180
Epoch [2/3], Step [500], Loss: 1.9512
Epoch [2/3], Step [600], Loss: 1.8468
Epoch [2/3], Step [700], Loss: 1.8087
Epoch [2/3], Step [800], Loss: 1.9512
Epo

KeyboardInterrupt: 

In [29]:
model.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)  # Reconstruction loss is used internally for weighting
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f'Test Accuracy: {accuracy:.2f}%')

Test Accuracy: 34.26%
