# Custom Dataset

## In this tutorial, we provide an example of adapting usb to custom dataset.

In [1]:
import sys
import numpy as np
from torchvision import transforms
from semilearn import get_data_loader, get_net_builder, get_algorithm, get_config, Trainer
from semilearn import split_ssl_data, BasicDataset

## Specifiy configs and define the model

In [17]:
# define configs and create config
config = {
    'algorithm': 'fixmatch',
    'net': 'wrn_28_2',
    'use_pretrain': False,  # todo: add pretrain

    # optimization configs
    'epoch': 3,
    'num_train_iter': 150,
    'num_eval_iter': 50,
    'optim': 'SGD',
    'lr': 0.03,
    'momentum': 0.9,
    'batch_size': 64,
    'eval_batch_size': 64,

    # dataset configs
    'dataset': 'none',
    'num_labels': 40,
    'num_classes': 10,
    'img_size': 32,
    'crop_ratio': 0.875,
    'data_dir': './data',

    # algorithm specific configs
    'hard_label': True,
    'uratio': 3,
    'ulb_loss_ratio': 1.0,

    # device configs
    'gpu': 0,
    'world_size': 1,
    "num_workers": 2,
    'distributed': False,
}
config = get_config(config)

/bin/sh: 1: netstat: not found


In [3]:
# create model and specify algorithm
algorithm = get_algorithm(config,  get_net_builder(config.net, from_name=False), tb_log=None, logger=None)

## Create dataset

In [13]:
# replace with your own code
data = np.random.randint(0, 255, size=3072 * 1000).reshape((-1, 32, 32, 3))
data = np.uint8(data)
target = np.random.randint(0, 10, size=1000)
lb_data, lb_target, ulb_data, ulb_target = split_ssl_data(config, data, target,
                                                          num_labels=config.num_labels,
                                                          num_classes=config.num_classes)

train_transform = transforms.Compose([transforms.RandomHorizontalFlip(),
                                      transforms.RandomCrop(32, padding=int(32 * 0.125), padding_mode='reflect'),
                                      transforms.ToTensor(),
                                      transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])

train_strong_transform = transforms.Compose([transforms.RandomHorizontalFlip(),
                                             transforms.RandomCrop(32, padding=int(32 * 0.125), padding_mode='reflect'),
                                             transforms.ToTensor(),
                                             transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])

lb_dataset = BasicDataset(config.algorithm, lb_data, lb_target, config.num_classes, train_transform, is_ulb=False)
ulb_dataset = BasicDataset(config.algorithm, lb_data, lb_target, config.num_classes, train_transform, is_ulb=True, strong_transform=train_strong_transform)

In [14]:
# replace with your own code
eval_data = np.random.randint(0, 255, size=3072 * 100).reshape((-1, 32, 32, 3))
eval_data = np.uint8(eval_data)
eval_target = np.random.randint(0, 10, size=100)

eval_transform = transforms.Compose([transforms.Resize(32),
                                      transforms.ToTensor(),
                                      transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])

eval_dataset = BasicDataset(config.algorithm, lb_data, lb_target, config.num_classes, eval_transform, is_ulb=False)

In [18]:
# define data loaders
train_lb_loader = get_data_loader(config, lb_dataset, config.batch_size)
train_ulb_loader = get_data_loader(config, ulb_dataset, int(config.batch_size * config.uratio))
eval_loader = get_data_loader(config, eval_dataset, config.eval_batch_size)

## Training and evaluation

In [19]:
# training and evaluation
trainer = Trainer(config, algorithm)
trainer.fit(train_lb_loader, train_ulb_loader, eval_loader)
trainer.evaluate(eval_loader)

Epoch: 0


  _warn_prf(average, modifier, msg_start, len(result))
[2022-08-21 18:38:14,350 INFO] confusion matrix
[2022-08-21 18:38:14,352 INFO] [[0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]]
[2022-08-21 18:38:14,355 INFO] evaluation metric
[2022-08-21 18:38:14,356 INFO] acc: 0.2000
[2022-08-21 18:38:14,357 INFO] precision: 0.0200
[2022-08-21 18:38:14,357 INFO] recall: 0.1000
[2022-08-21 18:38:14,358 INFO] f1: 0.0333


model saved: ./saved_models/fixmatch_none/latest_model.pth
model saved: ./saved_models/fixmatch_none/model_best.pth
Epoch: 1


  _warn_prf(average, modifier, msg_start, len(result))
[2022-08-21 18:38:24,027 INFO] confusion matrix
[2022-08-21 18:38:24,029 INFO] [[0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]]
[2022-08-21 18:38:24,032 INFO] evaluation metric
[2022-08-21 18:38:24,033 INFO] acc: 0.2000
[2022-08-21 18:38:24,033 INFO] precision: 0.0200
[2022-08-21 18:38:24,034 INFO] recall: 0.1000
[2022-08-21 18:38:24,035 INFO] f1: 0.0333


model saved: ./saved_models/fixmatch_none/latest_model.pth
Epoch: 2


  _warn_prf(average, modifier, msg_start, len(result))
[2022-08-21 18:38:33,855 INFO] confusion matrix
[2022-08-21 18:38:33,856 INFO] [[0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]]
[2022-08-21 18:38:33,860 INFO] evaluation metric
[2022-08-21 18:38:33,860 INFO] acc: 0.2000
[2022-08-21 18:38:33,861 INFO] precision: 0.0200
[2022-08-21 18:38:33,862 INFO] recall: 0.1000
[2022-08-21 18:38:33,862 INFO] f1: 0.0333
[2022-08-21 18:38:34,118 INFO] Best acc 0.2000 at epoch 0
[2022-08-21 18:38:34,119 INFO] Training finished.


model saved: ./saved_models/fixmatch_none/latest_model.pth


  _warn_prf(average, modifier, msg_start, len(result))
[2022-08-21 18:38:35,003 INFO] confusion matrix
[2022-08-21 18:38:35,004 INFO] [[0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]]
[2022-08-21 18:38:35,007 INFO] evaluation metric
[2022-08-21 18:38:35,008 INFO] acc: 0.2000
[2022-08-21 18:38:35,009 INFO] precision: 0.0200
[2022-08-21 18:38:35,009 INFO] recall: 0.1000
[2022-08-21 18:38:35,010 INFO] f1: 0.0333


{'acc': 0.2, 'precision': 0.02, 'recall': 0.1, 'f1': 0.03333333333333334}