In [2]:
#This one for MACMD compartivie Networks 

import torch
from torch import nn
from torch.cuda.amp import autocast, GradScaler
from torch.utils.data import DataLoader
from loader import *
from torchvision import transforms


from network import PVT_B2_MACMD

from engine import *
import os
import sys
os.environ["CUDA_VISIBLE_DEVICES"] = "0" # "0, 1, 2, 3"

from utils import *
from configs.config_setting import setting_config

import warnings
warnings.filterwarnings("ignore")



def main(config):
    config = setting_config
    # config.add_argument_config()
    config.set_datasets()
    config.set_opt_sch()

    work_dir =  '/users/MACMD/weight/'

    print('#----------Creating logger----------#')
    sys.path.append(work_dir + '/')
    log_dir = os.path.join(work_dir, 'log')
    #checkpoint_dir = os.path.join(work_dir, 'checkpoints')
    #resume_model = os.path.join(checkpoint_dir, 'latest.pth')
    #csv_folder = os.path.join(config.work_dir, 'csv_folder')
    outputs = os.path.join(config.work_dir, 'outputs')
    if not os.path.exists(outputs):
        os.makedirs(outputs)
    # if not os.path.exists(checkpoint_dir):
    #     os.makedirs(checkpoint_dir)
    # if not os.path.exists(csv_folder):
    #     os.makedirs(csv_folder)

    global logger
    logger = get_logger('test', log_dir)

    log_config_info(config, logger)


    print('#----------GPU init----------#')
    set_seed(config.seed)
    gpu_ids = [0]# [0, 1, 2, 3]
    torch.cuda.empty_cache()


    print('#----------Preparing dataset----------#')
 
    if config.datasets_name == "isic2017" or config.datasets_name == "isic2018":
        test_dataset = config.datasets(path_Data = config.data_path, train = False, Test = True)
        test_loader = DataLoader(test_dataset,
                                    batch_size=1,
                                    shuffle=False,
                                    pin_memory=True, 
                                    num_workers=config.num_workers,
                                    drop_last=True)
    elif config.datasets_name == "synapse" or config.datasets_name == "acdc":
      
        test_dataset = config.datasets(base_dir=config.volume_path, split="test", list_dir=config.list_dir)
        test_sampler = DistributedSampler(test_dataset, shuffle=False) if config.distributed else None
        test_loader = DataLoader(test_dataset,
                                batch_size=1, # if config.distributed else config.batch_size,
                                shuffle=False,
                                pin_memory=True, 
                                num_workers=config.num_workers, 
                                sampler=test_sampler,
                                drop_last=True)




    print('#----------Prepareing Models----------#')
    model_cfg = config.model_config
    
    model = PVT_B2_MACMD(n_classes=model_cfg['num_classes'], n_channels=model_cfg['input_channels'], img_size = config.input_size_h)
    
    model = torch.nn.DataParallel(model.cuda(), device_ids=gpu_ids, output_device=gpu_ids[0])

    print('#----------Prepareing loss, opt, sch and amp----------#')
    criterion = config.criterion
    optimizer = get_optimizer(config, model)
    scheduler = get_scheduler(config, optimizer)
    scaler = GradScaler()

    print('#----------Set other params----------#')
    min_loss = 999
    start_epoch = 1
    min_epoch = 1

    print('#----------Testing----------#')
    best_weight = torch.load(work_dir + 'best.pth', map_location=torch.device('cpu'), weights_only=True)
    model.module.load_state_dict(best_weight)
    
    # best_weight = torch.load(work_dir + 'checkpoints/latest.pth', map_location=torch.device('cpu'), weights_only=False)
    # #model.module.load_state_dict(best_weight)
    # model.module.load_state_dict(best_weight['model_state_dict'])
    if config.datasets_name == "isic2017" or config.datasets_name == "isic2018":
        loss, miou, f1_or_dsc = test_one_epoch(
                    test_loader,
                    model,
                    criterion,
                    logger,
                    config,
                )
    elif config.datasets_name == "synapse" or config.datasets_name == "acdc":
        mean_dice, mean_hd95 = test_sy_ac(
            test_dataset,
            test_loader,
            model,
            logger,
            config,
            test_save_path=outputs,
            val_or_test=True
        )
           


if __name__ == '__main__':
    config = setting_config
    main(config)

loss_weight [0.4, 0.6]
data path: /users/data/Synapse/train/
#----------Creating logger----------#
#----------GPU init----------#
#----------Preparing dataset----------#
#----------Prepareing Models----------#
Pretrain weights loaded.
Model pvt_v2_b2 backbone:  created, param count: 24849856
Model MACMD decoder:  created, param count: 9750881
#----------Prepareing loss, opt, sch and amp----------#
#----------Set other params----------#
#----------Testing----------#


100%|██████████| 12/12 [16:03<00:00, 80.30s/it]

test  mean_dice: 0.832706639180243, mean_hd95: 14.91915603442167, time(s): 963.73



