## `wslearn` basic example 

In this example, we use `wslearn` to train a ResNet model on the CIFAR10 dataset. `wslearn` is designed to make weakly supervised learning workflows look similar to conventional supervised learning. A `wslearn` script looks very similar to typical Torch style code, with a model, dataset, dataloader, optimizer, and training loop. There are some differences however which we will discuss

In [1]:
import torch
import torch.nn.functional as F
from torch import nn, optim
import numpy as np
from tqdm import tqdm
import wslearn

`wslearn` provides ready-made datasets for use. Examples from `wslearn` datasets have transformations included for use with consistency-regularisation algorithms

In [2]:
from wslearn.datasets import Cifar10

data = Cifar10(num_lbl=8)

100%|██████████| 170M/170M [00:12<00:00, 13.2MB/s] 


There are separate labelled and unlabelled datasets. When accessing examples, the output is a dictionary of the original data, it's label and the transformed features

In [6]:
data.get_lbl_dataset()[1].keys()

dict_keys(['X', 'y', 'weak', 'medium', 'strong'])

The unlabelled observations do not have labels

In [7]:
data.get_ulbl_dataset()[1].keys()

dict_keys(['X', 'weak', 'medium', 'strong'])

`wslearn` provides an implementation of FixMatch (https://arxiv.org/pdf/2001.07685). There are several parameters we can customise

In [15]:
from wslearn.algorithms import FixMatch

algorithm = FixMatch()

We can use the torch implementation of ResNet50, and modify the output layer to 10 classes to match CIFAR10. 

In [8]:
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50')
model.fc = torch.nn.Linear(model.fc.in_features, 10)

Using cache found in /home/nhamid/.cache/torch/hub/pytorch_vision_v0.10.0


`wslearn` provides specialised dataloaders for handling labelled and unlabelled batches. The CyclicLoader will reshuffle the labelled and unlabelled data separately once they have been consumed. This means the dataloader will never terminate. Output from the CyclicLoader is a tuple labelled_batch, unlabelled_batch

In [10]:
from wslearn.utils.data import CyclicLoader

lbl_batch_size = 8
ulbl_batch_size = 16
train_loader = CyclicLoader(data.get_lbl_dataset(), data.get_ulbl_dataset(),
                               lbl_batch_size=lbl_batch_size, ulbl_batch_size=ulbl_batch_size)

We can simply use Adam as provided by Torch

In [11]:
lr = 0.0005
momentum = 0.9
nesterov = True
weight_decay = 0.0005

# optimizer = torch.optim.SGD(params=model.parameters(), lr=lr, momentum=momentum, nesterov=nesterov, weight_decay=weight_decay)

optimizer = torch.optim.Adam(params=model.parameters(), lr=lr)

We now need to write a training function. In a wsl context, we prefer to use training iterations rather than epochs, as the idea of an epoch makes less sense with two datasets in parallel. This training loop is otherwise very conventional with the exception of the main training logic being handed over to `algorithm.forward()` 

In [12]:
def dict_to_device(d, device):
    return {k: v.to(device) if torch.is_tensor(v) else v for k, v in d.items()}

def train(model, train_loader, algorithm,  optimizer, num_iters=128,
          num_log_iters = 8, device="cpu"):


    model.to(device)
    model.train()

    training_bar = tqdm(train_loader, total=num_iters, desc="Training",
                        leave=True)
    total_loss = 0.0

    for i, (lbl_batch, ulbl_batch) in enumerate(training_bar):

        lbl_batch = dict_to_device(lbl_batch, device)
        ulbl_batch = dict_to_device(ulbl_batch, device)

        optimizer.zero_grad()

        loss = algorithm.forward(model, lbl_batch, ulbl_batch)

        loss.backward()

        optimizer.step()

        total_loss += loss.item()

        avg_loss = total_loss / (i+1)

        if i % num_log_iters == 0:
            training_bar.set_postfix(avg_loss = round(avg_loss, 4))

        if i > num_iters:
            break

In [16]:
device = "cuda" if torch.cuda.is_available() else "cpu"

train(model=model, train_loader=train_loader, algorithm=algorithm,
      optimizer=optimizer, device=device, num_iters=2048)

Training:   0%|          | 0/2048 [00:00<?, ?it/s]


AttributeError: 'FixMatch' object has no attribute 'forward'

In [14]:
from torch.utils.data import DataLoader
from sklearn.metrics import (
    accuracy_score, balanced_accuracy_score, precision_score, recall_score,
    f1_score, confusion_matrix
)

def evaluate(model, eval_loader, device="cpu"):
    model.to(device)
    model.eval()
    total_loss = 0.0
    total_num = 0.0
    y_true = []
    y_pred = []
    y_probs = []
    y_logits = []

    with torch.no_grad():
        for batch in eval_loader:
            X = batch["X"].to(device)
            y = batch["y"].to(device)

            num_batch = y.shape[0]
            total_num += num_batch

            logits = model(X)



            loss = F.cross_entropy(logits, y, reduction='mean', ignore_index=-1)
            y_true.extend(y.cpu().tolist())
            y_pred.extend(torch.max(logits, dim=-1)[1].cpu().tolist())
            y_logits.append(logits.cpu().numpy())
            y_probs.extend(torch.softmax(logits, dim=-1).cpu().tolist())
            total_loss += loss.item() * num_batch

        y_true = np.array(y_true)
        y_pred = np.array(y_pred)
        y_logits = np.concatenate(y_logits)
        top1 = accuracy_score(y_true, y_pred)
        # top5 = top_k_accuracy_score(y_true, y_pred, k=5)
        balanced_top1 = balanced_accuracy_score(y_true, y_pred)
        precision = precision_score(y_true, y_pred, average='macro')
        recall = recall_score(y_true, y_pred, average='macro')
        F1 = f1_score(y_true, y_pred, average='macro')

        print("accuracy: ", top1)
        # print("accuracy top 5: ", top5)
        print("balanced-accuracy: ", balanced_top1)
        print("recall: ", recall)
        print("f1: ", F1)

        cf_mat = confusion_matrix(y_true, y_pred, normalize='true')
        print('confusion matrix:\n' + np.array_str(cf_mat))

        model.train()



In [27]:
eval_loader = DataLoader(data.get_eval_dataset(), batch_size=32)
evaluate(model, eval_loader, device)

accuracy:  0.1299
balanced-accuracy:  0.1299
recall:  0.1299
f1:  0.07572577356234492
confusion matrix:
[[0.    0.003 0.203 0.    0.021 0.    0.481 0.    0.218 0.074]
 [0.    0.002 0.24  0.008 0.083 0.    0.513 0.    0.124 0.03 ]
 [0.    0.    0.157 0.001 0.097 0.    0.691 0.    0.035 0.019]
 [0.    0.001 0.174 0.    0.148 0.    0.65  0.    0.013 0.014]
 [0.    0.    0.222 0.002 0.088 0.    0.659 0.    0.023 0.006]
 [0.    0.    0.119 0.001 0.219 0.    0.64  0.    0.013 0.008]
 [0.    0.001 0.143 0.    0.12  0.    0.729 0.    0.003 0.004]
 [0.    0.    0.091 0.001 0.135 0.    0.716 0.    0.048 0.009]
 [0.    0.002 0.242 0.001 0.058 0.    0.383 0.    0.302 0.012]
 [0.    0.    0.096 0.001 0.041 0.    0.667 0.    0.174 0.021]]


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
