# Transform example

In this example we will use a consistency regularisation method with some transformations to achieve semi supervised learning. We are using MNIST with a simple CNN for classification, trained using the FixMatch algorithm

In [1]:
import wslearn
import torch
import torch.nn.functional as F
from torch import nn
import numpy as np
from tqdm import tqdm

Importing the data from torchvision and normalising the pixel values to [0, 1]

In [2]:
from torchvision.datasets import MNIST
from torchvision import transforms

mnist_tr = MNIST(root="~/.wslearn/datasets", train=True, download=True)
mnist_ts = MNIST(root="~/.wslearn/datasets", train=False, download=True)

X_tr, y_tr = mnist_tr.data.float()/255, mnist_tr.targets
X_ts, y_ts = mnist_ts.data.float()/255, mnist_ts.targets


Let's define some transformations. FixMatch expects a weak and strong transformation to be defined. Since the data is currently a torch tensor, we convert it to PIL image then apply any transforms, before converting it back to a torch tensor. The weak transform is just a random horizontal flip, the strong transform is a flip along with a random augment

In [None]:
weak_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor()
])
strong_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.RandomHorizontalFlip(),
    transforms.RandAugment(),
    transforms.ToTensor()
])

In [4]:
def split_lb_ulb_balanced(X, y, num_lbl, num_ulbl = None, return_idx = False,
                          seed=None):
    """
    A function to split features and labels into separate labelled and
    unlabelled sets.

    Args:
        X: the features
        y: the labels
        num_classes: The number of target classes
        num_lbl: The number of samples per class to be labelled
        num_ulbl: The number of samples per class to be unlabelled.
            If left unspecified, all remaining unlabelled data is taken

    Returns
        If return_idx is True:
            Returns a tuple of lists containing the labelled and unlabelled
            indices
        Else:
            Returns a 4-tuple containing the labelled features and labels,
            and unlabelled features and labels
    """

    # lbls = [] if lbl_idx is None else lbl_idx
    # ulbls = [] if ulbl_idx is None else ulbl_idx

    if seed is not None:
        torch.manual_seed(seed)

    lbls = []
    ulbls = []

    for c in torch.unique(y):
        c = c.item()
        idx = torch.where(y == c)[0]

        shuffled = idx[torch.randperm(len(idx))]

        lbls.extend(shuffled[:num_lbl].tolist())
        if num_ulbl is None:
            # all remaining examples are made unlabelled
            ulbls.extend(shuffled[num_lbl:].tolist())
        else:
            ulbls.extend(shuffled[num_lbl:num_lbl + num_ulbl].tolist())

    lbls = torch.tensor(lbls, dtype=torch.long)
    ulbls = torch.tensor(ulbls, dtype=torch.long)

    if return_idx:
        return lbls, ulbls

    return X[lbls], y[lbls], X[ulbls], y[ulbls]

We are using a CNN so we will add a channels dimension (only one channel in this case for graycsale images). Then we can split the data into the labelled and unlabelled parts, and use TransformDataset from wslearn to obtain samples

In [5]:
from wslearn.utils.data import TransformDataset
from wslearn.utils.data import split_lb_ulb_balanced

X_tr = X_tr.unsqueeze(1)
X_ts = X_ts.unsqueeze(1)
num_labels_per_class = 4


X_tr_lb, y_tr_lb, X_tr_ulb, y_tr_ulb = split_lb_ulb_balanced(X_tr, y_tr, num_labels_per_class)


lbl_dataset = TransformDataset(X_tr_lb, y_tr_lb, weak_transform=weak_transform, strong_transform=strong_transform)
ulbl_dataset = TransformDataset(X_tr_ulb, y_tr_ulb, weak_transform=weak_transform, strong_transform=strong_transform)

In [6]:
from wslearn.algorithms import FixMatch

algorithm = FixMatch()

In [7]:
import torch
from torch.nn.functional import cross_entropy

from wslearn.algorithms import Algorithm
from wslearn.utils.criterions import ce_consistency_loss

class FixMatch(Algorithm):
    """ An implementation of FixMatch (https://arxiv.org/pdf/2001.07685)

    By default the algorithm uses cross entropy loss for the supervised part,
    and cross entropy consistency loss for the unsupervised part.
    """

    def __init__(self, lambda_u=0.5,  conf_threshold=0.95,
                 max_pseudo_labels = None, sup_loss_func=None,
                 unsup_loss_func=None):
        """
        Initialise a fixmatch algorithm.

        Args:
            use_hard_label: true if using hard labelling for pseudo labels,
                otherwise soft labelling is used
            lambda_u: the weight of unlabelled loss in the total loss
            conf_threshold: the confidence threshold for pseudo-labels
            sup_loss_func: a function with signature f(pred, true) to compute
                the loss on the supervised batch
            unsup_loss_func: a function with signature f(pred, true, mask) to
                compute the loss on the unsupervised batch
        """
        super().__init__()

        self.lambda_u = lambda_u
        self.conf_threshold = conf_threshold
        self.max_pseudo_labels = max_pseudo_labels

        if sup_loss_func is None:
            # Default reduction is 'mean'
            self.sup_loss_func = cross_entropy
        else:
            self.sup_loss_func = sup_loss_func

        if unsup_loss_func is None:
            # Default reduction is 'mean'
            self.unsup_loss_func = ce_consistency_loss
        else:
            self.unsup_loss_func = unsup_loss_func

    def forward(self, model, lbl_batch, ulbl_batch, log_func=None):
        """
        Performs a forward pass of FixMatch

        Args:
            model: The predictor model
            lbl_batch: A dictionary with labelled data using keys "X", "y"
            ubl_batch: A dictionary with unlabelled data using keys "X", "y"
            log_func: A function which accepts a dictionary containing some
                training information
        """

        x_lbl_weak = lbl_batch["weak"]
        x_ulbl_weak = ulbl_batch["weak"]
        x_ulbl_strong = ulbl_batch["strong"]

        with torch.no_grad():
            logits = model(x_ulbl_weak)
            probs = torch.softmax(logits, dim=1)
            confidences, pseudo_labels = torch.max(probs, dim=1)
            mask = confidences.ge(self.conf_threshold)
            if self.max_pseudo_labels is not None:
                _, indices = torch.topk(confidences,
                                        min(self.max_pseudo_labels,
                                            len(confidences)))
                keep = torch.zeros_like(mask, dtype=torch.bool)
                keep[indices] = True
                mask &= keep

        x = torch.concat([x_lbl_weak, x_ulbl_strong])
        out = model(x)
        out_lbl_weak = out[:x_lbl_weak.size(0)]
        out_ulbl_strong = out[x_lbl_weak.size(0):]

        sup_loss = self.sup_loss_func(out_lbl_weak, lbl_batch["y"])

        unsup_loss = self.unsup_loss_func(out_ulbl_strong, pseudo_labels, mask)

        total_loss = (1 - self.lambda_u) * sup_loss + self.lambda_u * unsup_loss

        if log_func is not None:
            log_func({
                "sup_loss": sup_loss,
                "unsup_loss": unsup_loss,
                "total_loss": total_loss,
                "mask": mask,
                "pseudo_labels": pseudo_labels
            })

        return total_loss

    algorithm = FixMatch(lambda_u=1/3)

In [8]:
from wslearn.utils.data import CyclicLoader

lbl_batch_size = 20
ulbl_batch_size = 60
train_loader = CyclicLoader(lbl_dataset, ulbl_dataset, lbl_batch_size=lbl_batch_size, ulbl_batch_size=ulbl_batch_size)

In [9]:
def dict_to_device(d, device):
    return {k: v.to(device) if torch.is_tensor(v) else v for k, v in d.items()}

def train(model, train_loader, algorithm,  optimizer, num_iters=128,
          num_log_iters = 8, device="cpu"):


    model.to(device)
    model.train()

    training_bar = tqdm(train_loader, total=num_iters, desc="Training",
                        leave=True)

    for i, (lbl_batch, ulbl_batch) in enumerate(training_bar):

        lbl_batch = dict_to_device(lbl_batch, device)
        ulbl_batch = dict_to_device(ulbl_batch, device)

        optimizer.zero_grad()

        loss = algorithm.forward(model, lbl_batch, ulbl_batch)

        loss.backward()

        optimizer.step()

        if i % num_log_iters == 0:
            training_bar.set_postfix(loss = round(loss.item(), 4))

        if i > num_iters:
            break

In [10]:
import torch.nn as nn
import torch.nn.functional as F

class CNN(nn.Module):
    def __init__(self, num_classes=10):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)  # 2x2 max pooling
        self.fc1 = nn.Linear(32 * 7 * 7, 128)  # 28 -> 14 -> 7 after 2 poolings
        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))  # -> [batch, 16, 14, 14]
        x = self.pool(F.relu(self.conv2(x)))  # -> [batch, 32, 7, 7]
        x = x.view(x.size(0), -1)             # Flatten
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [11]:
model = CNN()
lr = 0.01
optimizer = torch.optim.Adam(params=model.parameters(), lr=lr)

device = "cuda" if torch.cuda.is_available() else "cpu"

In [12]:
train(model=model, train_loader=train_loader, algorithm=algorithm,
      optimizer=optimizer, device=device, num_iters=1000)

Training:   0%|          | 0/1000 [00:00<?, ?it/s]


TypeError: pic should be PIL Image or ndarray. Got <class 'torch.Tensor'>

In [None]:
from wslearn.utils.data import BasicDataset

X_ts, y_ts = X_ts.float(), y_ts.float()

test_dataset = BasicDataset(X_ts, y_ts)

In [None]:
from torch.utils.data import DataLoader

test_loader = DataLoader(test_dataset, 32)

In [None]:

from sklearn.metrics import (
    accuracy_score, confusion_matrix
)

def evaluate(model, eval_loader, device="cpu"):
    model.to(device)
    model.eval()
    total_num = 0.0
    y_true = []
    y_pred = []

    with torch.no_grad():
        for batch in eval_loader:
            X = batch["X"].to(device)
            y = batch["y"].to(device)
            num_batch = y.shape[0]
            total_num += num_batch
            logits = model(X)
            y_true.extend(y.cpu().tolist())
            y_pred.extend(torch.max(logits, dim=-1)[1].cpu().tolist())

        y_true = np.array(y_true)
        y_pred = np.array(y_pred)
        acc = accuracy_score(y_true, y_pred)
        print("accuracy: ", acc)
        cf_mat = confusion_matrix(y_true, y_pred)
        np.round(cf_mat, 2)
        print('confusion matrix:\n' + np.array_str(cf_mat))
        model.train()

In [None]:
evaluate(model, test_loader, device=device)

accuracy:  0.7617
confusion matrix:
[[ 931    2    3    2    0   18   10    5    9    0]
 [   1 1118    8    0    1    1    2    3    1    0]
 [  10    1  697  206    3   83   10   17    5    0]
 [   0    0   12  934    0   37    2   18    6    1]
 [   4    1    3    0  959    1    9    4    0    1]
 [   2    0    3  563    0  277    3   36    7    1]
 [  36    6   39   18    9   16  819    1   14    0]
 [   0    4    7   11    7   25    1  966    1    6]
 [  16   13    4  172    5   64    2   93  605    0]
 [  14    8    3   35   68    4    0  522   44  311]]
