In [2]:
import os
import torch
from torch import nn
import torch.backends.cudnn as cudnn
from torch.optim import SGD
from torch.utils.data import DataLoader
from glob import glob
from tqdm import tqdm

from dataset.semi import SemiDataset
from model.semseg.deeplabv3plus import DeepLabV3Plus
from supervised import evaluate
from util.ohem import ProbOhemCrossEntropy2d
from util.utils import AverageMeter
from util.dist_helper import setup_distributed
import mlflow

In [3]:
CFG = {
    'IMG_SIZE':1024,
    'crop_size': 256,
    'EPOCHS':50,
    'lr': 0.004,
    'lr_multi': 10.0,
    'BATCH_SIZE': 10,
    'SEED':41,
    'num_worker':4,
    'MEAN' : [0.485, 0.456, 0.406],
    'STD'  : [0.229, 0.224, 0.225],
    'train_magnification':"20X",
    'dataset':"pathology",
    'nclass': 2,
    'criterion':"CELoss",
    'ignore_index': 255,
    'conf_thresh': 0.95,
    'backbone': 'xception',
    'dilations': [6, 12, 18]
}

In [4]:
labeled_data_path = f"/workspace/git_ignore/PDA_labeled_tile(1024)/train/{CFG['train_magnification']}/**/*.png"
unlabeled_data_path = f"/workspace/git_ignore/PDA_unlabeled_tile(1024)/**/*_tiles/*.png"
val_data_path = f"/workspace/git_ignore/PDA_labeled_tile(1024)/validation/{CFG['train_magnification']}/**/*.png"
pth_path = "/workspace/FixMatch/pthfile"

In [5]:
labeled_train_list = sorted(glob(labeled_data_path))
labeld_train_img = labeled_train_list[1::2]
labeld_train_mask = labeled_train_list[0::2]

val_path_list = sorted(glob(val_data_path))
val_img = val_path_list[1::2]
val_mask = val_path_list[0::2]

unlabeled_train_img = sorted(glob(unlabeled_data_path))
unlabeled_train_mask = None

In [6]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
def main():

    rank, world_size = setup_distributed(port="tcp://0.0.0.0:12345")

    cudnn.enabled = True
    cudnn.benchmark = True

    model = DeepLabV3Plus(CFG)

    optimizer = SGD([{'params': model.backbone.parameters(), 'lr': CFG['lr']},
                     {'params': [param for name, param in model.named_parameters() if 'backbone' not in name],
                      'lr': CFG['lr'] * CFG['lr_multi']}], lr=CFG['lr'], momentum=0.9, weight_decay=1e-4)

    local_rank = int(os.environ["LOCAL_RANK"])
    model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
    model.cuda()

    model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], broadcast_buffers=False,
                                                      output_device=local_rank, find_unused_parameters=False)

    if CFG['criterion'] == 'CELoss':
        criterion_l = nn.CrossEntropyLoss(ignore_index = CFG['ignore_index']).cuda(local_rank)
    elif CFG['criterion'] == 'OHEM':
        criterion_l = ProbOhemCrossEntropy2d(ignore_index = CFG['ignore_index']).cuda(local_rank)


    criterion_u = nn.CrossEntropyLoss(reduction='none').cuda(local_rank)

    trainset_u = SemiDataset(img = unlabeled_train_img, mask = unlabeled_train_mask, mode = 'train_u', size = CFG['crop_size'])
    trainset_l = SemiDataset(img = labeld_train_img, mask = labeld_train_mask, mode = 'train_l', size = CFG['crop_size'])
    valset = SemiDataset(img = val_img, mask = val_mask, mode = 'val', size = CFG['crop_size'])

    trainsampler_l = torch.utils.data.distributed.DistributedSampler(trainset_l)
    trainloader_l = DataLoader(trainset_l, batch_size=CFG['batch_size'],
                               pin_memory=True, num_workers=CFG['num_worker'], drop_last=True, sampler=trainsampler_l)
    trainsampler_u = torch.utils.data.distributed.DistributedSampler(trainset_u)
    trainloader_u = DataLoader(trainset_u, batch_size=CFG['batch_size'],
                               pin_memory=True, num_workers=CFG['num_worker'], drop_last=True, sampler=trainsampler_u)
    valsampler = torch.utils.data.distributed.DistributedSampler(valset)
    valloader = DataLoader(valset, batch_size=1, pin_memory=True, num_workers=CFG['num_worker'],
                           drop_last=False, sampler=valsampler)

    total_iters = len(trainloader_u) * CFG['EPOCHS']
    previous_best = 0.0

    if os.path.exists(os.path.join(pth_path, 'latest.pth')):
        checkpoint = torch.load(os.path.join(pth_path, 'latest.pth'))
        model.load_state_dict(checkpoint['model'], map_location = device)
        optimizer.load_state_dict(checkpoint['optimizer'], map_location = device)
        epoch = checkpoint['epoch']
        previous_best = checkpoint['previous_best']
        
        if rank == 0:
            print(f'************ Load from checkpoint at epoch {epoch}')
    
    for epoch in range(1, CFG['EPOCHS']+1):
        if rank == 0:
            print(f'===========> Epoch: {epoch}, LR: {optimizer.param_groups[0]["lr"]:.5f}, Previous best: {previous_best:.2f}')

        total_loss  = AverageMeter()
        total_loss_x = AverageMeter()
        total_loss_s = AverageMeter()
        total_mask_ratio = AverageMeter()

        trainloader_l.sampler.set_epoch(epoch)
        trainloader_u.sampler.set_epoch(epoch)

        loader = zip(trainloader_l, trainloader_u, trainloader_u)

        for i, ((img_x, mask_x),
                (img_u_w, img_u_s, _, ignore_mask, cutmix_box, _),
                (img_u_w_mix, img_u_s_mix, _, ignore_mask_mix, _, _)) in tqdm(enumerate(loader)):

            img_x, mask_x = img_x.cuda(), mask_x.cuda()
            img_u_w, img_u_s = img_u_w.cuda(), img_u_s.cuda()
            ignore_mask, cutmix_box = ignore_mask.cuda(), cutmix_box.cuda()
            img_u_w_mix, img_u_s_mix = img_u_w_mix.cuda(), img_u_s_mix.cuda()
            ignore_mask_mix = ignore_mask_mix.cuda()

            with torch.no_grad():
                model.eval()

                pred_u_w_mix = model(img_u_w_mix).detach()
                conf_u_w_mix = pred_u_w_mix.softmax(dim=1).max(dim=1)[0]
                mask_u_w_mix = pred_u_w_mix.argmax(dim=1)

            img_u_s[cutmix_box.unsqueeze(1).expand(img_u_s.shape) == 1] = img_u_s_mix[cutmix_box.unsqueeze(1).expand(img_u_s.shape) == 1]

            model.train()

            num_lb, num_ulb = img_x.shape[0], img_u_w.shape[0]

            pred_x, pred_u_w = model(torch.cat((img_x, img_u_w))).split([num_lb, num_ulb])
            pred_u_s = model(img_u_s)

            pred_u_w = pred_u_w.detach()
            conf_u_w = pred_u_w.softmax(dim=1).max(dim=1)[0]
            mask_u_w = pred_u_w.argmax(dim=1)

            mask_u_w_cutmixed, conf_u_w_cutmixed, ignore_mask_cutmixed = (mask_u_w.clone(), conf_u_w.clone(), ignore_mask.clone())

            mask_u_w_cutmixed[cutmix_box == 1] = mask_u_w_mix[cutmix_box == 1]
            conf_u_w_cutmixed[cutmix_box == 1] = conf_u_w_mix[cutmix_box == 1]
            ignore_mask_cutmixed[cutmix_box == 1] = ignore_mask_mix[cutmix_box == 1]

            loss_x = criterion_l(pred_x, mask_x)

            loss_u_s = criterion_u(pred_u_s, mask_u_w_cutmixed)
            loss_u_s = loss_u_s * ((conf_u_w_cutmixed >= CFG['conf_thresh']) & (ignore_mask_cutmixed != 255))
            loss_u_s = loss_u_s.sum() / (ignore_mask_cutmixed != 255).sum().item()

            loss = (loss_x + loss_u_s) / 2.0

            torch.distributed.barrier()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss.update(loss.item())
            total_loss_x.update(loss_x.item())
            total_loss_s.update(loss_u_s.item())
            mask_ratio = ((conf_u_w >= CFG['conf_thresh']) & (ignore_mask != 255)).sum().item() / (ignore_mask != 255).sum()
            total_mask_ratio.update(mask_ratio.item())

            iters = (epoch-1) * len(trainloader_u) + i
            lr = CFG['lr'] * (1 - iters / total_iters) ** 0.9
            optimizer.param_groups[0]["lr"] = lr
            optimizer.param_groups[1]["lr"] = lr * CFG['lr_multi']
        
        eval_mode = 'sliding_window' if CFG['dataset'] == 'cityscapes' else 'original'
        mIoU = evaluate(model, valloader, eval_mode, CFG)

        if rank == 0:
            print(f"epoch{epoch}: Total Loss:{total_loss.avg:.2f} Loss x:{total_loss_x.avg:.2f}\
                  Loss s:{total_loss_s.avg:.2f} Mask ratio:{total_mask_ratio.avg:.2f} Val mIOU:{mIoU:.2f}")
            with mlflow.start_run(run_name=CFG["train_magnification"], experiment_id=210481695216345952):
                mlflow.log_metric('Total_Loss', total_loss.avg)
                mlflow.log_metric('Loss_x', total_loss_x.avg)
                mlflow.log_metric('Loss_s', total_loss_s.avg)
                mlflow.log_metric('Mask_ratio', total_mask_ratio.avg)
                mlflow.log_metric('Val mIOU', mIoU)
                mlflow.end_run()


        is_best = mIoU > previous_best
        previous_best = max(mIoU, previous_best)
        if rank == 0:
            checkpoint = {
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'epoch': epoch,
                'previous_best': previous_best,
            }
            torch.save(checkpoint, os.path.join(pth_path, 'latest.pth'))
            if is_best:
                torch.save(checkpoint, os.path.join(pth_path, 'best.pth'))

In [None]:
if __name__ == '__main__':
    main()