# Custom Algorithm

## In this tutorial, we provide an example of creaing a new SSL algorithm by resusing the component hook

In [1]:
from semilearn import get_dataset, get_data_loader, get_net_builder, get_algorithm, get_config, Trainer


## Step1: Create and Register Algorithm

We will create a new algorithm based on FixMatch and combine it with entropy loss, using the following steps:

* Inherit FixMatch algorithm
* Rewrite the 'train_step' function by adding entropy loss

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from semilearn.core.utils import ALGORITHMS
from semilearn.algorithms.fixmatch import FixMatch
from inspect import signature



def entropy_loss(ul_y):
    p = F.softmax(ul_y, dim=1)
    return -(p * F.log_softmax(ul_y, dim=1)).sum(dim=1).mean(dim=0)


@ALGORITHMS.register('fixmatch_entropy')
class FixMatchEntropy(FixMatch):
    def train_step(self, idx_lb, x_lb, y_lb, x_ulb_w, x_ulb_s):
        num_lb = y_lb.shape[0]

        # inference and calculate sup/unsup losses
        with self.amp_cm():
            if self.use_cat:
                inputs = torch.cat((x_lb, x_ulb_w, x_ulb_s))
                outputs = self.model(inputs)
                logits_x_lb = outputs['logits'][:num_lb]
                logits_x_ulb_w, logits_x_ulb_s = outputs['logits'][num_lb:].chunk(2)
                feats_x_lb = outputs['feat'][:num_lb]
                feats_x_ulb_w, feats_x_ulb_s = outputs['feat'][num_lb:].chunk(2)
            else:
                outs_x_lb = self.model(x_lb) 
                logits_x_lb = outs_x_lb['logits']
                feats_x_lb = outs_x_lb['feat']
                outs_x_ulb_s = self.model(x_ulb_s)
                logits_x_ulb_s = outs_x_ulb_s['logits']
                feats_x_ulb_s = outs_x_ulb_s['feat']
                with torch.no_grad():
                    outs_x_ulb_w = self.model(x_ulb_w)
                    logits_x_ulb_w = outs_x_ulb_w['logits']
                    feats_x_ulb_w = outs_x_ulb_w['feat']
            feat_dict = {'x_lb':feats_x_lb, 'x_ulb_w':feats_x_ulb_w, 'x_ulb_s':feats_x_ulb_s}

            sup_loss = self.ce_loss(logits_x_lb, y_lb, reduction='mean')
            
            # probs_x_ulb_w = torch.softmax(logits_x_ulb_w, dim=-1)
            probs_x_ulb_w = self.compute_prob(logits_x_ulb_w.detach())
            
            # if distribution alignment hook is registered, call it 
            # this is implemented for imbalanced algorithm - CReST
            if self.registered_hook("DistAlignHook"):
                probs_x_ulb_w = self.call_hook("dist_align", "DistAlignHook", probs_x_ulb=probs_x_ulb_w.detach())

            # compute mask
            mask = self.call_hook("masking", "MaskingHook", logits_x_ulb=probs_x_ulb_w, softmax_x_ulb=False)

            # generate unlabeled targets using pseudo label hook
            pseudo_label = self.call_hook("gen_ulb_targets", "PseudoLabelingHook", 
                                          logits=probs_x_ulb_w,
                                          use_hard_label=self.use_hard_label,
                                          T=self.T,
                                          softmax=False)

            unsup_loss = self.consistency_loss(logits_x_ulb_s,
                                               pseudo_label,
                                               'ce',
                                               mask=mask)
            
            # NOTE: add entropy loss here
            loss_entmin = entropy_loss(logits_x_ulb_w)

            total_loss = sup_loss + self.lambda_u * unsup_loss

        out_dict = self.process_out_dict(loss=total_loss, feat=feat_dict)
        log_dict = self.process_log_dict(sup_loss=sup_loss.item(), 
                                         unsup_loss=unsup_loss.item(), 
                                         total_loss=total_loss.item(), 
                                         util_ratio=mask.float().mean().item())
        return out_dict, log_dict

## Step 2: define configs and create config

In [3]:
config = {
    'algorithm': 'fixmatch_entropy',
    'net': 'vit_tiny_patch2_32',
    'use_pretrain': True, 
    'pretrain_path': 'https://github.com/microsoft/Semi-supervised-learning/releases/download/v.0.0.0/vit_tiny_patch2_32_mlp_im_1k_32.pth',

    # optimization configs
    'epoch': 1,  # set to 100
    'num_train_iter': 5000,  # set to 102400
    'num_eval_iter': 500,   # set to 1024
    'num_log_iter': 50,    # set to 256
    'optim': 'AdamW',
    'lr': 5e-4,
    'layer_decay': 0.5,
    'batch_size': 16,
    'eval_batch_size': 16,


    # dataset configs
    'dataset': 'cifar10',
    'num_labels': 40,
    'num_classes': 10,
    'img_size': 32,
    'crop_ratio': 0.875,
    'data_dir': './data',
    'ulb_samples_per_class': None,

    # algorithm specific configs
    'hard_label': True,
    'uratio': 2,
    'ulb_loss_ratio': 1.0,

    # device configs
    'gpu': 0,
    'world_size': 1,
    'distributed': False,
    "num_workers": 2,
}
config = get_config(config)

here 1


/bin/sh: 1: netstat: not found


## Step 3: create algorithm and dataset

In [4]:
algorithm = get_algorithm(config,  get_net_builder(config.net, from_name=False), tb_log=None, logger=None)

Files already downloaded and verified
lb count: [4, 4, 4, 4, 4, 4, 4, 4, 4, 4]
ulb count: [5000, 5000, 5000, 5000, 5000, 5000, 5000, 5000, 5000, 5000]
Files already downloaded and verified
unlabeled data number: 50000, labeled data number 40
Create train and test data loaders
[!] data loader keys: dict_keys(['train_lb', 'train_ulb', 'eval'])
_IncompatibleKeys(missing_keys=['head.weight', 'head.bias'], unexpected_keys=[])
Create optimizer and scheduler


In [5]:
dataset_dict = get_dataset(config, config.algorithm, config.dataset, config.num_labels, config.num_classes, data_dir=config.data_dir, include_lb_to_ulb=config.include_lb_to_ulb)
train_lb_loader = get_data_loader(config, dataset_dict['train_lb'], config.batch_size)
train_ulb_loader = get_data_loader(config, dataset_dict['train_ulb'], int(config.batch_size * config.uratio))
eval_loader = get_data_loader(config, dataset_dict['eval'], config.eval_batch_size)

Files already downloaded and verified
lb count: [4, 4, 4, 4, 4, 4, 4, 4, 4, 4]
ulb count: [5000, 5000, 5000, 5000, 5000, 5000, 5000, 5000, 5000, 5000]
Files already downloaded and verified


## Step 4: train and evaluate

In [6]:
# training and evaluation
trainer = Trainer(config, algorithm)
trainer.fit(train_lb_loader, train_ulb_loader, eval_loader)
trainer.evaluate(eval_loader)

Epoch: 0


  _warn_prf(average, modifier, msg_start, len(result))
[2022-12-18 16:02:52,476 INFO] confusion matrix
[2022-12-18 16:02:52,478 INFO] [[0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]]
[2022-12-18 16:02:52,481 INFO] evaluation metric
[2022-12-18 16:02:52,483 INFO] acc: 0.1000
[2022-12-18 16:02:52,484 INFO] precision: 0.0100
[2022-12-18 16:02:52,485 INFO] recall: 0.1000
[2022-12-18 16:02:52,486 INFO] f1: 0.0182


model saved: ./saved_models/fixmatch/latest_model.pth


[2022-12-18 16:02:53,467 INFO] Best acc 0.1000 at epoch 0
[2022-12-18 16:02:53,468 INFO] Training finished.


model saved: ./saved_models/fixmatch/model_best.pth


  _warn_prf(average, modifier, msg_start, len(result))
[2022-12-18 16:04:06,346 INFO] confusion matrix
[2022-12-18 16:04:06,348 INFO] [[0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]]
[2022-12-18 16:04:06,351 INFO] evaluation metric
[2022-12-18 16:04:06,352 INFO] acc: 0.1000
[2022-12-18 16:04:06,353 INFO] precision: 0.0100
[2022-12-18 16:04:06,354 INFO] recall: 0.1000
[2022-12-18 16:04:06,354 INFO] f1: 0.0182


{'acc': 0.1, 'precision': 0.01, 'recall': 0.1, 'f1': 0.01818181818181818}