# Iris basic example

In this example, we will use Pseudo labelling to classify the iris dataset with a simple MLP using `sslpack`

In [None]:
import sslpack
import torch
import torch.nn.functional as F
from torch import nn
import numpy as np
from tqdm import tqdm

Iris is not provided in `sslpack`, so we will first use sci-kit learn to load the dataset

In [2]:
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split

X, y = load_iris(return_X_y=True)

We would like to generate a train-test split, but also partition the training data into a labelled and unlabelled part. We can use a `sslpack` function to achieve this

In [None]:
from sslpack.utils.data import split_lb_ulb_balanced

X_tr, X_ts, y_tr, y_ts = train_test_split(X, y, test_size=0.3)

num_labels_per_class = 2
X_tr_lb, y_tr_lb, X_tr_ulb, y_tr_ulb = split_lb_ulb_balanced(X_tr, y_tr, num_labels_per_class)

X_tr_lb, y_tr_lb  = torch.tensor(X_tr_lb), torch.tensor(y_tr_lb)
X_tr_ulb, y_tr_ulb = torch.tensor(X_tr_ulb), torch.tensor(y_tr_ulb)

  X_tr_lb, y_tr_lb  = torch.tensor(X_tr_lb), torch.tensor(y_tr_lb)


`sslpack` expects datasets to return dictionaries to support the various transformations that are required by consistency regularisation methods. For this example, we have no transformations, but we need the correct dataset format

In [None]:
from sslpack.utils.data import BasicDataset
lbl_dataset = BasicDataset(X_tr_lb, y_tr_lb)
ulbl_dataset = BasicDataset(X_tr_ulb)

Now we can import the implementation of Pseudo label. 

In [None]:
from sslpack.algorithms import PseudoLabel
algorithm= PseudoLabel()

Unlike conventional supervised learning, our dataloaders need to handle two streams - a labelled and unlabelled part. Therefore, `sslpack` provides such dataloader, such as the CyclicLoader. We just need to specify the batch size of the labelled and unlabelled parts

In [None]:
from sslpack.utils.data import CyclicLoader

lbl_batch_size = 6
ulbl_batch_size = 12
train_loader = CyclicLoader(lbl_dataset, ulbl_dataset, lbl_batch_size=lbl_batch_size, ulbl_batch_size=ulbl_batch_size)

Now we can write a training function. `sslpack` is designed to be similar to convential SSL torch code, and as such the training loop is hopefully familiar. However, we replace the main training logic by the forward pass of the algorithm

In [7]:
def dict_to_device(d, device):
    return {k: v.to(device) if torch.is_tensor(v) else v for k, v in d.items()}

def train(model, train_loader, algorithm,  optimizer, num_iters=128,
          num_log_iters = 8, device="cpu"):


    model.to(device)
    model.train()

    training_bar = tqdm(train_loader, total=num_iters, desc="Training",
                        leave=True)

    for i, (lbl_batch, ulbl_batch) in enumerate(training_bar):

        lbl_batch = dict_to_device(lbl_batch, device)
        ulbl_batch = dict_to_device(ulbl_batch, device)

        optimizer.zero_grad()

        loss = algorithm.forward(model, lbl_batch, ulbl_batch)

        loss.backward()

        optimizer.step()

        if i % num_log_iters == 0:
            training_bar.set_postfix(loss = round(loss.item(), 4))

        if i > num_iters:
            break

All the remains is to specify a model and optimizer. 

In [8]:
model = torch.nn.Sequential(
    nn.Linear(4, 10),
    nn.ReLU(),
    nn.Linear(10, 10),
    nn.ReLU(),
    nn.Linear(10, 3)
)

model.double()

lr = 0.01
optimizer = torch.optim.Adam(params=model.parameters(), lr=lr)

device = "cuda" if torch.cuda.is_available() else "cpu"

We can now train the model

In [9]:
train(model=model, train_loader=train_loader, algorithm=algorithm,
      optimizer=optimizer, device=device, num_iters=100)

Training: 101it [00:00, 127.66it/s, loss=0.0037]                        


Now testing on the withheld test set

In [10]:
def test(model, X, y):
    model.eval()
    with torch.no_grad():
        predictions = model(X).argmax(dim=1)
        acc = (predictions == y).float().mean()
        return float(acc)

In [11]:
test(model, torch.tensor(X_ts).to(device), torch.tensor(y_ts).to(device))

0.9333333373069763