In [2]:
#!/usr/bin/env python
import os
#os.environ['CUDA_VISIBLE_DEVICES'] = '1'

import logging
import time
from os.path import exists, join, split
from torch import nn
import torch.backends.cudnn as cudnn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
import matplotlib.pyplot as plt
try:
  from modules import batchnormsync
except ImportError:
  pass
import pdb
from tqdm import tqdm

import data_transforms as transforms
from utils import *
from Par_CRF import apply_dcrf_par
from Par_CRF import apply_dcrf_single
from Par_CRF import apply_dcrf
from Par_CRF import save_compute_crf
from DataClass import *

from sklearn.mixture import GaussianMixture
from sklearn.mixture import BayesianGaussianMixture

torch.manual_seed(0)

FORMAT = "[%(asctime)-15s %(filename)s:%(lineno)d %(funcName)s] %(message)s"
logging.basicConfig(format=FORMAT)
logger = logging.getLogger(__name__)

logger.setLevel(logging.DEBUG)

def validate(val_loader, model, epoch, doc_directory, args, print_freq=10):
        batch_time = AverageMeter()

        DOC=Trackometer(epoch)

        # switch to evaluate mode
        model.eval()

        end = time.time()
        
        for i, (input, GT_label, pseudolabels, name) in tqdm(enumerate(val_loader)):
            #====================================================================================================================
            #       Get Image Names, Check Sizes
            #====================================================================================================================
            size=GT_label.shape[2]
            for target in pseudolabels:
                assert target.shape==GT_label.shape

            #Get image name without path
            imname = [(path.split('/')[-1])[:-4] + '.png' for path in name]
            #====================================================================================================================
            #make target float and normalize to range [0,1] for each pixel
            if torch.max(GT_label)!=0:
                GT_label=GT_label.float()/torch.max(GT_label).item()
            else:
                GT_label=GT_label.float()

            #input = input.cuda()
            input = input.cuda()

            input_var = torch.autograd.Variable(input).cuda()
            GT_label_var = torch.autograd.Variable(GT_label).cuda()
            #input_var = torch.autograd.Variable(input)
            #GT_label_var = torch.autograd.Variable(GT_label)

            #====================================================================================================================
            #       Compute Output, normalize it. Optionally apply DCRF
            #====================================================================================================================
            # compute output
            output = model(input_var)[0]

            m=torch.nn.Softmax(dim=1)
            sal_pred=m(output)
            if args.DCRF:
                sal_pred=apply_dcrf(sal_pred, name, Color=args.DCRF=='Color' or args.DCRF=='color')
            else:
                 sal_pred=sal_pred[:, 0, :, :]

            #====================================================================================================================
            #       Update Documentation, Print status in terminal (in respective iterations), save maps (in respective epochs)
            #====================================================================================================================
            DOC.update(sal_pred, GT_label_var, [], [])
            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            FreqPrint=len(val_loader)//print_freq
            if FreqPrint<1:
                FreqPrint=1
            if i % (FreqPrint) == 0:
                logger.info('Test: [{0}/{1}]\t'
                            'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                            'MAE GT {MAE_GT.val:.4f} ({MAE_GT.avg:.4f})\t'
                            'F-Score GT {F_GT.val:.3f} ({F_GT.avg:.3f})'
                            .format(i, len(val_loader), batch_time=batch_time, MAE_GT=DOC.L1_GT, F_GT=DOC.F_GT))

        logger.info('\n\nValidation Epoch {}:\t\tMAE (GT) = {:.1f} %\t\tF-score (GT) = {:.1f} %\n'.format(epoch, DOC.L1_GT.avg*100, DOC.F_GT.avg*100))

        f=open(doc_directory + "loss_val.txt", "a")
        f.write('{}\t{}\t{}\n'.format(epoch, DOC.L1_GT.avg, DOC.F_GT.avg))
        f.close()

        return DOC.F_GT.avg, DOC.L1_GT.avg

def eval_train(train_loader, model, epoch, doc_directory, args, discretization_threshold, refined_labels_directory=None, iter_size=5,
           print_freq=10, TrainMapsOut=False,mva_preds=None,image2indx=None):

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()

    DOC=Trackometer(epoch)
    if TrainMapsOut:
        DOC_plain = Trackometer(epoch)
        DOC_CRF = Trackometer(epoch)
        DOC_MVA = Trackometer(epoch)

    Disc_Thr = discretization_threshold

    # switch to train mode
    model.eval()
    end = time.time()

    output_shape = mva_preds.shape
    raw_preds  = torch.zeros((output_shape[0],2,output_shape[1],output_shape[2]))
    gt_targets =  torch.zeros((mva_preds.shape))
    pseudo_targets =  torch.zeros((mva_preds.shape))

    all_losses= torch.zeros(len(train_loader.dataset))
    with torch.no_grad():
        for i, (index, Data) in tqdm(enumerate(train_loader)):
            #====================================================================================================================
            #       Get time, Image Names, Check Sizes
            #       Normalize Labels, create variables and put them on cuda
            #====================================================================================================================
            # measure data loading time
            data_time.update(time.time() - end)
            #initialize batch data
            batch_data=BatchData(Data, active=True)
            #check dimensions of labels
            batch_data.check_dimension()
            #Make GT label and pseudolabels float and normalize to range [0,1]
            batch_data.normalize_labels()
            #Push input to cuda. Create Variables for input and labels.
            #batch_data.create_vars_on_cuda()
            batch_data.create_vars_on_cuda()

            #====================================================================================================================
            #       Compute Output, normalize it. Optionally apply DCRF
            #====================================================================================================================
            #compute saliency prediction, normalize with softmax. Optionally apply Threshold.
            batch_data.compute_saliency(model, False)

            #====================================================================================================================
            #       If TrainMapsOut: Save Training Images (Before Optimizer Step!)
            #====================================================================================================================
            if TrainMapsOut:
                m = torch.nn.Softmax(dim=1)
                sal_pred_raw = m(batch_data.output)
                gt_targets[image2indx(batch_data.names)] = batch_data.GT_label
                pseudo_targets[image2indx(batch_data.names)] = batch_data.pseudolabels[0]
                assert len(batch_data.pseudolabels) == 1, 'Only one map should be refined at a time in order to not lose information'
                raw_preds[image2indx(batch_data.names)] = sal_pred_raw.detach().cpu()

            #====================================================================================================================
            #       Discretize Targets and apply 'soft thresholding' to saliency predictions.
            #====================================================================================================================
            #Discretize all pseudolabels and apply soft thresholing

            batch_data.discretize_pseudolabels(Disc_Thr)

            #=====================================================================================================================
            #       Compute Loss, Gradient and perform optimizer Step.
            #====================================================================================================================
            #compute the loss (with asymmetries and all) and save to batch_active.loss
            batch_data.compute_loss(mean_loss=False, beta=args.beta_sq)
            loss=torch.mean(batch_data.loss)
            
            for b in range(len(index)):
                 all_losses[index[b]]=batch_data.loss[b]
            #====================================================================================================================
            #       Update Documentation
            #====================================================================================================================
            DOC.update(batch_data.sal_pred, batch_data.GT_label_var, batch_data.sal_pred_list, batch_data.pseudolabels_var)
            #losses is redundant with loss DOC.Loss. Kept for convenience.
            losses.update(loss.data.item(), batch_data.input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            FreqPrint=len(train_loader)//print_freq
            if FreqPrint<1:
                FreqPrint=1
            if i % (FreqPrint) == 0:
                logger.info('Epoch: [{0}][{1}/{2}]\t'
                             'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                             'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                             'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                             'L1 Loss GT {loss_L1_GT.val:.4f} ({loss_L1_GT.avg:.4f})\t'
                             .format(epoch,
                                     i,
                                     len(train_loader),
                                     batch_time=batch_time,
                                     data_time=data_time,
                                     loss=losses,
                                     loss_L1_GT=DOC.L1_GT,
                                     ))

    all_losses = (all_losses-all_losses.min())/(all_losses.max()-all_losses.min())
    all_losses = all_losses.reshape(-1,1)
    # fit a two-component GMM to the loss
    gmm = GaussianMixture(n_components=2,max_iter=10,tol=1e-2,reg_covar=5e-4)
    #gmm = BayesianGaussianMixture(n_components=2,max_iter=10,tol=1e-2,reg_covar=5e-4)
    #pdb.set_trace()
    gmm.fit(all_losses)
    prob = gmm.predict_proba(all_losses)
    prob = prob[:,gmm.means_.argmin()]

    if TrainMapsOut:
        assert refined_labels_directory is not None, 'Directory for output of refined Maps needs to be specified'
        #Create Output Directories
        path_train = refined_labels_directory
        path_plain = join(path_train, 'PlainMaps/')
        path_CRF = join(path_train, 'CRFMaps/')
        path_MVA = join(path_train, 'MVAMaps/')
        for path in  [path_train,path_plain,path_CRF,path_MVA]:
            os.makedirs(path,exist_ok=True)
        name = train_loader.dataset.image_list # the order is kept correctly, 0...2499

        save_compute_crf(path_plain, path_CRF, path_MVA,
                           name, gt_targets, pseudo_targets, raw_preds, mva_preds,
                           image2indx,
                           DOC_plain, DOC_CRF, DOC_MVA,
                           args)

        assert mva_preds.sum()!=0, 'mva_preds was not updated!?'

        logger.info('\n\n\nTraining Maps Extracted in this epoch {}. Results:\n\nPlain:{}\nCRF:{}\nMVA:{}'\
            .format(epoch, str(DOC_plain), str(DOC_CRF), str(DOC_MVA)))

        DOC_plain.write_history(refined_labels_directory + "Results_plain.txt")
        DOC_CRF.write_history(refined_labels_directory + "Results_CRF.txt")
        DOC_MVA.write_history(refined_labels_directory + "Results_MVA.txt")

    else:
        DOC.write_history(doc_directory + "loss_eval_train.txt")

    return prob


def warmup(train_loader, model,  optimizer, epoch, doc_directory, args, discretization_threshold, refined_labels_directory=None, iter_size=5, print_freq=10, TrainMapsOut=False,mva_preds=None,image2indx=None):

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()

    DOC=Trackometer(epoch)
    if TrainMapsOut:
        DOC_plain = Trackometer(epoch)
        DOC_CRF = Trackometer(epoch)
        DOC_MVA = Trackometer(epoch)

    Disc_Thr = discretization_threshold

    # switch to train mode
    model.train()
    end = time.time()

    output_shape = mva_preds.shape
    raw_preds  = torch.zeros((output_shape[0],2,output_shape[1],output_shape[2]))
    gt_targets =  torch.zeros((mva_preds.shape))
    pseudo_targets =  torch.zeros((mva_preds.shape))

    for i, (index, Data) in tqdm(enumerate(train_loader)):
        #====================================================================================================================
        #       Get time, Image Names, Check Sizes
        #       Normalize Labels, create variables and put them on cuda
        #====================================================================================================================
        # measure data loading time
        data_time.update(time.time() - end)
        #initialize batch data
        batch_data=BatchData(Data, active=True)
        #check dimensions of labels
        batch_data.check_dimension()
        #Make GT label and pseudolabels float and normalize to range [0,1]
        batch_data.normalize_labels()
        #Push input to cuda. Create Variables for input and labels.
        #batch_data.create_vars_on_cuda()
        batch_data.create_vars_on_cuda()

        #====================================================================================================================
        #       Compute Output, normalize it. Optionally apply DCRF
        #====================================================================================================================
        #compute saliency prediction, normalize with softmax. Optionally apply Threshold.

        batch_data.compute_saliency(model, False)

        #====================================================================================================================
        #       If TrainMapsOut: Save Training Images (Before Optimizer Step!)
        #====================================================================================================================
        if TrainMapsOut:
            m = torch.nn.Softmax(dim=1)
            sal_pred_raw = m(batch_data.output)
            gt_targets[image2indx(batch_data.names)] = batch_data.GT_label
            pseudo_targets[image2indx(batch_data.names)] = batch_data.pseudolabels[0]
            assert len(batch_data.pseudolabels) == 1, 'Only one map should be refined at a time in order to not lose information'
            raw_preds[image2indx(batch_data.names)] = sal_pred_raw.detach().cpu()

        #====================================================================================================================
        #       Discretize Targets and apply 'soft thresholding' to saliency predictions.
        #====================================================================================================================
        #Discretize all pseudolabels and apply soft thresholing
        batch_data.discretize_pseudolabels(Disc_Thr)

        #=====================================================================================================================
        #       Compute Loss, Gradient and perform optimizer Step.
        #====================================================================================================================
        #compute the loss (with asymmetries and all) and save to batch_active.loss
        batch_data.compute_loss(beta=args.beta_sq)
        loss = batch_data.loss

        #pass iter_size batches before updating grad
        if i%iter_size==0:
            optimizer.zero_grad()
        loss.backward()
        if i%iter_size==iter_size-1:
            optimizer.step()

        #====================================================================================================================
        #       Update Documentation
        #====================================================================================================================
        DOC.update(batch_data.sal_pred, batch_data.GT_label_var, batch_data.sal_pred_list, batch_data.pseudolabels_var)
        #losses is redundant with loss DOC.Loss. Kept for convenience.
        losses.update(loss.data.item(), batch_data.input.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        FreqPrint=len(train_loader)//print_freq
        if FreqPrint<1:
            FreqPrint=1
        if i % (FreqPrint) == 0:
            logger.info('Epoch: [{0}][{1}/{2}]\t'
                         'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                         'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                         'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                         'L1 Loss GT {loss_L1_GT.val:.4f} ({loss_L1_GT.avg:.4f})\t'
                         .format(epoch,
                                 i,
                                 len(train_loader),
                                 batch_time=batch_time,
                                 data_time=data_time,
                                 loss=losses,
                                 loss_L1_GT=DOC.L1_GT,
                                 ))


    if TrainMapsOut:
        assert refined_labels_directory is not None, 'Directory for output of refined Maps needs to be specified'
        #Create Output Directories
        path_train = refined_labels_directory
        path_plain = join(path_train, 'PlainMaps/')
        path_CRF = join(path_train, 'CRFMaps/')
        path_MVA = join(path_train, 'MVAMaps/')
        for path in  [path_train,path_plain,path_CRF,path_MVA]:
            os.makedirs(path,exist_ok=True)
        name = train_loader.dataset.image_list # the order is kept correctly, 0...2499

        save_compute_crf(path_plain, path_CRF, path_MVA,
                           name, gt_targets, pseudo_targets, raw_preds, mva_preds,
                           image2indx,
                           DOC_plain, DOC_CRF, DOC_MVA,
                           args)

        assert mva_preds.sum()!=0, 'mva_preds was not updated!?'

        logger.info('\n\n\nTraining Maps Extracted in this epoch {}. Results:\n\nPlain:{}\nCRF:{}\nMVA:{}'\
            .format(epoch, str(DOC_plain), str(DOC_CRF), str(DOC_MVA)))

        DOC_plain.write_history(refined_labels_directory + "Results_plain.txt")
        DOC_CRF.write_history(refined_labels_directory + "Results_CRF.txt")
        DOC_MVA.write_history(refined_labels_directory + "Results_MVA.txt")

    else:
        DOC.write_history(doc_directory + "loss_train.txt")

    return losses.avg, mva_preds

def train_round(args, target_dirs, output_dir_it, discretization_threshold, MapsOut = False):
    log_handler = logging.FileHandler(output_dir_it+'/log.txt')
    logger.addHandler(log_handler)
    batch_size = args.batch_size
    num_workers = args.workers
    crop_size = args.crop_size

    iter_size_train=args.iter_size

    f=open(output_dir_it + "params.txt", "w")
    f.write("Parameters:\n\n")
    for k, v in args.__dict__.items():
        f.write("{}:\t\t{}\n".format(k, v))
    f.close()

    single_model1 = DRNSeg(args.arch, 2, None, pretrained=True)
    single_model2 = DRNSeg(args.arch, 2, None, pretrained=True)

    #load pretrained model for layers that match in size.
    if args.pretrained:
        print('Loading model 1 state dict\n')
        load_dict=torch.load(args.pretrained)
        own_dict=single_model1.state_dict()
        for name, param in load_dict.items():
            if name not in own_dict:
                warnings.warn(' Model could not be loaded ! Thats bad ! ')
                continue
            if own_dict[name].size() != load_dict[name].size():
                print('Size of pretrained model and your model does not match in {} ({} vs. {}). Layer stays initialized randomly.'\
                    .format(name, own_dict[name].size(), load_dict[name].size()))
            else:
                own_dict[name].copy_(param)
        print('Loading model 2 state dict\n')
        load_dict=torch.load(args.pretrained)
        own_dict=single_model2.state_dict()
        for name, param in load_dict.items():
            if name not in own_dict:
                warnings.warn(' Model could not be loaded ! Thats bad ! ')
                continue
            if own_dict[name].size() != load_dict[name].size():
                print('Size of pretrained model and your model does not match in {} ({} vs. {}). Layer stays initialized randomly.'\
                    .format(name, own_dict[name].size(), load_dict[name].size()))
            else:
                own_dict[name].copy_(param)
        print('\n')

    model1 = torch.nn.DataParallel(single_model1.cuda())
    model2 = torch.nn.DataParallel(single_model2.cuda())

    # Data loading code
    data_dir = args.data_dir
    info = json.load(open(join(data_dir, 'info.json'), 'r'))
    normalize = transforms.Normalize(mean=info['mean'],
                                     std=info['std'])
    t = []
    t.extend([transforms.Resize_Image(crop_size),
              transforms.ToTensor(),
              normalize])
    t_val = t

    
    train_loader = torch.utils.data.DataLoader(
        SegList(args, data_dir, 'train', transforms.Compose(t),
        image_dir= join(args.root_dir, 'Data/01_img/'), gt_dir= join(args.root_dir, 'Data/02_gt/'),
        targets = target_dirs, list_dir=args.data_dir, out_name=True),
        batch_size=batch_size, shuffle=True, num_workers=num_workers,
        pin_memory=True, drop_last=True
     )

    if not MapsOut:
        val_loader = torch.utils.data.DataLoader(
            SegList(args, data_dir, 'val', transforms.Compose(t_val),
            image_dir= join(args.root_dir, 'Data/01_img/'), gt_dir= join(args.root_dir, 'Data/02_gt/'),
            targets = None, list_dir=args.data_dir, out_name=True),
            batch_size=batch_size, shuffle=False, num_workers=num_workers,
            pin_memory=True, drop_last=True
        )

    optimizer1 = torch.optim.Adam(single_model1.optim_parameters(), lr=args.lr)
    optimizer2 = torch.optim.Adam(single_model2.optim_parameters(), lr=args.lr)

    cudnn.benchmark = True
    start_epoch = 0

    # optionally resume from a checkpoint
    if args.resume_model1:
        if os.path.isfile(args.resume_model1):
            print("=> loading checkpoint '{}'".format(args.resume_model1))
            checkpoint = torch.load(args.resume_model1)
            start_epoch = checkpoint['epoch']
            model1.load_state_dict(checkpoint['state_dict'])
            optimizer1.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume_model1, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume_model1))

    if args.resume_model2:
        if os.path.isfile(args.resume_model2):
            print("=> loading checkpoint '{}'".format(args.resume_model2))
            checkpoint = torch.load(args.resume_model2)
            start_epoch = checkpoint['epoch']
            model2.load_state_dict(checkpoint['state_dict'])
            optimizer2.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume_model2, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume_model2))


    mva_preds,image2indx = init_mva_preds(args,train_loader)
    for epoch in range(start_epoch, args.epochs):
        lr = adjust_learning_rate(args, optimizer1, epoch)
        lr = adjust_learning_rate(args, optimizer2, epoch)
        logger.info('Epoch: [{0}]\tlr {1:.2e}'.format(epoch, lr))

        if 0:#epoch < args.warm_up:
            print('Warmup Net1')
            trainloss, mva_preds = warmup( train_loader,
                                model1,
                                optimizer1,
                                epoch,
                                output_dir_it,
                                args,
                                discretization_threshold,
                                refined_labels_directory=output_dir_it,
                                iter_size=iter_size_train,
                                print_freq=2,
                                TrainMapsOut=MapsOut,
                                mva_preds=mva_preds,
                                image2indx=image2indx)
            assert torch.isnan(mva_preds.sum(dim=(1,2))).sum().item() == 0, 'images are droped since size of data set is not a multiple of batch size'
            print('Warmup Net2')
            trainloss, mva_preds = warmup( train_loader,
                                model2,
                                optimizer2,
                                epoch,
                                output_dir_it,
                                args,
                                discretization_threshold,
                                refined_labels_directory=output_dir_it,
                                iter_size=iter_size_train,
                                print_freq=2,
                                TrainMapsOut=MapsOut,
                                mva_preds=mva_preds,
                                image2indx=image2indx)
            assert torch.isnan(mva_preds.sum(dim=(1,2))).sum().item() == 0, 'images are droped since size of data set is not a multiple of batch size'
            

        else:
            '''
            print('Eval Train Net1')
            prob1  = eval_train( train_loader,
                   model1,
                   epoch,
                   output_dir_it,
                   args,
                   discretization_threshold,
                   refined_labels_directory=output_dir_it,
                   iter_size=iter_size_train,
                   print_freq=2,
                   TrainMapsOut=MapsOut,
                   mva_preds=mva_preds,
                   image2indx=image2indx)
            assert torch.isnan(mva_preds.sum(dim=(1,2))).sum().item() == 0, 'images are droped since size of data set is not a multiple of batch size'


            print('Eval Train Net2')
            prob2 = eval_train( train_loader,
                    model2,
                    epoch,
                    output_dir_it,
                    args,
                    discretization_threshold,
                    refined_labels_directory=output_dir_it,
                    iter_size=iter_size_train,
                    print_freq=2,
                    TrainMapsOut=MapsOut,
                    mva_preds=mva_preds,
                    image2indx=image2indx)
            assert torch.isnan(mva_preds.sum(dim=(1,2))).sum().item() == 0, 'images are droped since size of data set is not a multiple of batch size'
            #pdb.set_trace()
            '''
            #remove me
            prob1 = torch.ones(2500)
            prob2 = torch.ones(2500)

            pred1 = (prob1 > args.p_threshold)
            pred2 = (prob2 > args.p_threshold)

            labeled_train_loader = torch.utils.data.DataLoader(
                    SegList(args, data_dir, 'labeled', transforms.Compose(t),
                    image_dir= join(args.root_dir, 'Data/01_img/'), gt_dir= join(args.root_dir, 'Data/02_gt/'),
                    targets = target_dirs, list_dir=args.data_dir, out_name=True,
                    pred = pred2, prob = prob2),
                    batch_size=batch_size, shuffle=True, num_workers=num_workers,
                    pin_memory=True, drop_last=True
                 )
            #unlabeled_train_loader = torch.utils.data.DataLoader(
            #        SegList(args, data_dir, 'unlabeled', transforms.Compose(t),
            #        image_dir= join(args.root_dir, 'Data/01_img/'), gt_dir= join(args.root_dir, 'Data/02_gt/'),
            #        targets = target_dirs, list_dir=args.data_dir, out_name=True,
            #        pred = pred2, prob = prob2),
            #        batch_size=batch_size, shuffle=True, num_workers=num_workers,
            #        pin_memory=True, drop_last=True
            #        )
            
            print('Training Net1')
            #pdb.set_trace()
            trainloss, mva_preds = train(labeled_train_loader,
                                train_loader,
                                model1,
                                model2,
                                optimizer1,
                                epoch,
                                output_dir_it,
                                args,
                                discretization_threshold,
                                refined_labels_directory=output_dir_it,
                                iter_size=iter_size_train,
                                print_freq=3,
                                TrainMapsOut=MapsOut,
                                mva_preds=mva_preds,
                                image2indx=image2indx)
            assert torch.isnan(mva_preds.sum(dim=(1,2))).sum().item() == 0, 'images are droped since size of data set is not a multiple of batch size'

            labeled_train_loader = torch.utils.data.DataLoader(
                SegList(args, data_dir, 'labeled', transforms.Compose(t),
                image_dir= join(args.root_dir, 'Data/01_img/'), gt_dir= join(args.root_dir, 'Data/02_gt/'),
                targets = target_dirs, list_dir=args.data_dir, out_name=True,
                pred = pred1, prob = prob1),
                batch_size=batch_size, shuffle=True, num_workers=num_workers,
                pin_memory=True, drop_last=True
             )
            
            #unlabeled_train_loader = torch.utils.data.DataLoader(
            #    SegList(args, data_dir, 'unlabeled', transforms.Compose(t),
            #    image_dir= join(args.root_dir, 'Data/01_img/'), gt_dir= join(args.root_dir, 'Data/02_gt/'),
            #    targets = target_dirs, list_dir=args.data_dir, out_name=True,
            #    pred = pred1, prob = prob1),
            #    batch_size=batch_size, shuffle=True, num_workers=num_workers,
            #    pin_memory=True, drop_last=True
            # )

            print('Training Net2')
            trainloss, mva_preds = train(labeled_train_loader,
                                train_loader,
                                model2,
                                model1,
                                optimizer2,
                                epoch,
                                output_dir_it,
                                args,
                                discretization_threshold,
                                refined_labels_directory=output_dir_it,
                                iter_size=iter_size_train,
                                print_freq=3,
                                TrainMapsOut=MapsOut,
                                mva_preds=mva_preds,
                                image2indx=image2indx)
            assert torch.isnan(mva_preds.sum(dim=(1,2))).sum().item() == 0, 'images are droped since size of data set is not a multiple of batch size'
            if not MapsOut:
                val_loader = torch.utils.data.DataLoader(
                    SegList(args, data_dir, 'val', transforms.Compose(t_val),
                    image_dir= join(args.root_dir, 'Data/01_img/'), gt_dir= join(args.root_dir, 'Data/02_gt/'),
                    targets = None, list_dir=args.data_dir, out_name=True),
                    batch_size=batch_size, shuffle=False, num_workers=num_workers,
                    pin_memory=True, drop_last=True
                    )

                # evaluate on validation set
                print('Validation Net1')
                F_beta, GT_loss_L1 = validate(val_loader, model1, epoch, output_dir_it, args, print_freq=6)
                print('Validation Net2')
                F_beta, GT_loss_L1 = validate(val_loader, model2, epoch, output_dir_it, args, print_freq=6)

            checkpoint_model1_path_latest = output_dir_it + 'checkpoint_model1_{:03d}.pth.tar'.format(epoch + 1)
            checkpoint_model2_path_latest = output_dir_it + 'checkpoint_model2_{:03d}.pth.tar'.format(epoch + 1)

            if (epoch + 1) % args.checkpoint_freq == 0 or epoch==args.epochs:
                torch.save({
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model1.state_dict(),
                    'optimizer' : optimizer1.state_dict(),
                }, checkpoint_model1_path_latest)
            if (epoch + 1) % args.checkpoint_freq == 0 or epoch==args.epochs:
                torch.save({
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model2.state_dict(),
                    'optimizer' : optimizer2.state_dict(),
                }, checkpoint_model2_path_latest)

    return trainloss

def test(args, eval_data_loader, model, num_classes,
          output_dir='pred', save_vis=False):
    with torch.no_grad():
        model.eval()
        batch_time = AverageMeter()
        data_time = AverageMeter()
        end = time.time()
        hist = np.zeros((num_classes, num_classes))

        DOC = Trackometer(0)

        f = open(output_dir + 'Results.txt', 'w')
        for i, (image, GT_label, pseudolabels, name) in enumerate(eval_data_loader):
            data_time.update(time.time() - end)
            #====================================================================================================================
            #       Get Image Names
            #====================================================================================================================

            #Get image name without path
            imname = [(path.split('/')[-1])[:-4] + '.png' for path in name]
            #====================================================================================================================
            #make target float and normalize to range [0,1] for each pixel
            GT_label=GT_label.float()/torch.max(GT_label).item()
            for dummy_ind in range(len(pseudolabels)):
                pseudolabels[dummy_ind]=pseudolabels[dummy_ind].float()/255.0

            #pad the image, s.t. width and height are both multiples of 8.
            #This way, the output will have the same shape as the image. The padded part will be thrown away in the output.
            #Get original width and height
            w0=image.shape[2]
            h0=image.shape[3]
            #Get new width, height, that is a multiple of n=8
            n=8
            dw = -w0%n
            dh = -h0%n
            w1 = w0+dw
            h1 = h0+dh
            #pad on the right the missing width and on the bottom the missing height.
            pad_reflection=nn.ReflectionPad2d((0,dh,0,dw))
            im_new=pad_reflection(image)
            #check if padding went well.
            assert torch.all(torch.eq(image,im_new[:,:,:w0,:h0]))

            image_var = Variable(im_new, requires_grad=False)

            final = model(image_var)[0]
            _, pred = torch.max(final, 1)

            #make continuous prediction, then cast it to unit8
            m=torch.nn.Softmax(dim=1)
            sal_pred=m(final)
            if args.DCRF:
                sal_pred=apply_dcrf(sal_pred[:,:,:w0,:h0], name, Color=(args.DCRF=='Color' or args.DCRF=='color'))
            else:
                sal_pred=sal_pred[:, 0, :w0, :h0]

            assert sal_pred.shape==GT_label.shape

            DOC.update(sal_pred, GT_label.cuda(), [sal_pred], [GT_label.cuda()])
            #DOC.update(sal_pred, GT_label, [sal_pred], [GT_label])


            sal_pred = (sal_pred*255).int().cpu().data.numpy()
            GT_label = (GT_label*255).int().cpu().data.numpy()

            if save_vis:
                save_output_images(sal_pred, imname, output_dir)
                save_output_images(GT_label, imname, output_dir, name_suffix='_GT')

            batch_time.update(time.time() - end)

            end = time.time()
            if i%50 == 0:
                logger.info('Eval: [{0}/{1}]\t'
                            'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                            'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                            'L1 {L1.val:.3f} ({L1.avg:.3f})\t'
                            'F-measure {F.val:.3f} ({F.avg:.3f})\t'
                            .format(i, len(eval_data_loader), batch_time=batch_time, data_time=data_time, L1=DOC.L1_GT, F=DOC.F_GT))

            f.write('{}\t{}\t{}\t{}\n'.format(DOC.L1_GT.val, DOC.F_GT.val, DOC.prec_GT.val, DOC.recall_GT.val))

        f.close()
        print(DOC)

        return DOC


def test_saliency(args):
    batch_size = args.batch_size
    num_workers = args.workers
    test_dir = join(args.root_dir, 'Doc/Test/')

    for k, v in args.__dict__.items():
        print(k, ':', v)

    single_model = DRNSeg(args.arch, 2, pretrained_model=None, pretrained=True)
    model = torch.nn.DataParallel(single_model).cuda()

    data_dir = args.data_dir
    info = json.load(open(join(data_dir, 'info.json'), 'r'))
    normalize = transforms.Normalize(mean=info['mean'], std=info['std'])

    dataset = SegList_test(args, data_dir, 'test', transforms.Compose([
        transforms.Resize_Image(args.crop_size),
        transforms.ToTensor(),
        normalize,
    ]), image_dir= join(args.root_dir, 'Data/01_img/'), gt_dir= join(args.root_dir, 'Data/02_gt/'),
    list_dir=args.data_dir, out_name=True)
    test_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size, shuffle=False, num_workers=num_workers,
        pin_memory=False
    )

    cudnn.benchmark = True

    if not args.resume:
        args.resume = join(args.root_dir + 'Doc/Phase_II_Fusion/checkpoint_200.pth.tar')
    start_epoch = 0
    if args.resume:
        if os.path.isfile(args.resume):
            logger.info("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            logger.info("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        else:
            logger.info("=> no checkpoint found at '{}'".format(args.resume))

    if not exists(test_dir):
        os.makedirs(test_dir)

    DOC = test(args, test_loader, model, 2, save_vis=True, output_dir=test_dir)

    logger.info('MAE = %f', DOC.L1_GT.avg)

    return DOC



def test_all_datasets(args):
    #for each dataset, we have a dictionary, that contains
    #   - the name
    #   - the Parameters directory
    #   - name of data (***_names.txt file in Param directory)
    #   - the batch size for testing
    test_dir = join(args.root_dir, 'Doc/Test_all/')
    if not exists(test_dir):
        os.makedirs(test_dir)

    datasets = []

    MSRAB = ECSSD = DUT = SED2 = THUR = False

    MSRAB = True
    ECSSD = True
    DUT = True
    SED2 = True
    '''
    THUR = True
    '''
    if not args.resume:
        args.resume = join(args.root_dir + 'Doc/Phase_II_Fusion/checkpoint_latest.pth.tar')

    #01_MSRAB
    if MSRAB:
        datasets.append({ \
                'name' : '01_MSRAB', \
                'param_dir' : '/notebookd/DeepUSPS/Data/01_MSRAB/Parameters/', \
                'data_prefix' : 'test', \
                'batch_size' : 1 \
                })
    #02_ECSSD
    if ECSSD:
        datasets.append({ \
                'name' : '02_ECSSD', \
                'param_dir' : '/media/bigData/_80_User/Dax/UnsupSD/SD_beta/Data/02_ECSSD/Parameters/', \
                'data_prefix' : 'all', \
                'batch_size' : 1 \
                })

    #03_DUT
    if DUT:
        datasets.append({ \
                'name' : '03_DUT', \
                'param_dir' : '/media/bigData/_80_User/Dax/UnsupSD/SD_beta/Data/03_DUT/Parameters/', \
                'data_prefix' : 'all', \
                'batch_size' : 1 \
                })

    #04_SED2
    if SED2:
        datasets.append({ \
                'name' : '04_SED2', \
                'param_dir' : '/media/bigData/_80_User/Dax/UnsupSD/SD_beta/Data/04_SED2/Parameters/', \
                'data_prefix' : 'all', \
                'batch_size' : 1 \
                })

    #06_THUR
    if THUR:
        datasets.append({ \
                'name' : '06_THUR', \
                'param_dir' : '/media/bigData/_80_User/Dax/UnsupSD/SD_beta/Data/06_THUR/Parameters/', \
                'data_prefix' : 'GT', \
                'batch_size' : 1 \
                })

    #Iterate through the dictionaries and test each dataset
    for dataset in datasets:
        #set correct arguments
        args.dataset_name = dataset['name']
        args.data_dir = dataset['param_dir']
        args.test_data = dataset['data_prefix']
        args.batch_size = dataset['batch_size']
        DOC = test_saliency(args)
        dataset['Result'] = DOC

    print("\n\n\t\t\tMAE\t\tF\t\tprecision\trecall")
    for dataset in datasets:
        print("{name}: \t\t{DOC.L1_GT.avg:.3f}\t\t{DOC.F_GT.avg:.3f}\t\t{DOC.prec_GT.avg:.3f}\t\t{DOC.recall_GT.avg:.3f}"\
            .format(name=dataset['name'], DOC=dataset['Result']) )


    result_file = join(test_dir, 'Test_Results.txt')
    f = open(result_file, 'a')
    f.write("\t\t\tMAE\t\tF\t\tprecision\trecall\n")
    for dataset in datasets:
        f.write("{name}: \t\t{DOC.L1_GT.avg:.5f}\t\t{DOC.F_GT.avg:.5f}\t\t{DOC.prec_GT.avg:.5f}\t\t{DOC.recall_GT.avg:.5f}\n"\
            .format(name=dataset['name'], DOC=dataset['Result']) )
    f.close()




def train_unsupervised(args):
    #====================================================================================================================
    #       Phase I: Refinement of Pseodulabels
    #====================================================================================================================.
    #learning_rates_refinement = [1e-6, 2e-6, 5e-6]
    learning_rates_refinement = [1e-6, 2e-6]
    args.beta_sq = 1.0
    args.epochs = 25
    args.iter_size = min(1, int(40/args.batch_size))
    num_iterations_refinement = len(learning_rates_refinement)
    doc_directory = join(args.root_dir, 'Doc/')
    refined_labels_directory = join(doc_directory, 'Phase_I_Refined_Maps/')
    os.makedirs(doc_directory, exist_ok=True)
    os.makedirs(refined_labels_directory, exist_ok=True)
    pseudolabels = [
        {'name': 'MC', 'data_directory': join(args.root_dir, 'Data/03_mc/'), 'discretization_threshold': 0.31, \
            'F-score_plain': [71.65], 'MAE_plain': [14.41], 'F-score_mva': [71.65], 'MAE_mva': [14.41]},
        {'name': 'HS', 'data_directory': join(args.root_dir, 'Data/04_hs/'), 'discretization_threshold': 0.36, \
            'F-score_plain': [71.29], 'MAE_plain': [16.09], 'F-score_mva': [71.29], 'MAE_mva': [16.09]},
        {'name': 'DSR', 'data_directory': join(args.root_dir, 'Data/05_dsr/'), 'discretization_threshold': 0.23, \
            'F-score_plain': [72.27], 'MAE_plain': [12.07], 'F-score_mva': [72.27], 'MAE_mva': [12.07]},
        {'name': 'RBD', 'data_directory': join(args.root_dir, 'Data/06_rbd/'), 'discretization_threshold': 0.25, \
            'F-score_plain': [75.08], 'MAE_plain': [11.71], 'F-score_mva': [75.08], 'MAE_mva': [11.71]}
    ]
    target_dirs_refined = []
    for pseudolabel in pseudolabels:
        #directory with input targets
        target_dir = [pseudolabel['data_directory']]
        #directory for output targets
        output_dir = join(refined_labels_directory, pseudolabel['name'] + '/')
        os.makedirs(output_dir, exist_ok=True)
        #discretization threshold for this particular pseudolabel
        discretization_threshold = pseudolabel['discretization_threshold']
        for i in range(num_iterations_refinement):
            args.lr = learning_rates_refinement[i]
            #output directory for current iteration
            output_dir_it = join(output_dir, 'Iteration_' + str(i+1) + '/')
            os.makedirs(output_dir_it, exist_ok=True)
            #train_round(args, target_dir, output_dir_it, discretization_threshold, MapsOut = True)
            #after one iteration, discretization threshold does not matter too much
            discretization_threshold = 0.5
            target_dir = [join(output_dir_it, 'MVAMaps/')]
            #get Results
            #update_plots(refined_labels_directory, output_dir_it, pseudolabel)
        target_dirs_refined.append(target_dir[0])

    #====================================================================================================================
    #       Phase II: Fusion of refine Pseudolabels
    #====================================================================================================================
    phase_2_directory = join(doc_directory, 'Phase_II_Fusion/')
    os.makedirs(phase_2_directory, exist_ok=True)
    args.epochs = 200
    args.beta_sq = 4.0
    args.lr = 1e-4
    args.iter_size = min(1, int(100/args.batch_size))
    train_round(args, target_dirs_refined, phase_2_directory, 0.5, MapsOut = False)

    create_phase2_plots(phase_2_directory)

def main():
    args = SimpleNamespace(arch='drn_d_22', batch_size=8\
                           , beta_sq=1.0, bn_sync=False, checkpoint_freq=25\
                           , cmd='train'\
                           , crop_size=432\
                           , data_dir='/notebooks/deepdividemix/Parameters/'\
                           , pretrained='/notebooks/deepdividemix/Pretrained_Models/drn_pretraining/drn_d_22_cityscapes.pth'\
                           , iter_size=1 \
                           , root_dir='/notebooks/deepdividemix/' \
                           , resume_model1='/notebooks/deepdividemix/Doc/Phase_II_Fusion/checkpoint_model1_075.pth.tar' \
                           , resume_model2='/notebooks/deepdividemix/Doc/Phase_II_Fusion/checkpoint_model2_075.pth.tar' \
                           , warm_up=1 \
                           , p_threshold=0.5 \
                           , T=0.5 \
                           , alpha=4 \
                           , DCRF=None \
                           , lambda_u=25 \
                           , workers=0)
    
    if args.cmd == 'train':
         '''if os.path.isdir(join(args.root_dir, 'Doc')):
             print("\n\n\n" + "="*100 + "\n\n\t\tWarning! This doc path seems to be used!\n\t\tPress \"c\" to continue and overwrite existing files, \"exit\" to abort.\n\n" + "="*100 + "\n\n\n")
             pdb.set_trace()'''
         train_unsupervised(args)

    elif args.cmd == 'test':
        args.dataset_name='01_MSRAB'
        test_saliency(args)

    elif args.cmd == 'test_all':
        test_all_datasets(args)
'''                           
, resume_model1='/notebooks/deepdividemix/Doc/Phase_II_Fusion/checkpoint_model1_100.pth.tar' \
                           , resume_model2='/notebooks/deepdividemix/Doc/Phase_II_Fusion/checkpoint_model2_100.pth.tar' \
                           , resume_model1='' \
                           , resume_model2='' \

                           
                           '''


"                           \n, resume_model1='/notebooks/deepdividemix/Doc/Phase_II_Fusion/checkpoint_model1_100.pth.tar'                            , resume_model2='/notebooks/deepdividemix/Doc/Phase_II_Fusion/checkpoint_model2_100.pth.tar'                            , resume_model1=''                            , resume_model2='' \n                           \n                           "

In [3]:

def linear_rampup(current, warm_up, rampup_length=16):
     current = np.clip((current-warm_up) / rampup_length, 0.0, 1.0)
     return 25*float(current)

class SemiLoss(object):
    def __call__(self, outputs_x, targets_x, outputs_u, targets_u, epoch, warm_up):
        probs_u = torch.softmax(outputs_u, dim=1)
        #Lx = -torch.mean(torch.sum(F.log_softmax(outputs_x, dim=1) * targets_x, dim=1))
        Lu = torch.mean((probs_u - targets_u)**2)

        return Lu, linear_rampup(epoch,warm_up)

criterion = SemiLoss()

def train(labeled_train_loader, unlabeled_train_loader, model1, model2, optimizer, epoch, doc_directory, args, discretization_threshold, refined_labels_directory=None, iter_size=5, print_freq=10, TrainMapsOut=False,mva_preds=None,image2indx=None):

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    batch_size = args.batch_size

    DOC=Trackometer(epoch)
    if TrainMapsOut:
        DOC_plain = Trackometer(epoch)
        DOC_CRF = Trackometer(epoch)
        DOC_MVA = Trackometer(epoch)

    Disc_Thr = discretization_threshold

    # switch to train mode
    model1.train()
    model2.eval()
    end = time.time()

    unlabeled_train_iter = iter(unlabeled_train_loader)
    num_iter = (len(labeled_train_loader.dataset)//batch_size)+1

    output_shape = mva_preds.shape
    raw_preds  = torch.zeros((output_shape[0],2,output_shape[1],output_shape[2]))
    gt_targets =  torch.zeros((mva_preds.shape))
    pseudo_targets =  torch.zeros((mva_preds.shape))

    unlabeled_dataloader_empty = False
    unlabeled_dataloader_empty = True #Removeme

    Lu = 0
    for i, (index_x, Data_x, w_x) in tqdm(enumerate(labeled_train_loader)):
        if not unlabeled_dataloader_empty:
            try:
                index_u, Data_u = unlabeled_train_iter.next()

                #UnLabeled data procesing start here
                data_time.update(time.time() - end)
                #initialize batch data
                batch_data_unlabeled_net1=BatchData(Data_u, active=True)
                batch_data_unlabeled_net2=BatchData(Data_u, active=True)
                #check dimensions of labels
                batch_data_unlabeled_net1.check_dimension()
                batch_data_unlabeled_net2.check_dimension()

                batch_data_unlabeled_net1.create_vars_on_cuda()
                batch_data_unlabeled_net2.create_vars_on_cuda()
        
                batch_data_unlabeled_net1.compute_saliency(model1, False)
                batch_data_unlabeled_net2.compute_saliency(model2, False)
    
                outputs_u1 = batch_data_unlabeled_net1.sal_pred        
                outputs_u2 = batch_data_unlabeled_net2.sal_pred
                        
                pu = (outputs_u1 + outputs_u2) / 2
            
                ptu = pu**(1/args.T) # temparature sharpening
                #pdb.set_trace()

                #targets_u = F.normalize(ptu)
                targets_u = ptu

                for dummy_ind in range(len(batch_data_unlabeled_net1.pseudolabels_var)):
                    batch_data_unlabeled_net1.pseudolabels_var[dummy_ind]=targets_u

                batch_data_unlabeled_net1.compute_loss(beta=args.beta_sq)

                Lu = batch_data_unlabeled_net1.loss

            except:
                # First time in here so set unlabeled dataloader to empty
                Lu = 0
                unlabeled_dataloader_empty = True

        #Labeled data procesing start here
        w_x = w_x.view(-1,1).type(torch.FloatTensor).cuda()
        #initialize batch data
        #pdb.set_trace()
        batch_data_labeled_net1=BatchData(Data_x, active=True)
        batch_data_labeled_net2=BatchData(Data_x, active=True)
        #check dimensions of labels
        batch_data_labeled_net1.check_dimension()
        batch_data_labeled_net2.check_dimension()
        #Make GT label and pseudolabels float and normalize to range [0,1]
        #pdb.set_trace()
        batch_data_labeled_net1.normalize_labels()
        #batch_data_labeled_net2.normalize_labels()
        #Push input to cuda. Create Variables for input and labels.
        #batch_data.create_vars_on_cuda()
        batch_data_labeled_net1.create_vars_on_cuda()
        batch_data_labeled_net2.create_vars_on_cuda()

        batch_data_labeled_net1.compute_saliency(model1, False)
        batch_data_labeled_net2.compute_saliency(model2, False)

        outputs_x1 = batch_data_labeled_net1.sal_pred
        outputs_x2 = batch_data_labeled_net2.sal_pred

        inputs_x = batch_data_labeled_net1.input_var

        batch_data_labeled_net1.merge_pseudolabels()
        labels_x = batch_data_labeled_net1.merged_labels
            
        px = (outputs_x1 + outputs_x2) / 2
            
        px = torch.stack([torch.add(torch.mul(w_x[i],labels_x[i]),torch.mul(1-w_x[i],px[i])) for i in range(len(w_x))])
        ptx = px**(1/args.T) # temparature sharpening
        #original targets_x = F.normalize(ptx)
        targets_x = ptx
        
        for dummy_ind in range(len(batch_data_labeled_net1.pseudolabels_var)):
            batch_data_labeled_net1.pseudolabels_var[dummy_ind]=targets_x

        #pdb.set_trace()
        batch_data_labeled_net1.discretize_pseudolabels(Disc_Thr)

        batch_data_labeled_net1.compute_loss(beta=args.beta_sq)

        Lx = batch_data_labeled_net1.loss

        lamb = linear_rampup(epoch,args.warm_up)
        loss = Lx + lamb*Lu

        FreqPrint=len(labeled_train_loader)//print_freq
        if i % (FreqPrint) == 0:
            logger.info('\rEpoch [%3d] Iter[%3d/%3d]\t Total loss: %.4f Labeled loss: %.4f Unlabeled lamb*loss: %.4f'%(epoch, i+1, len(labeled_train_loader), loss, Lx, lamb*Lu))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()            
        
    return losses.avg, mva_preds

In [None]:
from types import SimpleNamespace
import torch.nn.functional as F
main()

Loading model 1 state dict

Size of pretrained model and your model does not match in seg.weight (torch.Size([2, 512, 1, 1]) vs. torch.Size([19, 512, 1, 1])). Layer stays initialized randomly.
Size of pretrained model and your model does not match in seg.bias (torch.Size([2]) vs. torch.Size([19])). Layer stays initialized randomly.
Size of pretrained model and your model does not match in up.weight (torch.Size([2, 1, 16, 16]) vs. torch.Size([19, 1, 16, 16])). Layer stays initialized randomly.
Loading model 2 state dict

Size of pretrained model and your model does not match in seg.weight (torch.Size([2, 512, 1, 1]) vs. torch.Size([19, 512, 1, 1])). Layer stays initialized randomly.
Size of pretrained model and your model does not match in seg.bias (torch.Size([2]) vs. torch.Size([19])). Layer stays initialized randomly.
Size of pretrained model and your model does not match in up.weight (torch.Size([2, 1, 16, 16]) vs. torch.Size([19, 1, 16, 16])). Layer stays initialized randomly.


=>

[2021-10-19 16:11:47,094 <ipython-input-2-84e4c762e3d8>:514 train_round] Epoch: [75]	lr 6.55e-05
[2021-10-19 16:11:47,131 utils.py:189 read_lists] labeled dataset size is 2500


Training Net1


Epoch [ 75] Iter[  1/312]	 Total loss: 0.0092 Labeled loss: 0.0092 Unlabeled lamb*loss: 0.0000
Epoch [ 75] Iter[105/312]	 Total loss: 0.0120 Labeled loss: 0.0120 Unlabeled lamb*loss: 0.0000
Epoch [ 75] Iter[209/312]	 Total loss: 0.0103 Labeled loss: 0.0103 Unlabeled lamb*loss: 0.0000
312it [03:46,  1.47it/s]
[2021-10-19 16:15:35,552 utils.py:189 read_lists] labeled dataset size is 2500


Training Net2


Epoch [ 75] Iter[  1/312]	 Total loss: 0.0084 Labeled loss: 0.0084 Unlabeled lamb*loss: 0.0000
Epoch [ 75] Iter[105/312]	 Total loss: 0.0108 Labeled loss: 0.0108 Unlabeled lamb*loss: 0.0000
Epoch [ 75] Iter[209/312]	 Total loss: 0.0108 Labeled loss: 0.0108 Unlabeled lamb*loss: 0.0000
312it [03:27,  1.50it/s]
0it [00:00, ?it/s]

Validation Net1


[2021-10-19 16:19:04,801 <ipython-input-2-84e4c762e3d8>:100 validate] Test: [0/62]	Time 0.362 (0.362)	MAE GT 0.0790 (0.0790)	F-Score GT 0.852 (0.852)
10it [00:02,  3.47it/s][2021-10-19 16:19:07,674 <ipython-input-2-84e4c762e3d8>:100 validate] Test: [10/62]	Time 0.271 (0.294)	MAE GT 0.0328 (0.0604)	F-Score GT 0.917 (0.859)
20it [00:05,  3.63it/s][2021-10-19 16:19:10,470 <ipython-input-2-84e4c762e3d8>:100 validate] Test: [20/62]	Time 0.259 (0.287)	MAE GT 0.0380 (0.0596)	F-Score GT 0.914 (0.854)
30it [00:08,  3.71it/s][2021-10-19 16:19:13,156 <ipython-input-2-84e4c762e3d8>:100 validate] Test: [30/62]	Time 0.297 (0.281)	MAE GT 0.0556 (0.0569)	F-Score GT 0.802 (0.861)
40it [00:11,  3.81it/s][2021-10-19 16:19:15,853 <ipython-input-2-84e4c762e3d8>:100 validate] Test: [40/62]	Time 0.257 (0.278)	MAE GT 0.0488 (0.0601)	F-Score GT 0.866 (0.862)
50it [00:14,  3.05it/s][2021-10-19 16:19:18,900 <ipython-input-2-84e4c762e3d8>:100 validate] Test: [50/62]	Time 0.286 (0.284)	MAE GT 0.0378 (0.0613)	F-Sco

Validation Net2


[2021-10-19 16:19:22,501 <ipython-input-2-84e4c762e3d8>:100 validate] Test: [0/62]	Time 0.257 (0.257)	MAE GT 0.0586 (0.0586)	F-Score GT 0.902 (0.902)
10it [00:02,  4.37it/s][2021-10-19 16:19:24,801 <ipython-input-2-84e4c762e3d8>:100 validate] Test: [10/62]	Time 0.226 (0.232)	MAE GT 0.0335 (0.0568)	F-Score GT 0.908 (0.866)
20it [00:04,  4.46it/s][2021-10-19 16:19:27,070 <ipython-input-2-84e4c762e3d8>:100 validate] Test: [20/62]	Time 0.244 (0.230)	MAE GT 0.0391 (0.0569)	F-Score GT 0.906 (0.859)
30it [00:06,  4.33it/s][2021-10-19 16:19:29,436 <ipython-input-2-84e4c762e3d8>:100 validate] Test: [30/62]	Time 0.225 (0.232)	MAE GT 0.0566 (0.0547)	F-Score GT 0.802 (0.863)
40it [00:09,  3.77it/s][2021-10-19 16:19:31,971 <ipython-input-2-84e4c762e3d8>:100 validate] Test: [40/62]	Time 0.224 (0.237)	MAE GT 0.0478 (0.0580)	F-Score GT 0.866 (0.863)
50it [00:11,  4.38it/s][2021-10-19 16:19:34,274 <ipython-input-2-84e4c762e3d8>:100 validate] Test: [50/62]	Time 0.221 (0.236)	MAE GT 0.0376 (0.0593)	F-Sco

Training Net1


Epoch [ 76] Iter[  1/312]	 Total loss: 0.0115 Labeled loss: 0.0115 Unlabeled lamb*loss: 0.0000
Epoch [ 76] Iter[105/312]	 Total loss: 0.1366 Labeled loss: 0.1366 Unlabeled lamb*loss: 0.0000
Epoch [ 76] Iter[209/312]	 Total loss: 0.0114 Labeled loss: 0.0114 Unlabeled lamb*loss: 0.0000
312it [03:26,  1.52it/s]
[2021-10-19 16:23:05,261 utils.py:189 read_lists] labeled dataset size is 2500


Training Net2


Epoch [ 76] Iter[  1/312]	 Total loss: 0.0079 Labeled loss: 0.0079 Unlabeled lamb*loss: 0.0000
Epoch [ 76] Iter[105/312]	 Total loss: 0.1317 Labeled loss: 0.1317 Unlabeled lamb*loss: 0.0000
Epoch [ 76] Iter[209/312]	 Total loss: 0.0113 Labeled loss: 0.0113 Unlabeled lamb*loss: 0.0000
312it [03:26,  1.51it/s]
0it [00:00, ?it/s]

Validation Net1


[2021-10-19 16:26:33,735 <ipython-input-2-84e4c762e3d8>:100 validate] Test: [0/62]	Time 0.252 (0.252)	MAE GT 0.0735 (0.0735)	F-Score GT 0.871 (0.871)
10it [00:02,  4.24it/s][2021-10-19 16:26:36,036 <ipython-input-2-84e4c762e3d8>:100 validate] Test: [10/62]	Time 0.230 (0.232)	MAE GT 0.0362 (0.0591)	F-Score GT 0.905 (0.865)
20it [00:04,  4.47it/s][2021-10-19 16:26:38,264 <ipython-input-2-84e4c762e3d8>:100 validate] Test: [20/62]	Time 0.223 (0.228)	MAE GT 0.0377 (0.0584)	F-Score GT 0.917 (0.860)
30it [00:06,  4.42it/s][2021-10-19 16:26:40,640 <ipython-input-2-84e4c762e3d8>:100 validate] Test: [30/62]	Time 0.235 (0.231)	MAE GT 0.0589 (0.0565)	F-Score GT 0.791 (0.865)
40it [00:09,  3.97it/s][2021-10-19 16:26:43,128 <ipython-input-2-84e4c762e3d8>:100 validate] Test: [40/62]	Time 0.233 (0.235)	MAE GT 0.0477 (0.0601)	F-Score GT 0.868 (0.865)
50it [00:11,  4.20it/s][2021-10-19 16:26:45,570 <ipython-input-2-84e4c762e3d8>:100 validate] Test: [50/62]	Time 0.230 (0.237)	MAE GT 0.0407 (0.0613)	F-Sco

Validation Net2


[2021-10-19 16:26:48,491 <ipython-input-2-84e4c762e3d8>:100 validate] Test: [0/62]	Time 0.241 (0.241)	MAE GT 0.0620 (0.0620)	F-Score GT 0.895 (0.895)
10it [00:02,  4.09it/s][2021-10-19 16:26:51,018 <ipython-input-2-84e4c762e3d8>:100 validate] Test: [10/62]	Time 0.252 (0.252)	MAE GT 0.0336 (0.0571)	F-Score GT 0.909 (0.867)
20it [00:04,  4.25it/s][2021-10-19 16:26:53,342 <ipython-input-2-84e4c762e3d8>:100 validate] Test: [20/62]	Time 0.225 (0.242)	MAE GT 0.0378 (0.0573)	F-Score GT 0.914 (0.859)
30it [00:07,  4.43it/s][2021-10-19 16:26:55,676 <ipython-input-2-84e4c762e3d8>:100 validate] Test: [30/62]	Time 0.221 (0.240)	MAE GT 0.0563 (0.0551)	F-Score GT 0.808 (0.864)
40it [00:09,  3.89it/s][2021-10-19 16:26:58,184 <ipython-input-2-84e4c762e3d8>:100 validate] Test: [40/62]	Time 0.221 (0.242)	MAE GT 0.0561 (0.0582)	F-Score GT 0.834 (0.864)
50it [00:12,  4.32it/s][2021-10-19 16:27:00,555 <ipython-input-2-84e4c762e3d8>:100 validate] Test: [50/62]	Time 0.251 (0.241)	MAE GT 0.0377 (0.0595)	F-Sco

Training Net1


Epoch [ 77] Iter[  1/312]	 Total loss: 0.0152 Labeled loss: 0.0152 Unlabeled lamb*loss: 0.0000
Epoch [ 77] Iter[105/312]	 Total loss: 0.0132 Labeled loss: 0.0132 Unlabeled lamb*loss: 0.0000
Epoch [ 77] Iter[209/312]	 Total loss: 0.0165 Labeled loss: 0.0165 Unlabeled lamb*loss: 0.0000
312it [03:26,  1.51it/s]
[2021-10-19 16:30:31,306 utils.py:189 read_lists] labeled dataset size is 2500


Training Net2


Epoch [ 77] Iter[  1/312]	 Total loss: 0.0159 Labeled loss: 0.0159 Unlabeled lamb*loss: 0.0000
Epoch [ 77] Iter[105/312]	 Total loss: 0.0108 Labeled loss: 0.0108 Unlabeled lamb*loss: 0.0000
Epoch [ 77] Iter[209/312]	 Total loss: 0.1375 Labeled loss: 0.1375 Unlabeled lamb*loss: 0.0000
312it [03:27,  1.52it/s]
0it [00:00, ?it/s]

Validation Net1


[2021-10-19 16:34:00,452 <ipython-input-2-84e4c762e3d8>:100 validate] Test: [0/62]	Time 0.248 (0.248)	MAE GT 0.0705 (0.0705)	F-Score GT 0.869 (0.869)
10it [00:02,  3.75it/s][2021-10-19 16:34:02,899 <ipython-input-2-84e4c762e3d8>:100 validate] Test: [10/62]	Time 0.233 (0.245)	MAE GT 0.0315 (0.0570)	F-Score GT 0.915 (0.866)
20it [00:04,  4.27it/s][2021-10-19 16:34:05,283 <ipython-input-2-84e4c762e3d8>:100 validate] Test: [20/62]	Time 0.243 (0.242)	MAE GT 0.0323 (0.0568)	F-Score GT 0.924 (0.859)
30it [00:07,  4.17it/s][2021-10-19 16:34:07,677 <ipython-input-2-84e4c762e3d8>:100 validate] Test: [30/62]	Time 0.225 (0.241)	MAE GT 0.0524 (0.0546)	F-Score GT 0.816 (0.865)
40it [00:09,  4.29it/s][2021-10-19 16:34:10,085 <ipython-input-2-84e4c762e3d8>:100 validate] Test: [40/62]	Time 0.232 (0.241)	MAE GT 0.0507 (0.0576)	F-Score GT 0.857 (0.865)
50it [00:12,  4.25it/s][2021-10-19 16:34:12,575 <ipython-input-2-84e4c762e3d8>:100 validate] Test: [50/62]	Time 0.297 (0.243)	MAE GT 0.0352 (0.0587)	F-Sco

Validation Net2


[2021-10-19 16:34:15,514 <ipython-input-2-84e4c762e3d8>:100 validate] Test: [0/62]	Time 0.244 (0.244)	MAE GT 0.0672 (0.0672)	F-Score GT 0.885 (0.885)
10it [00:02,  4.33it/s][2021-10-19 16:34:17,906 <ipython-input-2-84e4c762e3d8>:100 validate] Test: [10/62]	Time 0.290 (0.240)	MAE GT 0.0339 (0.0579)	F-Score GT 0.907 (0.864)
20it [00:04,  3.97it/s][2021-10-19 16:34:20,350 <ipython-input-2-84e4c762e3d8>:100 validate] Test: [20/62]	Time 0.233 (0.242)	MAE GT 0.0409 (0.0579)	F-Score GT 0.901 (0.856)
30it [00:07,  4.05it/s][2021-10-19 16:34:22,795 <ipython-input-2-84e4c762e3d8>:100 validate] Test: [30/62]	Time 0.263 (0.243)	MAE GT 0.0572 (0.0557)	F-Score GT 0.798 (0.861)
40it [00:09,  4.24it/s][2021-10-19 16:34:25,228 <ipython-input-2-84e4c762e3d8>:100 validate] Test: [40/62]	Time 0.259 (0.243)	MAE GT 0.0505 (0.0590)	F-Score GT 0.860 (0.861)
50it [00:12,  4.15it/s][2021-10-19 16:34:27,705 <ipython-input-2-84e4c762e3d8>:100 validate] Test: [50/62]	Time 0.251 (0.244)	MAE GT 0.0393 (0.0603)	F-Sco

Training Net1


Epoch [ 78] Iter[  1/312]	 Total loss: 0.0095 Labeled loss: 0.0095 Unlabeled lamb*loss: 0.0000
Epoch [ 78] Iter[105/312]	 Total loss: 0.0115 Labeled loss: 0.0115 Unlabeled lamb*loss: 0.0000
Epoch [ 78] Iter[209/312]	 Total loss: 0.0213 Labeled loss: 0.0213 Unlabeled lamb*loss: 0.0000
312it [03:26,  1.52it/s]
[2021-10-19 16:37:58,624 utils.py:189 read_lists] labeled dataset size is 2500


Training Net2


Epoch [ 78] Iter[  1/312]	 Total loss: 0.0086 Labeled loss: 0.0086 Unlabeled lamb*loss: 0.0000
Epoch [ 78] Iter[105/312]	 Total loss: 0.0089 Labeled loss: 0.0089 Unlabeled lamb*loss: 0.0000
228it [02:30,  1.51it/s]