<a href="https://colab.research.google.com/github/cosmaadrian/ml-environment/blob/master/DSM_Lab_FixMatch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Semi-Supervised Learning with FixMatch

In this lab, we will explore the implementation of FixMatch, a method for performing semi-supervised learning.

We will perform our training on a fraction of CIFAR10, a dataset of natural images

You can find the original paper here: FixMatch: Simplifying Semi-Supervised Learning with Consistency and Confidence (https://arxiv.org/abs/2001.07685)



In [None]:
import sys

import numpy as np
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import transforms, datasets
from torchvision.models import resnet18
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import RandomSampler, SequentialSampler

import matplotlib.pyplot as plt
from PIL import Image

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def train_epoch(model, dataloader, device, optimizer, criterion, epoch):
    model.train()

    total_train_loss = 0.0
    dataset_size = 0

    bar = tqdm(enumerate(dataloader), total=len(dataloader), colour='cyan', file=sys.stdout)

    for step, (images, labels) in bar:
        images = images.to(device)
        labels = labels.to(device)

        batch_size = images.shape[0]

        optimizer.zero_grad()
        pred = model(images)
        loss = criterion(pred, labels)

        loss.backward()
        optimizer.step()

        total_train_loss += (loss.item() * batch_size)
        dataset_size += batch_size

        epoch_loss = np.round(total_train_loss / dataset_size, 2)
        bar.set_postfix(Epoch=epoch, Train_Loss=epoch_loss)


    return epoch_loss

def valid_epoch(model, dataloader, device, criterion, epoch):
    model.eval()

    total_val_loss = 0.0
    dataset_size = 0

    correct = 0

    bar = tqdm(enumerate(dataloader), total=len(dataloader), colour='cyan', file=sys.stdout)
    for step, (images, labels) in bar:
        images = images.to(device)
        labels = labels.to(device)

        batch_size = images.shape[0]

        pred = model(images)
        loss = criterion(pred, labels)

        _, predicted = torch.max(pred, 1)
        correct += (predicted == labels).sum().item()

        total_val_loss += (loss.item() * batch_size)
        dataset_size += batch_size

        epoch_loss = np.round(total_val_loss / dataset_size, 2)

        accuracy = np.round(100 * correct / dataset_size, 2)

        bar.set_postfix(Epoch=epoch, Valid_Acc=accuracy, Valid_Loss=epoch_loss)

    return accuracy, epoch_loss

def run_training(model, trainloader, testloader, criterion, optimizer, num_epochs):
    if torch.cuda.is_available():
        print("[INFO] Using GPU: {}\n".format(torch.cuda.get_device_name()))

    top_accuracy = 0.0

    for epoch in range(num_epochs):

        train_loss = train_epoch(model, trainloader, device, optimizer, criterion, epoch)
        with torch.no_grad():
            val_accuracy, val_loss = valid_epoch(model, testloader, device, criterion, epoch)
            if val_accuracy > top_accuracy:
                print(f"Validation Accuracy Improved ({top_accuracy} ---> {val_accuracy})")
                top_accuracy = val_accuracy
        print()

In [None]:
# Don't touch this for now.
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

cifar_trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
cifar_testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

# Let's just use 10% of the data to make it harder
from collections import defaultdict
indices_per_class = defaultdict(list)
for i in range(len(cifar_trainset)):
  _, class_label = cifar_trainset[i]
  indices_per_class[class_label].append(i)

labeled_indices = []
unlabeled_indices = []
for class_name, indices in indices_per_class.items():
  labeled_indices.extend(indices[:int(0.1 * len(indices))]) # 10% labeled
  unlabeled_indices.extend(indices[int(0.1 * len(indices)):]) # 90% unlabeled

cifar_labeled_trainset = torch.utils.data.Subset(dataset = cifar_trainset, indices = labeled_indices)

cifar_labeled_trainloader = DataLoader(cifar_labeled_trainset, batch_size=64, shuffle=True)
cifar_testloader = DataLoader(cifar_testset, batch_size=64, shuffle=False)


# Baseline: Training just with the supervised data.

In [None]:
learning_rate = 0.001
epochs = 10

model = torchvision.models.resnet18(pretrained = False) # let's initialize a ResNet18 from scratch and pretrain it ourselves
model.fc = nn.Linear(in_features=model.fc.in_features, out_features=10, bias=True)

model.to(device)
criterion = nn.CrossEntropyLoss()

# Adam is an improved gradient descent algorithm
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [None]:
run_training(model, cifar_labeled_trainloader, cifar_testloader, criterion = criterion, optimizer = optimizer, num_epochs = epochs)

## Results suck. Let's make them better.

First, let's define a set of weak and strong augmentations.

In [None]:
weak_transforms = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomCrop(32, padding=int(32*0.125)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Strong augmentations (additional color jitter & grayscale)
strong_transforms = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomCrop(32, padding=int(32*0.125)),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.02),
    transforms.RandomGrayscale(p=0.2),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

val_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

class TransformFixMatch(object):
    def __init__(self, weak, strong):
        self.weak = weak
        self.strong = strong

    def __call__(self, x):
        weak = self.weak(x)
        strong = self.strong(x)

        return weak, strong

class CIFAR10SSL(datasets.CIFAR10):
    def __init__(self, root, indexs, train=True, transform=None, target_transform=None, download=False):
        super().__init__(root, train=train, transform=transform, target_transform=target_transform, download=download)
        if indexs is not None:
            self.data = self.data[indexs]
            self.targets = np.array(self.targets)[indexs]

    def __getitem__(self, index):
        img, target = self.data[index], self.targets[index]
        img = Image.fromarray(img)

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target


In [None]:
# Adding the new transforms
cifar_labeled_trainset = CIFAR10SSL(root='./data', indexs = labeled_indices, train=True, transform = weak_transforms)
cifar_unlabeled_trainset = CIFAR10SSL(root='./data', indexs = unlabeled_indices, train=True, transform=TransformFixMatch(weak = weak_transforms, strong = strong_transforms))
cifar_testset = datasets.CIFAR10(root='./data', train=False, transform = val_transforms, download=False)

# One training batch contains 1/8 labeled data and 7/8 unlabeled data.
cifar_labeled_trainloader = DataLoader(cifar_labeled_trainset, batch_size=64, sampler = RandomSampler(cifar_labeled_trainset))
cifar_unlabeled_trainloader = DataLoader(cifar_unlabeled_trainset, batch_size=7 * 64, sampler = RandomSampler(cifar_unlabeled_trainset))

# Test set is all labeled
cifar_testloader = DataLoader(cifar_testset, batch_size=64, shuffle = False)

In [None]:
def train_fixmatch_epoch(model, labeled_dataloader, unlabeled_dataloader, device, optimizer, epoch):
    criterion_labeled = nn.CrossEntropyLoss()
    criterion_unlabeled = nn.CrossEntropyLoss(reduction='none')  # Per-example loss

    threshold = 0.90  # Pseudo-label confidence threshold

    model.train()
    epoch_loss = 0

    bar = tqdm(enumerate(unlabeled_dataloader), total=len(unlabeled_dataloader), colour='cyan', file=sys.stdout)
    bar.set_description(f"Epoch {epoch}")

    labeled_iterator = iter(labeled_dataloader)

    for step, (unlabeled_images, _) in bar:
        unlabeled_images_weak, unlabeled_images_strong = unlabeled_images
        unlabeled_images_weak = unlabeled_images_weak.to(device)
        unlabeled_images_strong = unlabeled_images_strong.to(device)

        # Get next batch of labeled data (handle StopIteration)
        try:
            labeled_images, labels = next(labeled_iterator)
        except StopIteration:
            labeled_iterator = iter(labeled_dataloader)
            labeled_images, labels = next(labeled_iterator)

        labeled_images = labeled_images.to(device)
        labels = labels.to(device).long()

        optimizer.zero_grad()

        # Compute predictions on labeled data
        pred_labeled = model(labeled_images)
        loss_labeled = criterion_labeled(pred_labeled, labels)

        # Compute pseudo-labels for weakly augmented images
        with torch.no_grad():
            pred_weak = model(unlabeled_images_weak)
            pred_weak_confidence = torch.nn.functional.softmax(pred_weak, dim=-1)
            max_values, max_indices = torch.max(pred_weak_confidence, dim=-1)
            max_indices = max_indices.to(device).long()

            # Create FixMatch mask (ignore low-confidence predictions)
            fixmatch_mask = (max_values > threshold).float()

        # TODO other things to try out
        # add mixup between labeled data and unlabeled data
        # (interpolate labeled images with strongly augmented unlabeled images and between corresponding true labels and pseudo-labels)
        # mixup: BEYOND EMPIRICAL RISK MINIMIZATION https://arxiv.org/pdf/1710.09412

        # TODO (more complicated)
        # some pseudo-labels might be wrong and stay wrong throughout training
        # maybe figure out which ones are wrong by looking at training dynamics
        # Identifying Mislabeled Data using the Area Under the Margin Ranking https://arxiv.org/pdf/2001.10528
        # MarginMatch: Improving Semi-Supervised Learning with Pseudo-Margins https://arxiv.org/pdf/2308.09037

        pred_strong = model(unlabeled_images_strong)

        # loss for labeled images (nothing special, crossentropy between true labels and preds)
        loss_labeled = criterion_labeled(pred_labeled, labels)

        # loss for unlabeled: crossentropy between high-confidence pseudo-labels on weak augmentations and preds on strong augmentations
        # fixmatch_mask filters out unconfident pseudo-labels
        loss_consistency = criterion_unlabeled(pred_strong, max_indices) * fixmatch_mask
        loss_consistency = loss_consistency.mean()

        # Total loss (labeled + consistency)
        loss = loss_labeled + loss_consistency

        # Backpropagation
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

        bar.set_postfix(
            Epoch=epoch,
            LabeledLoss=loss_labeled.item(),
            ConsistencyLoss=loss_consistency.item(),
            FractionMasked=(1 - fixmatch_mask.float().mean().clamp(min=1e-6)).item()
        )

    return epoch_loss / len(unlabeled_dataloader)

def run_training_fixmatch(model, labeled_trainloader, unlabeled_trainloader, testloader, optimizer, num_epochs):
    print(f"[INFO] Using device: {device}")

    top_accuracy = 0.0
    criterion = nn.CrossEntropyLoss()

    for epoch in range(num_epochs):
        train_loss = train_fixmatch_epoch(model, labeled_trainloader, unlabeled_trainloader, device, optimizer, epoch)
        print(f"Epoch {epoch} Loss: {train_loss:.2f}")

        # Validate the model after each epoch
        with torch.no_grad():
            val_accuracy, val_loss = valid_epoch(model, testloader, device, criterion, epoch)
            print(f"Validation Accuracy: {val_accuracy:.2f} | Validation Loss: {val_loss:.2f}")
    
            # Save best model based on validation accuracy
            if val_accuracy > top_accuracy:
                print(f"Validation Accuracy Improved ({top_accuracy:.2f} ---> {val_accuracy:.2f})")
                top_accuracy = val_accuracy

        print()

In [None]:
learning_rate = 0.001
epochs = 40

model = torchvision.models.resnet18(pretrained=False)  # No pretraining
model.fc = nn.Linear(in_features=model.fc.in_features, out_features=10, bias=True)
model.to(device)

# Adam is an improved gradient descent algorithm
optimizer = optim.AdamW(model.parameters(), lr=learning_rate)

In [None]:
run_training_fixmatch(
    model,
    labeled_trainloader=cifar_labeled_trainloader,
    unlabeled_trainloader=cifar_unlabeled_trainloader,
    testloader=cifar_testloader,
    optimizer=optimizer,
    num_epochs=epochs
)