In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader, Subset
import torch.backends.cudnn as cudnn
import numpy as np
import random
from itertools import cycle
from copy import deepcopy

In [2]:
# Random parameters for the reproducibility

seed = 2025

cudnn.benchmark = False
cudnn.deterministic = True
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)

In [3]:
#Set the number of labeled, validated

labeled_num = 500
val_num = 1000

#Base Augmentation
transform = transforms.Compose([transforms.ToTensor()])

#Load MNIST dataset
dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)

indices = np.arange(len(dataset))
np.random.shuffle(indices)

#The number of labeled samples
labeled_idx = indices[:labeled_num]

#The number of validation samples
val_idx = indices[labeled_num:labeled_num + val_num]

#The number of unlabeled samples
unlabeled_idx = indices[labeled_num + val_num:]

# Get labeled set by predefined indices
labeled_set = Subset(dataset, labeled_idx)

# Get validated set by predefined indices
val_set = Subset(dataset, val_idx)

#Set up loader for labeled set
labeled_loader = DataLoader(labeled_set, batch_size = 512, shuffle = True)

#Set up loader for validation set
val_loader = DataLoader(val_set, batch_size = 512, shuffle = False)


100%|██████████| 9.91M/9.91M [00:00<00:00, 53.2MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 1.72MB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 14.9MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 8.94MB/s]


In [4]:
# Weak-Strong augmentation
weak_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(28, padding = 4),
    transforms.ToTensor()
])

strong_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(28, padding = 4),
    transforms.RandomRotation(15),
    transforms.ToTensor()
])

#Load dataset with weak-strong augmentation
weak_transformed_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=weak_transform)
strong_transformed_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=strong_transform)

In [5]:
#Define a class for every pair of data from the weak and strong transformation

class PairedDataset(torch.utils.data.Dataset):
    def __init__(self, weak_dataset, strong_dataset):
        self.weak_dataset = weak_dataset
        self.strong_dataset = strong_dataset

    def __len__(self):
        return len(self.weak_dataset)

    def __getitem__(self, idx):
        weak_image, weak_label = self.weak_dataset[idx]
        strong_image, strong_label = self.strong_dataset[idx]
        return (weak_image, weak_label), (strong_image, strong_label)

In [6]:
# Get weak, strong augmented set by predefined indices
weak_unlabeled_set = Subset(weak_transformed_dataset, unlabeled_idx)
strong_unlabeled_set = Subset(strong_transformed_dataset, unlabeled_idx)

# Set up loader for unlabeled data
paired_unlabeled_set = PairedDataset(weak_unlabeled_set, strong_unlabeled_set)
unlabeled_loader = DataLoader(paired_unlabeled_set, batch_size = 512, shuffle = True)

In [7]:
# Create a model

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, kernel_size=2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, kernel_size=2)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [11]:
# Init model, optimizer, loss
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_student = CNN().to(device)
model_teacher = CNN().to(device)
model_teacher.load_state_dict(model_student.state_dict())

optimizer = optim.Adam(model_student.parameters(), lr= 0.001)
criterion = nn.CrossEntropyLoss()
thresh = 0.95
ema_decay = 0.99

#Variables for storing best model
iter_per_eval = 5
iter_num = 0
max_iter_num = 200

best_model = None
best_acc = 0

In [18]:
# Define the function for validation

def validation(model):
  # Move to evaluation model
  model.eval()
  correct = 0
  with torch.no_grad():
    for data, target in val_loader:
      # data, target = data.cuda(), target.cuda()
      data, target = data.to(device), target.to(device)
      output = model(data)
      pred = output.argmax(dim=1)
      correct += pred.eq(target).sum().item()

  acc_score = correct / len(val_set)
  print(f"Validation Accuracy: {acc_score}")
  model.train()
  return acc_score

In [10]:
# Define the function to update teacher model with EMA (Exponential Moving Average) formula

def update_teacher(model_student, model_teacher, ema_decay):
  with torch.no_grad():
    for t_param, s_param in zip(model_teacher.parameters(), model_student.parameters()):
      t_param.data.mul_(ema_decay).add_(s_param.data, alpha= 1 - ema_decay)

In [19]:
# Train model

model_student.train()
model_teacher.train()

for _ in range(100):
  for (labeled_data, labeled_target), ((weak_unlabeled_data, _), (strong_unlabeled_data, _)) in zip(cycle(labeled_loader), unlabeled_loader):
    labeled_data, labeled_target = labeled_data.to(device), labeled_target.to(device)
    weak_unlabeled_data, strong_unlabeled_data = weak_unlabeled_data.to(device), strong_unlabeled_data.to(device)

    #Supervised loss
    logist_x = model_student(labeled_data)
    loss_x = criterion(logist_x, labeled_target)

    #Unspervised consistency loss
    with torch.no_grad():
      teacher_preds = torch.softmax(model_teacher(weak_unlabeled_data), dim=1)
    student_preds = model_student(strong_unlabeled_data)

    max_probs, pseudo_labels = torch.max(teacher_preds, dim=1)
    mask = max_probs.ge(thresh).float()

    loss_u = (criterion(student_preds, pseudo_labels) * mask).mean()
    loss = loss_x + loss_u

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    update_teacher(model_student, model_teacher, ema_decay)

    iter_num += 1
    print(f"Iteration {iter_num}, Loss: {loss.item():.4f}")

    if iter_num % iter_per_eval == 0:
      acc_score = validation(model_student)
      if acc_score >= best_acc:
        best_acc = acc_score
        best_model = deepcopy(model_student)
        print(f"Save new best model at iter: {iter_num}")

    if iter_num == max_iter_num:
      break

  if iter_num == max_iter_num:
    break


Iteration 11, Loss: 1.1411
Iteration 12, Loss: 1.0057
Iteration 13, Loss: 0.8878
Iteration 14, Loss: 0.7889
Iteration 15, Loss: 0.7084
Validation Accuracy: 0.76
Save new best model at iter: 15
Iteration 16, Loss: 0.6432
Iteration 17, Loss: 0.5936
Iteration 18, Loss: 0.5562
Iteration 19, Loss: 0.5260
Iteration 20, Loss: 0.4936
Validation Accuracy: 0.804
Save new best model at iter: 20
Iteration 21, Loss: 0.4638
Iteration 22, Loss: 0.4388
Iteration 23, Loss: 0.4130
Iteration 24, Loss: 0.3964
Iteration 25, Loss: 0.3671
Validation Accuracy: 0.824
Save new best model at iter: 25
Iteration 26, Loss: 0.3557
Iteration 27, Loss: 0.3268
Iteration 28, Loss: 0.3153
Iteration 29, Loss: 0.2911
Iteration 30, Loss: 0.2780
Validation Accuracy: 0.841
Save new best model at iter: 30
Iteration 31, Loss: 0.2583
Iteration 32, Loss: 0.2460
Iteration 33, Loss: 0.2298
Iteration 34, Loss: 0.2189
Iteration 35, Loss: 0.2046
Validation Accuracy: 0.851
Save new best model at iter: 35
Iteration 36, Loss: 0.1949
Iter

In [20]:
# Evaluate model

def test(model):
  model.eval()
  test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
  test_loader = DataLoader(test_dataset, batch_size=512, shuffle=False)
  correct = 0
  with torch.no_grad():
    for data, target in test_loader:
      data, target = data.to(device), target.to(device)
      output = model(data)
      pred = output.argmax(dim=1)
      correct += pred.eq(target).sum().item()
  print(f"Test Accuracy: {correct/len(test_dataset):.4f}")

test(best_model)

Test Accuracy: 0.9440
