In [1]:
# SSL : Self Training

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader,Subset,ConcatDataset

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# -------------------------
# Step 1: Data Preparation
# -------------------------

# Transform: Normalize MNIST images
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# Load full MNIST dataset
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, transform=transform)

# Create labeled subset (e.g., 100 examples evenly distributed across classes)
num_labels = 100
labels = np.array(train_dataset.targets)
labeled_idx = []

for i in range(10):
    idx = np.where(labels == i)[0][:num_labels // 10]
    labeled_idx.extend(idx)

unlabeled_idx = list(set(range(len(train_dataset))) - set(labeled_idx))

labeled_dataset = Subset(train_dataset, labeled_idx)
unlabeled_dataset = Subset(train_dataset, unlabeled_idx)

labeled_loader = DataLoader(labeled_dataset, batch_size=64, shuffle=True)
unlabeled_loader = DataLoader(unlabeled_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)

In [None]:
from torch.utils.data import DataLoader, Subset, ConcatDataset, Dataset

# Step 2: CNN for self-training
class BaseCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 16, 3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(32 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 32 * 7 * 7)
        x = F.relu(self.fc1(x))
        return self.fc2(x)

# Step 3: Supervised training
def train(model, loader, optimizer):
    model.train()
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        loss = F.cross_entropy(model(x), y)
        loss.backward()
        optimizer.step()

# Step 4: Generate pseudo-labels for unlabeled data
def generate_pseudo_labels(model, loader, threshold=0.95):
    model.eval()
    pseudo_x, pseudo_y = [], []
    with torch.no_grad():
        for x, _ in loader:
            x = x.to(device)
            logits = model(x)
            probs = F.softmax(logits, dim=1)
            conf, pred = torch.max(probs, 1)
            mask = conf > threshold
            if mask.any():
                pseudo_x.append(x[mask])
                pseudo_y.append(pred[mask])
    if pseudo_x:
        return torch.utils.data.TensorDataset(torch.cat(pseudo_x), torch.cat(pseudo_y))
    return None

def evaluate(model, loader):
    model.eval()
    correct = 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            pred = model(x).argmax(1)
            correct += (pred == y).sum().item()
    return correct / len(loader.dataset)

# Step 5: Self-Training Loop
model = BaseCNN().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(1, 11):
    train(model, labeled_loader, optimizer)
    pseudo_dataset = generate_pseudo_labels(model, unlabeled_loader, threshold=0.95)
    if pseudo_dataset:
        print(f"Epoch {epoch}: Adding {len(pseudo_dataset)} pseudo-labeled samples.")
        labeled_set = ConcatDataset([labeled_dataset, pseudo_dataset])
        labeled_loader = DataLoader(labeled_dataset, batch_size=64, shuffle=True)
    acc = evaluate(model, test_loader)
    print(f"[Self-Training] Epoch {epoch} - Test Accuracy: {acc:.4f}")