In [None]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append("..")

import torch
import MSMatch as mm

cfg = mm.get_default_cfg()
mm.set_seeds(cfg.seed)
logger_level = "INFO"
logger = mm.get_logger(cfg.save_name, cfg.save_path, logger_level)
tb_log = mm.TensorBoardLog(cfg.save_path, "")

In [None]:
# Construct Dataset
train_dset = mm.SSL_Dataset(
    name=cfg.dataset, train=True, data_dir=None, seed=cfg.seed,
)
lb_dset, ulb_dset = train_dset.get_ssl_dset(cfg.num_labels)

cfg.num_classes = train_dset.num_classes
cfg.num_channels = train_dset.num_channels

_eval_dset = mm.SSL_Dataset(
    name=cfg.dataset, train=False, data_dir=None, seed=cfg.seed,
)
eval_dset = _eval_dset.get_dset()

In [None]:
net_builder = mm.get_net_builder(
    cfg.net,
    pretrained=cfg.pretrained,
    in_channels=cfg.num_channels,
)

In [None]:
model = mm.FixMatch(
        net_builder,
        cfg.num_classes,
        cfg.num_channels,
        cfg.ema_m,
        T=0.5,
        p_cutoff=cfg.p_cutoff,
        lambda_u=cfg.ulb_loss_ratio,
        hard_label=True,
        num_eval_iter=cfg.num_eval_iter,
        tb_log=tb_log,
        logger=logger,
    )
logger.info(f"Number of Trainable Params: {sum(p.numel() for p in model.train_model.parameters() if p.requires_grad)}")

In [None]:
cfg.epoch = 10
cfg.num_train_iter = cfg.epoch * cfg.num_eval_iter * 32 // cfg.batch_size

In [None]:
optimizer = mm.get_optimizer(
    model.train_model, cfg.opt, cfg.lr, cfg.momentum, cfg.weight_decay
)
scheduler = mm.get_cosine_schedule_with_warmup(
    optimizer, cfg.num_train_iter, num_warmup_steps=cfg.num_train_iter * 0
)
model.set_optimizer(optimizer, scheduler)
if torch.cuda.is_available():
    cfg.gpu = 0
    torch.cuda.set_device(cfg.gpu)
    model.train_model = model.train_model.cuda(cfg.gpu)
    model.eval_model = model.eval_model.cuda(cfg.gpu)

logger.info(f"model_arch: {model}")
logger.info(f"Arguments: {cfg}")

In [None]:
# Construct data loader
loader_dict = {}
dset_dict = {"train_lb": lb_dset, "train_ulb": ulb_dset, "eval": eval_dset}

loader_dict["train_lb"] = get_data_loader(
    dset_dict["train_lb"],
    cfg.batch_size,
    data_sampler="RandomSampler",
    num_iters=cfg.num_train_iter,
    num_workers=1,
    distributed=False,
)

loader_dict["train_ulb"] = get_data_loader(
    dset_dict["train_ulb"],
    cfg.batch_size * cfg.uratio,
    data_sampler="RandomSampler",
    num_iters=cfg.num_train_iter,
    num_workers=4,
    distributed=False,
)

loader_dict["eval"] = get_data_loader(
    dset_dict["eval"], cfg.eval_batch_size, num_workers=1
)

## set DataLoader on FixMatch
model.set_data_loader(loader_dict)

In [None]:
trainer = model.train
print(cfg)

for epoch in range(cfg.epoch):
    print(epoch)
    trainer(cfg)

In [None]:
model.save_model("latest_model.pth", cfg.save_path)