In [1]:
import sys
sys.path.append('../')

In [2]:
import os
import sys
import json
import time
import rich
import numpy as np
import pickle
import wandb
import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F

from configs.finetune import FinetuneConfig
from tasks.classification import Classification

from models.backbone.base import calculate_out_features
from models.backbone.densenet import DenseNetBackbone
from models.backbone.resnet import build_resnet_backbone
from models.head.projector import MLPHead
from models.head.classifier import LinearClassifier

from datasets.brain import BrainProcessor, Brain, BrainMoCo
from datasets.transforms import make_transforms, compute_statistics

from utils.logging import get_rich_logger
from utils.gpu import set_gpu

In [3]:
from easydict import EasyDict as edict
from torch.utils.data import DataLoader
from utils.metrics import classification_result

import collections
from sklearn.ensemble import GradientBoostingClassifier, RandomForestClassifier
from sklearn.tree import DecisionTreeClassifier

In [4]:
@torch.no_grad()
def evaluate_finetuned_model(data_loader, backbone, classifier, device, adjusted):
   
    backbone.to(device)
    backbone.eval()
    classifier.to(device)
    classifier.eval()
    
    y_true, y_pred = [], []
    for i, batch in enumerate(data_loader):
        x = batch['x'].float().to(device)
        y = batch['y'].to(device)
        logits = classifier(backbone(x))
        y_true.append(y.long())
        y_pred.append(logits)
    y_true = torch.cat(y_true, dim=0)
    y_pred = torch.cat(y_pred, dim=0).to(torch.float32)
    
    clf_result = classification_result(y_true=y_true.cpu().numpy(),
                                       y_pred=y_pred.softmax(1).cpu().numpy(),
                                       adjusted=adjusted)
    clf_result['y_true'] = y_true.cpu().numpy()
    clf_result['y_pred'] = y_pred.softmax(1).cpu().numpy()
    
    backbone.to('cpu')
    classifier.to('cpu')
    
    return clf_result

In [5]:
hashs =["2022-07-01_06-45-52", "2022-07-01_06-19-43", "2022-07-01_05-52-53", "2022-07-01_05-27-16", "2022-07-01_05-02-03", "2022-07-01_04-36-14", "2022-07-01_04-11-23", "2022-07-01_03-45-52", "2022-07-01_03-20-04", "2022-07-01_02-54-33"]

In [6]:
gpus = ['30']
server = 'workstation2'

In [7]:
hash = hashs[0]

In [15]:
result_list = []
result_adj_list = []

for hash in tqdm.tqdm(hashs):

    config = edict()
    config.server = server
    config.gpus = gpus
    local_rank = 0

    config.finetune_file = os.path.join(f'../checkpoints/pet/resnet/{hash}/ckpt.last.pth.tar')
    finetune_config = os.path.join(f'../checkpoints/pet/resnet/{hash}/configs.json')
    with open(finetune_config, 'rb') as fb:
        finetune_config = json.load(fb)

    finetune_config_names = [
        # data_parser
        'data_type', 'root', 'data_info', 'mci_only', 'n_splits', 'n_cv',
        'image_size', 'small_kernel', 'random_state',
        'intensity', 'crop', 'crop_size', 'rotate', 'flip', 'affine', 'blur', 'blur_std', 'prob',
        # model_parser
        'backbone_type', 'init_features', 'growth_rate', 'block_config', 'bn_size', 'dropout_rate',
        'arch', 'no_max_pool',
        # train
        'batch_size',
        # moco / supmoco
        'alphas',
        # others
        'task', 'projector_dim'
    ]

    for name in finetune_config_names:
        if name in finetune_config.keys():
            setattr(config, name, finetune_config[name])

    set_gpu(config)
    np.random.seed(config.random_state)
    torch.manual_seed(config.random_state)
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.allow_tf32 = True
    torch.cuda.set_device(local_rank)

    # Networks
    if config.backbone_type == 'densenet':
        backbone = DenseNetBackbone(in_channels=1,
                                    init_features=config.init_features,
                                    growth_rate=config.growth_rate,
                                    block_config=config.block_config,
                                    bn_size=config.bn_size,
                                    dropout_rate=config.dropout_rate,
                                    semi=False)
        activation = True
    elif config.backbone_type == 'resnet':
        backbone = build_resnet_backbone(arch=config.arch,
                                         no_max_pool=config.no_max_pool,
                                         in_channels=1,
                                         semi=False)
        activation = False
    else:
        raise NotImplementedError

    if config.small_kernel:
        backbone._fix_first_conv()

    if config.crop:
        out_dim = calculate_out_features(backbone=backbone, in_channels=1, image_size=config.crop_size)
    else:
        out_dim = calculate_out_features(backbone=backbone, in_channels=1, image_size=config.image_size)
    classifier = LinearClassifier(in_channels=out_dim, num_classes=2, activation=activation)

    backbone.load_weights_from_checkpoint(path=config.finetune_file, key='backbone')
    classifier.load_weights_from_checkpoint(path=config.finetune_file, key='classifier')

    # load finetune data
    data_processor = BrainProcessor(root=config.root,
                                    data_info=config.data_info,
                                    data_type=config.data_type,
                                    mci_only=config.mci_only,
                                    random_state=config.random_state)
    datasets = data_processor.process(n_splits=config.n_splits, n_cv=config.n_cv)

    # intensity normalization
    assert config.intensity in [None, 'scale', 'minmax']
    mean_std, min_max = (None, None), (None, None)
    if config.intensity is None:
        pass
    elif config.intensity == 'scale':
        pass
    elif config.intensity == 'minmax':
        with open(os.path.join(config.root, 'labels/minmax.pkl'), 'rb') as fb:
            minmax_stats = pickle.load(fb)
            min_max = (minmax_stats[config.data_type]['min'], minmax_stats[config.data_type]['max'])
    else:
        raise NotImplementedError

    train_transform, test_transform = make_transforms(image_size=config.image_size,
                                                      intensity=config.intensity,
                                                      mean_std=mean_std,
                                                      min_max=min_max,
                                                      crop=config.crop,
                                                      crop_size=config.crop_size,
                                                      rotate=config.rotate,
                                                      flip=config.flip,
                                                      affine=config.affine,
                                                      blur=config.blur,
                                                      blur_std=config.blur_std,
                                                      prob=config.prob)

    train_set = Brain(dataset=datasets['train'], data_type=config.data_type, transform=test_transform)
    test_set = Brain(dataset=datasets['test'], data_type=config.data_type, transform=test_transform)

    train_loader = DataLoader(dataset=train_set, batch_size=16, drop_last=False)
    test_loader = DataLoader(dataset=test_set, batch_size=16, drop_last=False)

    clf_result = evaluate_finetuned_model(test_loader, backbone, classifier, local_rank, False)
    clf_result_adj = evaluate_finetuned_model(test_loader, backbone, classifier, local_rank, True)

    result_list.append(clf_result)
    result_adj_list.append(clf_result_adj)

100%|██████████| 10/10 [01:28<00:00,  8.83s/it]


In [16]:
import pandas as pd
result_list = pd.DataFrame(result_list)
result_list.drop(['y_true', 'y_pred'], axis=1).mean()

acc      0.800276
auroc    0.826641
sens     0.641365
spec     0.863746
prec     0.623422
f1       0.624594
gmean    0.739130
dtype: float64

In [17]:
result_adj_list = pd.DataFrame(result_adj_list)
result_adj_list.drop(['y_true', 'y_pred'], axis=1).mean()

acc      0.787682
auroc    0.826641
sens     0.742389
spec     0.800103
prec     0.592215
f1       0.655790
gmean    0.769247
dtype: float64