# Transform example

In this example we will use a consistency regularisation method with some transformations to achieve semi supervised learning. We are using MNIST with a simple CNN for classification, trained using the FixMatch algorithm

In [None]:
import sslpack
import torch
import torch.nn.functional as F
from torch import nn
import numpy as np
from tqdm import tqdm

Importing the data from torchvision and normalising the pixel values to [0, 1]

In [None]:
from torchvision.datasets import MNIST
from torchvision import transforms

mnist_tr = MNIST(root="~/.sslpack/datasets", train=True, download=True)
mnist_ts = MNIST(root="~/.sslpack/datasets", train=False, download=True)

X_tr, y_tr = mnist_tr.data.float()/255, mnist_tr.targets
X_ts, y_ts = mnist_ts.data.float()/255, mnist_ts.targets


Let's define some transformations. FixMatch expects a weak and strong transformation to be defined. Since the data is currently a torch tensor, we convert it to PIL image then apply any transforms, before converting it back to a torch tensor. The weak transform is just a random horizontal flip, the strong transform is a flip along with a random augment

In [16]:
weak_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor()
])
strong_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.RandomHorizontalFlip(),
    transforms.RandAugment(),
    transforms.ToTensor()
])

We are using a CNN so we will add a channels dimension (only one channel in this case for graycsale images). Then we can split the data into the labelled and unlabelled parts, and use TransformDataset from wslearn to obtain samples. TransformDataset can handle the weak and strong transformations for us

In [None]:
from sslpack.utils.data import TransformDataset
from sslpack.utils.data import stratify_lbl_ulbl

X_tr = X_tr.unsqueeze(1)
X_ts = X_ts.unsqueeze(1)
num_labels_per_class = 4


X_tr_lb, y_tr_lb, X_tr_ulb, y_tr_ulb = stratify_lbl_ulbl(X_tr, y_tr, num_labels_per_class)


lbl_dataset = TransformDataset(X_tr_lb, y_tr_lb, weak_transform=weak_transform, strong_transform=strong_transform)
ulbl_dataset = TransformDataset(X_tr_ulb, y_tr_ulb, weak_transform=weak_transform, strong_transform=strong_transform)

In [None]:
from sslpack.algorithms import FixMatch

algorithm = FixMatch(lambda_u=1/3)

In [None]:
from sslpack.utils.data import CyclicLoader

lbl_batch_size = 20
ulbl_batch_size = 60
train_loader = CyclicLoader(lbl_dataset, ulbl_dataset, lbl_batch_size=lbl_batch_size, ulbl_batch_size=ulbl_batch_size)

In [20]:
def dict_to_device(d, device):
    return {k: v.to(device) if torch.is_tensor(v) else v for k, v in d.items()}

def train(model, train_loader, algorithm,  optimizer, num_iters=128,
          num_log_iters = 8, device="cpu"):


    model.to(device)
    model.train()

    training_bar = tqdm(train_loader, total=num_iters, desc="Training",
                        leave=True)

    for i, (lbl_batch, ulbl_batch) in enumerate(training_bar):

        lbl_batch = dict_to_device(lbl_batch, device)
        ulbl_batch = dict_to_device(ulbl_batch, device)

        optimizer.zero_grad()

        loss = algorithm.forward(model, lbl_batch, ulbl_batch)

        loss.backward()

        optimizer.step()

        if i % num_log_iters == 0:
            training_bar.set_postfix(loss = round(loss.item(), 4))

        if i > num_iters:
            break

In [None]:
import torch.nn as nn
import torch.nn.functional as F

class CNN(nn.Module):
    def __init__(self, num_classes=10):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(32 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [22]:
model = CNN()
lr = 0.03
optimizer = torch.optim.SGD(params=model.parameters(), lr=lr, momentum=0.9, weight_decay=0.0005)

device = "cuda" if torch.cuda.is_available() else "cpu"

In [23]:
train(model=model, train_loader=train_loader, algorithm=algorithm,
      optimizer=optimizer, device=device, num_iters=1000)

Training: 1001it [00:45, 22.23it/s, loss=0.0957]                         


In [None]:
from sslpack.utils.data import BasicDataset
from torch.utils.data import DataLoader

X_ts, y_ts = X_ts.float(), y_ts.float()

test_dataset = BasicDataset(X_ts, y_ts)
test_loader = DataLoader(test_dataset, 32)

In [25]:

from sklearn.metrics import (
    accuracy_score, confusion_matrix
)

def evaluate(model, eval_loader, device="cpu"):
    model.to(device)
    model.eval()
    total_num = 0.0
    y_true = []
    y_pred = []

    with torch.no_grad():
        for batch in eval_loader:
            X = batch["X"].to(device)
            y = batch["y"].to(device)
            num_batch = y.shape[0]
            total_num += num_batch
            logits = model(X)
            y_true.extend(y.cpu().tolist())
            y_pred.extend(torch.max(logits, dim=-1)[1].cpu().tolist())

        y_true = np.array(y_true)
        y_pred = np.array(y_pred)
        acc = accuracy_score(y_true, y_pred)
        print("accuracy: ", acc)
        cf_mat = confusion_matrix(y_true, y_pred, normalize='true')

        with np.printoptions(suppress=True, precision=3):
            print('confusion matrix:\n' + np.array_str(cf_mat))
        model.train()

In [26]:
evaluate(model, test_loader, device=device)

accuracy:  0.7618
confusion matrix:
[[0.839 0.019 0.    0.034 0.011 0.03  0.05  0.008 0.004 0.005]
 [0.    0.99  0.004 0.003 0.    0.    0.004 0.    0.    0.   ]
 [0.008 0.003 0.774 0.084 0.012 0.063 0.045 0.009 0.003 0.   ]
 [0.001 0.009 0.011 0.775 0.005 0.158 0.01  0.006 0.013 0.012]
 [0.    0.005 0.002 0.    0.983 0.    0.004 0.005 0.    0.001]
 [0.    0.015 0.017 0.226 0.002 0.685 0.022 0.027 0.002 0.003]
 [0.029 0.038 0.017 0.015 0.023 0.01  0.867 0.    0.001 0.   ]
 [0.001 0.01  0.006 0.003 0.053 0.02  0.    0.878 0.    0.029]
 [0.001 0.082 0.003 0.094 0.066 0.07  0.041 0.054 0.586 0.002]
 [0.01  0.014 0.    0.028 0.603 0.002 0.    0.136 0.001 0.207]]
