In [None]:
# SSL : Co-Training

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader,Subset

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# -------------------------
# Step 1: Data Preparation
# -------------------------

# Transform: Normalize MNIST images
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# Load full MNIST dataset
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, transform=transform)

# Create labeled subset (e.g., 100 examples evenly distributed across classes)
num_labels = 100
labels = np.array(train_dataset.targets)
labeled_idx = []

for i in range(10):
    idx = np.where(labels == i)[0][:num_labels // 10]
    labeled_idx.extend(idx)

unlabeled_idx = list(set(range(len(train_dataset))) - set(labeled_idx))

labeled_dataset = Subset(train_dataset, labeled_idx)
unlabeled_dataset = Subset(train_dataset, unlabeled_idx)

labeled_loader = DataLoader(labeled_dataset, batch_size=64, shuffle=True)
unlabeled_loader = DataLoader(unlabeled_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)

In [None]:
class SplitMNIST(Dataset):
    def __init__(self, dataset, side='left'):
        self.dataset = dataset
        self.side = side

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        x, y = self.dataset[idx]
        if self.side == 'left':
            return x[:, :, :14], y
        else:
            return x[:, :, 14:], y

# Step 3: Define Half-CNN Model
class HalfCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 16, 3, padding=1)
        self.conv2 = nn.Conv2d(16, 16, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(16 * 7 * 3, 64)  # Actually matches input from 28x14 images
        self.fc2 = nn.Linear(64, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))     # → [B, 16, 14, 7]
        x = self.pool(F.relu(self.conv2(x)))     # → [B, 16, 7, 3]
        x = x.view(x.size(0), -1)                # Flatten safely
        x = F.relu(self.fc1(x))
        return self.fc2(x)

# Step 4: Training and Evaluation Functions
def train(model, loader, optimizer):
    model.train()
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        output = model(x)
        loss = F.cross_entropy(output, y)
        loss.backward()
        optimizer.step()

def evaluate(model, loader):
    model.eval()
    correct = 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            pred = model(x).argmax(1)
            correct += (pred == y).sum().item()
    return correct / len(loader.dataset)

def generate_pseudo_labels(model, loader, threshold=0.95):
    model.eval()
    pseudo_x, pseudo_y = [], []
    with torch.no_grad():
        for x, _ in loader:
            x = x.to(device)
            logits = model(x)
            probs = F.softmax(logits, dim=1)
            conf, pred = torch.max(probs, 1)
            mask = conf > threshold
            if mask.any():
                pseudo_x.append(x[mask].cpu())
                pseudo_y.append(pred[mask].cpu())
    if pseudo_x:
        return torch.utils.data.TensorDataset(torch.cat(pseudo_x), torch.cat(pseudo_y))
    return None

# Step 5: Initialize Models and Loaders
model1 = HalfCNN().to(device)
model2 = HalfCNN().to(device)
opt1 = optim.Adam(model1.parameters(), lr=1e-3)
opt2 = optim.Adam(model2.parameters(), lr=1e-3)

view1_loader = DataLoader(SplitMNIST(labeled_dataset, 'left'), batch_size=64, shuffle=True)
view2_loader = DataLoader(SplitMNIST(labeled_dataset, 'right'), batch_size=64, shuffle=True)
unlabeled1 = DataLoader(SplitMNIST(unlabeled_dataset, 'left'), batch_size=256)
unlabeled2 = DataLoader(SplitMNIST(unlabeled_dataset, 'right'), batch_size=256)

# Step 6: Co-Training Loop
for epoch in range(1, 11):
    print(f"\nEpoch {epoch}")

    # Train both models on current labeled data
    train(model1, view1_loader, opt1)
    train(model2, view2_loader, opt2)

    # Generate pseudo-labels
    p1 = generate_pseudo_labels(model1, unlabeled1)
    p2 = generate_pseudo_labels(model2, unlabeled2)

    if p1 and p2:
        print(f"  Adding pseudo-labels: View1 ← {len(p2)} from model2, View2 ← {len(p1)} from model1")
        pseudo_view1 = SplitMNIST(p1, side='right')   # model1 gets help from model2
        pseudo_view2 = SplitMNIST(p2, side='left')    # model2 gets help from model1

        view1_loader = DataLoader(ConcatDataset([SplitMNIST(labeled_set, 'left'), pseudo_view2]), batch_size=64, shuffle=True)
        view2_loader = DataLoader(ConcatDataset([SplitMNIST(labeled_set, 'right'), pseudo_view1]), batch_size=64, shuffle=True)

    acc1 = evaluate(model1, unlabeled1)
    acc2 = evaluate(model2, unlabeled2)
    print(f"  View1 Accuracy: {acc1:.4f}, View2 Accuracy: {acc2:.4f}")