In [None]:
%load_ext autoreload
%autoreload 2
import os, sys
sys.path.append("..")
import logging
import random
from dotmap import DotMap

import numpy as np
import torch
import torch.nn.parallel
import torch.backends.cudnn as cudnn

In [None]:
from MSMatch.utils import net_builder, get_logger, count_parameters, create_dir_str
from MSMatch.train_utils import TBLog, get_OPT, get_cosine_schedule_with_warmup
from MSMatch.models.fixmatch.FixMatch import FixMatch
from MSMatch.datasets.SSL_Dataset import SSL_Dataset
from MSMatch.datasets.data_utils import get_data_loader

In [None]:
cfg = DotMap()
cfg.dataset = "eurosat_rgb"
cfg.net = "unet"
cfg.batch_size = 4
cfg.p_cutoff = 0.95
cfg.lr = 0.03
cfg.uratio = 7
cfg.weight_decay = 5e-4
cfg.ulb_loss_ratio = 1.0
cfg.seed = 42
cfg.num_labels = 100
cfg.opt = "SGD"
cfg.pretrained = False
cfg.save_dir = "./saved_models"
cfg.save_name = "test"
cfg.ema_m = 0.999

cfg.momentum = 0.9

In [None]:
dir_name = create_dir_str(cfg)
cfg.save_name = os.path.join(cfg.save_name, dir_name)
save_path = os.path.join(cfg.save_dir, cfg.save_name)

In [None]:
# random seed has to be set for the syncronization of labeled data sampling in each process.
random.seed(cfg.seed)
torch.manual_seed(cfg.seed)
np.random.seed(cfg.seed)
cudnn.deterministic = True

In [None]:
# SET save_path and logger
save_path = os.path.join(cfg.save_dir, cfg.save_name)
tb_log = TBLog(save_path, "")
logger_level = "INFO"

logger = get_logger(cfg.save_name, save_path, logger_level)

In [None]:
# Construct Dataset
train_dset = SSL_Dataset(
    name=cfg.dataset, train=True, data_dir=cfg.data_dir, seed=cfg.seed,
)
lb_dset, ulb_dset = train_dset.get_ssl_dset(cfg.num_labels)

In [None]:
cfg.num_classes = train_dset.num_classes
cfg.num_channels = train_dset.num_channels

_eval_dset = SSL_Dataset(
    name=cfg.dataset, train=False, data_dir=cfg.data_dir, seed=cfg.seed,
)
eval_dset = _eval_dset.get_dset()

In [None]:
cfg.bn_momentum = 1.0 - cfg.ema_m
_net_builder = net_builder(
    cfg.net,
    pretrained=cfg.pretrained,
    in_channels=cfg.num_channels,
)

In [None]:
model = FixMatch(
        _net_builder,
        cfg.num_classes,
        cfg.num_channels,
        cfg.ema_m,
        cfg.T,
        cfg.p_cutoff,
        cfg.ulb_loss_ratio,
        cfg.hard_label,
        num_eval_iter=cfg.num_eval_iter,
        tb_log=tb_log,
        logger=logger,
    )

In [None]:
logger.info(f"Number of Trainable Params: {count_parameters(model.train_model)}")

In [None]:
cfg.num_train_iter = 64

In [None]:
optimizer = get_OPT(
    model.train_model, cfg.opt, cfg.lr, cfg.momentum, cfg.weight_decay
)
scheduler = get_cosine_schedule_with_warmup(
    optimizer, cfg.num_train_iter, num_warmup_steps=cfg.num_train_iter * 0
)

In [None]:
model.set_optimizer(optimizer, scheduler)

In [None]:
if torch.cuda.is_available():
        torch.cuda.set_device(cfg.gpu)
        model.train_model = model.train_model.cuda(cfg.gpu)
        model.eval_model = model.eval_model.cuda(cfg.gpu)

In [None]:
logger.info(f"model_arch: {model}")
logger.info(f"Arguments: {cfg}")

In [None]:
cudnn.benchmark = True

In [None]:
# Construct data loader
loader_dict = {}
dset_dict = {"train_lb": lb_dset, "train_ulb": ulb_dset, "eval": eval_dset}

loader_dict["train_lb"] = get_data_loader(
    dset_dict["train_lb"],
    cfg.batch_size,
    data_sampler=cfg.train_sampler,
    num_iters=cfg.num_train_iter,
    num_workers=cfg.num_workers,
    distributed=cfg.distributed,
)

loader_dict["train_ulb"] = get_data_loader(
    dset_dict["train_ulb"],
    cfg.batch_size * cfg.uratio,
    data_sampler=cfg.train_sampler,
    num_iters=cfg.num_train_iter,
    num_workers=4 * cfg.num_workers,
    distributed=cfg.distributed,
)

loader_dict["eval"] = get_data_loader(
    dset_dict["eval"], cfg.eval_batch_size, num_workers=cfg.num_workers
)