## `wslearn` example 

In this example, we use `wslearn` to train a ResNet model on the CIFAR10 dataset. `wslearn` is designed to make weakly supervised learning workflows look similar to conventional supervised learning. A `wslearn` script looks very similar to typical Torch style code, with a model, dataset, dataloader, optimizer, and training loop. There are some differences however which we will discuss

In [1]:
import torch
import torch.nn.functional as F
from torch import nn, optim
import numpy as np
from tqdm import tqdm
import wslearn

`wslearn` provides ready-made datasets for use. Examples from `wslearn` datasets have transformations included for use with consistency-regularisation algorithms

In [2]:
from wslearn.datasets import Cifar10

data = Cifar10(lbls_per_class=8)

There are separate labelled and unlabelled datasets. When accessing examples, the output is a dictionary of the original data, it's label and the transformed features

In [3]:
data.get_lbl_dataset()[1].keys()

dict_keys(['y', 'weak', 'strong', 'X'])

The unlabelled observations do not have labels

In [4]:
data.get_ulbl_dataset()[1].keys()

dict_keys(['weak', 'strong', 'X'])

`wslearn` provides an implementation of FixMatch (https://arxiv.org/pdf/2001.07685). There are several parameters we can customise

In [5]:
from wslearn.algorithms import FixMatch

algorithm = FixMatch()

We will use a pretrained Vision transformer and fine tune it for our application. 

In [6]:
from wslearn.networks.vision_transfomers import ViT_Tiny_2

model = ViT_Tiny_2(32, 3, 10)
model.load_checkpoint("https://github.com/microsoft/Semi-supervised-learning/releases/download/v.0.0.0/vit_tiny_patch2_32_mlp_im_1k_32.pth")


# model = vit_tiny_patch2_32(pretrained=True, pretrained_path="https://github.com/microsoft/Semi-supervised-learning/releases/download/v.0.0.0/vit_tiny_patch2_32_mlp_im_1k_32.pth")


  from .autonotebook import tqdm as notebook_tqdm


`wslearn` provides specialised dataloaders for handling labelled and unlabelled batches. The CyclicLoader will reshuffle the labelled and unlabelled data separately once they have been consumed. This means the dataloader will never terminate. Output from the CyclicLoader is a tuple labelled_batch, unlabelled_batch

In [7]:
from wslearn.utils.data import CyclicLoader

lbl_batch_size = 8
ulbl_batch_size = 16
train_loader = CyclicLoader(data.get_lbl_dataset(), data.get_ulbl_dataset(),
                               lbl_batch_size=lbl_batch_size, ulbl_batch_size=ulbl_batch_size)

We can simply use Adam as provided by Torch

In [8]:
lr = 0.0005

optimizer = torch.optim.Adam(params=model.parameters(), lr=lr)

We now need to write a training function. In a wsl context, we prefer to use training iterations rather than epochs, as the idea of an epoch makes less sense with two datasets in parallel. This training loop is otherwise very conventional with the exception of the main training logic being handed over to `algorithm.forward()` 

In [9]:
def dict_to_device(d, device):
    return {k: v.to(device) if torch.is_tensor(v) else v for k, v in d.items()}

def train(model, train_loader, algorithm,  optimizer, num_iters=128,
          num_log_iters = 8, device="cpu"):


    model.to(device)
    model.train()

    training_bar = tqdm(train_loader, total=num_iters, desc="Training",
                        leave=True)
    total_loss = 0.0

    for i, (lbl_batch, ulbl_batch) in enumerate(training_bar):

        lbl_batch = dict_to_device(lbl_batch, device)
        ulbl_batch = dict_to_device(ulbl_batch, device)

        optimizer.zero_grad()

        loss = algorithm.forward(model, lbl_batch, ulbl_batch)

        loss.backward()

        optimizer.step()

        total_loss += loss.item()

        avg_loss = total_loss / (i+1)

        if i % num_log_iters == 0:
            training_bar.set_postfix(avg_loss = round(avg_loss, 4))

        if i > num_iters:
            break

In [10]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [11]:
train(model=model, train_loader=train_loader, algorithm=algorithm,
      optimizer=optimizer, device=device, num_iters=128)

Training: 129it [00:15,  8.49it/s, avg_loss=0.919]                         


Now the model has finished training, we can evaluate it's performance

In [12]:
from torch.utils.data import DataLoader
from sklearn.metrics import (
    accuracy_score, confusion_matrix
)

def evaluate(model, eval_loader, device="cpu"):
    model.to(device)
    model.eval()
    total_num = 0.0
    y_true = []
    y_pred = []

    with torch.no_grad():
        for batch in eval_loader:
            X = batch["X"].to(device)
            y = batch["y"].to(device)
            num_batch = y.shape[0]
            total_num += num_batch
            logits = model(X)
            y_true.extend(y.cpu().tolist())
            y_pred.extend(torch.max(logits, dim=-1)[1].cpu().tolist())

        y_true = np.array(y_true)
        y_pred = np.array(y_pred)
        acc = accuracy_score(y_true, y_pred)
        print("accuracy: ", acc)
        cf_mat = confusion_matrix(y_true, y_pred, normalize='true')
        print('confusion matrix:\n' + np.array_str(cf_mat))
        model.train()



In [13]:
eval_loader = DataLoader(data.get_eval_dataset(), batch_size=32)
evaluate(model, eval_loader, device)

accuracy:  0.6558
confusion matrix:
[[0.747 0.012 0.    0.001 0.006 0.004 0.002 0.074 0.144 0.01 ]
 [0.    0.916 0.    0.    0.001 0.    0.    0.005 0.026 0.052]
 [0.104 0.007 0.507 0.029 0.15  0.098 0.03  0.061 0.014 0.   ]
 [0.031 0.021 0.015 0.413 0.184 0.21  0.042 0.055 0.027 0.002]
 [0.052 0.002 0.005 0.069 0.624 0.038 0.027 0.17  0.012 0.001]
 [0.016 0.004 0.001 0.032 0.09  0.806 0.001 0.042 0.008 0.   ]
 [0.003 0.006 0.008 0.044 0.08  0.018 0.828 0.008 0.005 0.   ]
 [0.106 0.006 0.003 0.028 0.308 0.112 0.007 0.386 0.041 0.003]
 [0.101 0.069 0.    0.    0.003 0.001 0.001 0.032 0.748 0.045]
 [0.048 0.241 0.    0.    0.005 0.001 0.    0.034 0.088 0.583]]
