In [None]:
# PI : PI Network

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader,Subset

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# -------------------------
# Step 1: Data Preparation
# -------------------------

# Transform: Normalize MNIST images
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# Load full MNIST dataset
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, transform=transform)

# Create labeled subset (e.g., 100 examples evenly distributed across classes)
num_labels = 100
labels = np.array(train_dataset.targets)
labeled_idx = []

for i in range(10):
    idx = np.where(labels == i)[0][:num_labels // 10]
    labeled_idx.extend(idx)

unlabeled_idx = list(set(range(len(train_dataset))) - set(labeled_idx))

labeled_dataset = Subset(train_dataset, labeled_idx)
unlabeled_dataset = Subset(train_dataset, unlabeled_idx)

labeled_loader = DataLoader(labeled_dataset, batch_size=64, shuffle=True)
unlabeled_loader = DataLoader(unlabeled_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)

In [None]:
# -------------------------
# Step 2: Define the Model
# -------------------------

class PiCNN(nn.Module):
    """Simple CNN for Π Model"""
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))        # Conv layer 1
        x = F.relu(self.conv2(x))        # Conv layer 2
        x = F.max_pool2d(x, 2)           # Max pooling
        x = torch.flatten(x, 1)          # Flatten to [batch, features]
        x = self.dropout(x)              # Apply dropout
        x = F.relu(self.fc1(x))          # Fully connected
        x = self.fc2(x)                  # Output logits
        return x

In [None]:
# -------------------------
# Step 3: Training Utilities
# -------------------------

def add_noise(x, noise_std=0.15):
    """Adds Gaussian noise to an input tensor (augmentation)"""
    return x + torch.randn_like(x) * noise_std

def train_pi_model(model, optimizer, labeled_loader, unlabeled_loader, alpha):
    model.train()
    for (x_l, y_l), (x_u, _) in zip(labeled_loader, unlabeled_loader):
        x_l, y_l = x_l.to(device), y_l.to(device)
        x_u = x_u.to(device)

        # Supervised loss on labeled data (with dropout)
        logits_l = model(x_l)
        loss_sup = F.cross_entropy(logits_l, y_l)

        # Consistency loss on unlabeled data
        # Pass same input twice with different dropout masks and augmentations
        x_u1 = add_noise(x_u)
        x_u2 = add_noise(x_u)

        logits_u1 = model(x_u1)
        logits_u2 = model(x_u2)

        probs_u1 = F.softmax(logits_u1, dim=1)
        probs_u2 = F.softmax(logits_u2, dim=1)

        loss_unsup = F.mse_loss(probs_u1, probs_u2)

        # Total loss
        loss = loss_sup + alpha * loss_unsup

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

def evaluate(model, loader):
    model.eval()
    correct = 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            output = model(x)
            pred = output.argmax(dim=1)
            correct += (pred == y).sum().item()
    return correct / len(loader.dataset)

In [None]:
# -------------------------
# Step 4: Train the Model
# -------------------------

model = PiCNN().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
# incase of two different models
# use
# optimizer = optim.Adam(list(modelA.parameters()) + list(modelB.parameters()), lr=1e-3)

epochs = 20
alpha = 20.0  # weight for consistency loss

for epoch in range(1, epochs + 1):
    train_pi_model(model, optimizer, labeled_loader, unlabeled_loader, alpha)
    test_acc = evaluate(model, test_loader)
    print(f"Epoch {epoch:02d} | Test Accuracy: {test_acc:.4f}")