In [1]:
import os
import time
import gc

import cv2
import numpy as np
import torch
import pandas as pd
from tqdm import tqdm
from torch.utils.data import DataLoader
from skimage import measure

from seg_lib.dataloaders.sam_dataset import SamGeneralDataset
from seg_lib.eval.metrics import Metrics
from seg_lib.eval.fusion import LogitsFusion
from seg_lib.io.files import read_json
from seg_lib.models.cafe_net.pvt import CAFE
from seg_lib.models.pvt_v2 import SegPVT
from seg_lib.models.selector import predictor_selector
from seg_lib.models.normalizer import Normalizer
from seg_lib.prompt import Sampler

KeyboardInterrupt: 

In [2]:
TH = 0.5
# total number of parallel workers used in the dataloader
N_CPUS = os.cpu_count()
N_GPUS = torch.cuda.device_count()
DEVICE = 'cuda' if N_GPUS > 0 else 'cpu'

In [11]:
BASE_PATH = 'E:\\UNIPD\\SAM'
TEST_PATH = f'{BASE_PATH}/data/test'
TEST_DESCRIPTOR = 'metadata/ribs_test_old.csv'
BASE_WEIGHTS = f'{BASE_PATH}/pretrained_models/pvt_v2_b2.pth'
MODELS_SETS = [
    [
        f'{BASE_PATH}/outputs/train/ribs_pvt_small_da1',
        f'{BASE_PATH}/outputs/train/ribs_cafe_small_da1',
    ],
    [
        f'{BASE_PATH}/outputs/train/ribs_pvt_da1',
        f'{BASE_PATH}/outputs/train/ribs_cafe_da1',
    ],
    [
        f'{BASE_PATH}/outputs/train/ribs_pvt_small_da2',
        f'{BASE_PATH}/outputs/train/ribs_cafe_small_da2',
    ],
    [
        f'{BASE_PATH}/outputs/train/ribs_pvt_da2',
        f'{BASE_PATH}/outputs/train/ribs_cafe_da2',
    ],
    [
        f'{BASE_PATH}/outputs/train/ribs_pvt_small_da1',
        f'{BASE_PATH}/outputs/train/ribs_cafe_small_da1',
    ],
    [
        f'{BASE_PATH}/outputs/train/ribs_pvt_da1',
        f'{BASE_PATH}/outputs/train/ribs_cafe_da1',
    ]
]
SAM_WEIGHTS = f'{BASE_PATH}/pretrained_models/sam_vit_b_01ec64.pth'
SAM_MODELS = [
    f'{BASE_PATH}/outputs/train/v0_9_0_small_da1/best_SAMUS.pth',
    f'{BASE_PATH}/outputs/train/v0_9_0_da1/best_SAMUS.pth',
    f'{BASE_PATH}/outputs/train/v0_9_0_small_da2/best_SAMUS.pth',
    f'{BASE_PATH}/outputs/train/v0_9_0_da2/best_SAMUS.pth',
    f'{BASE_PATH}/outputs/train/v1_1_0_small/best_SAMUS.pth',
    f'{BASE_PATH}/outputs/train/v1_1_0/best_SAMUS.pth'
]

In [12]:
def select_model(model_type: str, ckpt_path: str):
    if model_type == 'pvt':
        return SegPVT(backbone_ckpt_path=ckpt_path)
    
    return CAFE(pvtv2_path=ckpt_path)

def get_model(ckpt_path: str, model_path: str, model_type: str = 'pvt'):
    model = select_model(model_type, ckpt_path)
    model.to(DEVICE)
    
    checkpoint = torch.load(model_path, map_location=torch.device(DEVICE))
    model.load_state_dict(checkpoint)

    return model

In [13]:
def get_dataset(
        data_path: str,
        data_desc_path: str,
        batch_size: int = 8,
        embedding_size: int = 128,
        input_size: int = 352):
    csv_path = os.path.join(data_path, data_desc_path)
    df = pd.read_csv(csv_path)
    test_split = 'test' if 'test' in df['split'].unique() else 'val'
    del df

    test_dataset = SamGeneralDataset(
        data_path,
        split=test_split, 
        point_sampler=None,
        df_file_path=data_desc_path,
        img_size=input_size,
        embedding_size=embedding_size,
        prompt=None,
        read_img_as_grayscale=False)
    
    return DataLoader(
        test_dataset,
        batch_size=batch_size * max(1, N_GPUS),
        shuffle=False,
        num_workers=N_CPUS,
        pin_memory=True)

In [14]:
def flatten_preds(preds):
    if isinstance(preds, tuple):
        final_preds = preds[0]
        for i in range(1, len(preds)):
            final_preds += preds[i]
        preds = final_preds

    return preds

def eval(test_dataset, model_1, model_2):
    metrics = Metrics()
    batch_sizes = []
    latency_p_batch = []
    pred_masks = []
    pred_logits = []
    file_names = []
    
    data_config = {'dtype': torch.float32, 'device': DEVICE}
    for batch in tqdm(test_dataset):
        imgs = batch['image'].to(**data_config)
        labels = batch['label'].to(**data_config)
        orig_sizes = np.array(
            list(zip(*batch['original_img_size']) ), dtype=int
        )

        with torch.no_grad():
            _start = time.time()
            preds_1 = model_1(imgs)
            preds_2 = model_2(imgs)
            _end = time.time()

        latency_p_batch.append(_end - _start)
        batch_sizes.append(imgs.shape[0])

        preds_1 = flatten_preds(preds_1)
        preds_2 = flatten_preds(preds_2)
        preds = (preds_1 + preds_2) / 2
        
        logits = preds.sigmoid().detach().numpy()[:, 0, :, :]
        bin_masks = (logits > TH).astype('uint8')
        labels = labels.detach().numpy()[:, 0, :, :].astype('uint8')
        for i in range(bin_masks.shape[0]):
            bin_mask = cv2.resize(
                bin_masks[i], orig_sizes[i], cv2.INTER_NEAREST
            )
            label = cv2.resize(labels[i], orig_sizes[i], cv2.INTER_NEAREST)
            metrics.step(bin_mask, label)

            logits_i = cv2.resize(logits[i], orig_sizes[i], cv2.INTER_NEAREST)
            pred_logits.append(logits_i)
            pred_masks.append(bin_mask)
            file_names.append(batch['img_name'][i])

    metrics = {
        **metrics.get_results(),
        'fps': sum(batch_sizes) / sum(latency_p_batch),
        'latency': sum(latency_p_batch) / sum(batch_sizes),
        'latency_p_batch': sum(latency_p_batch) / len(batch_sizes)
    }
    preds = {
        file_names[i]: {
            'logits': pred_logits[i],
            'mask': pred_masks[i]
        }
        for i in range(len(file_names))
    }
    return metrics, preds


In [15]:
def read_img(i_row: dict):
    base_path = os.path.join(TEST_PATH, i_row['subset']) 
    img_path = os.path.join(base_path, 'img', i_row['img_name'])
    label_path = os.path.join(base_path, 'label', i_row['label_name'])

    if not os.path.exists(img_path):
        raise ValueError(f'Image does not exist on disk: {img_path}')
    if not os.path.exists(label_path):
        raise ValueError(f'Label does not exist on disk: {label_path}')

    # load the image (H, W, 3) and convert it from BGR to RGB
    image = cv2.imread(img_path, 1)[:, :, ::-1]
    # read the mask as (H, W), grayscale
    mask = cv2.imread(label_path, 0)
    mask[mask > 1] = 1

    return image, mask

def predict_sam(data_df, predictor, sampler, seg_preds):
    sam_metrics = Metrics()
    preds = {}
    batch_sizes = []
    latency_p_batch = []

    for _, row in tqdm(data_df.iterrows(), total=data_df.shape[0]):
        img, label = read_img(row)        
        mask_of_blobs = measure.label(seg_preds[row['img_name']]['mask'])
        input_point, input_label = sampler.sample(
            mask_of_blobs, seg_preds[row['img_name']]['mask']
        )

        _start = time.time()
        predictor.set_image(img)
        sam_logits, iou_scores, _ = predictor.predict(
            point_coords=input_point,
            point_labels=input_label,
            multimask_output=False,
            return_logits=True
        )
        _end = time.time()
        batch_sizes.append(1)
        latency_p_batch.append(_end - _start)

        logits =  sam_logits[np.argmax(iou_scores)]
        binary_mask = logits > 0.0
        sam_metrics.step(binary_mask, label)

        preds[row['img_name']] = {
            'label': label,
            'logits': Normalizer.sigmoid(logits).astype(np.float32)
        }

    sam_metrics = {
        **sam_metrics.get_results(),
        'fps': sum(batch_sizes) / sum(latency_p_batch),
        'latency': sum(latency_p_batch) / sum(batch_sizes),
        'latency_p_batch': sum(latency_p_batch) / len(batch_sizes)
    }

    return sam_metrics, preds


In [16]:
def fuse(preds, sam_preds):
    fusion_metrics = Metrics()
    fusion_preds = {}
    for img_name in preds:
        fusion = LogitsFusion.apply(
            preds[img_name]['logits'],
            sam_preds[img_name]['logits'],
            method='default')
        fusion_preds[img_name] = { 'logits': fusion }
        fusion_metrics.step(fusion > TH, sam_preds[img_name]['label'])
    
    return fusion_metrics.get_results(), fusion_preds
    

In [17]:
test_df = pd.read_csv(os.path.join(TEST_PATH, TEST_DESCRIPTOR))
test_ds = get_dataset(
    TEST_PATH,
    TEST_DESCRIPTOR,
    batch_size=8,
    embedding_size=128,
    input_size=352)
sampler = Sampler(
    sampling_step=50,
    min_blob_count=10,
    mode='grid',
    erode_grid='off')

In [18]:
full_preds = []
full_sam_preds = []
full_fusion_preds = []

for i in range(len(MODELS_SETS)):
    print(MODELS_SETS[i])
    pvt_path = MODELS_SETS[i][0]
    cafe_path = MODELS_SETS[i][1]

    pvt_config = read_json(os.path.join(pvt_path, 'config.json'))
    cafe_config = read_json(os.path.join(cafe_path, 'config.json'))

    pvt_model = get_model(
        BASE_WEIGHTS,
        os.path.join(pvt_path, f"best_{pvt_config['model_type']}.pth"),
        model_type='pvt')
    cafe_model = get_model(
        BASE_WEIGHTS,
        os.path.join(cafe_path, f"best_{cafe_config['model_type']}.pth"),
        model_type='cafe')
    predictor = predictor_selector(
        model_topology='SAMUS',
        checkpoint_path=[SAM_WEIGHTS, SAM_MODELS[i]],
        model_type='vit_b',
        device=DEVICE)
    
    metrics, preds = eval(test_ds, pvt_model, cafe_model)
    print('Seg. Metrics', metrics)
    sam_metrics, sam_preds = predict_sam(test_df, predictor, sampler, preds)
    print('SAM Metrics', sam_metrics)
    fusion_metrics, fusion_preds = fuse(preds, sam_preds)
    print('Fusion Metrics', fusion_metrics)

    full_preds.append(preds)
    full_sam_preds.append(sam_preds)
    full_fusion_preds.append(fusion_preds)
    gc.collect()


['E:\\UNIPD\\SAM/outputs/train/ribs_pvt_small_da1', 'E:\\UNIPD\\SAM/outputs/train/ribs_cafe_small_da1']


100%|██████████| 7/7 [01:09<00:00,  9.91s/it]


Seg. Metrics {'iou': 0.6204378660941883, 'dice': 0.7644125092063282, 'mae': 9.989186664486024, 'f-measure': 0.7488671886674982, 'e-measure': 0.901772510338703, 'fps': 0.8618318118397545, 'latency': 1.160319201800288, 'latency_p_batch': 8.122234412602015}


100%|██████████| 49/49 [02:21<00:00,  2.89s/it]


SAM Metrics {'iou': 0.6360637084708394, 'dice': 0.7768735194441058, 'mae': 9.048432445946698, 'f-measure': 0.7549317705489784, 'e-measure': 0.9046942609827254, 'fps': 0.3515353106050143, 'latency': 2.844664447161616, 'latency_p_batch': 2.844664447161616}
Fusion Metrics {'iou': 0.6409931903576157, 'dice': 0.7800227726055444, 'mae': 9.739050742670903, 'f-measure': 0.7668036046733301, 'e-measure': 0.9080097325914086}
['E:\\UNIPD\\SAM/outputs/train/ribs_pvt_da1', 'E:\\UNIPD\\SAM/outputs/train/ribs_cafe_da1']


100%|██████████| 7/7 [00:56<00:00,  8.12s/it]


Seg. Metrics {'iou': 0.7461357452480621, 'dice': 0.8538623229118802, 'mae': 6.362168843005437, 'f-measure': 0.8455649588581431, 'e-measure': 0.9447117200033565, 'fps': 1.0827218201380575, 'latency': 0.9235982700270049, 'latency_p_batch': 6.465187890189035}


100%|██████████| 49/49 [02:34<00:00,  3.15s/it]


SAM Metrics {'iou': 0.7080126979930932, 'dice': 0.8284067812642778, 'mae': 7.338039628605223, 'f-measure': 0.8162090164048258, 'e-measure': 0.9322124411575289, 'fps': 0.3220353704954169, 'latency': 3.1052489621298656, 'latency_p_batch': 3.1052489621298656}
Fusion Metrics {'iou': 0.7558047330521354, 'dice': 0.8602138686596486, 'mae': 6.411769390678653, 'f-measure': 0.8540952186698397, 'e-measure': 0.9464095806998338}


### Late Ensemble Model Fusion

In [39]:
# da1 + da2
small_da1_preds = full_preds[0]
small_da2_preds = full_preds[2]
small_fus_preds = {
    _k: {
        **small_da1_preds[_k],
        'logits': (
            small_da1_preds[_k]['logits'] + small_da2_preds[_k]['logits']
        ) / 2
    }
    for _k in small_da1_preds
}

small_da1_sam_preds = full_sam_preds[0]
small_da2_sam_preds = full_sam_preds[2]
small_fus_sam_preds = {
    _k: {
        **small_da1_sam_preds[_k],
        'logits': (
            small_da1_sam_preds[_k]['logits']
                + small_da2_sam_preds[_k]['logits']) / 2
    }
    for _k in small_da1_sam_preds
}


In [20]:
fusion_metrics, small_da3_preds = fuse(small_fus_preds, small_da1_sam_preds)
print('Fusion Metrics', fusion_metrics)

Fusion Metrics {'iou': 0.6602243498209209, 'dice': 0.7943975631104319, 'mae': 9.385327547340767, 'f-measure': 0.7836134115752975, 'e-measure': 0.9162794015060504}


In [40]:
# da1 + da2
da1_preds = full_preds[1]
da2_preds = full_preds[3]
fus_preds = {
    _k: {
        **da1_preds[_k],
        'logits': (
            da1_preds[_k]['logits'] + da2_preds[_k]['logits']
        ) / 2
    }
    for _k in da1_preds
}

da1_sam_preds = full_sam_preds[1]
da2_sam_preds = full_sam_preds[3]
fus_sam_preds = {
    _k: {
        **da1_sam_preds[_k],
        'logits': (
            da1_sam_preds[_k]['logits']
                + da2_sam_preds[_k]['logits']) / 2
    }
    for _k in da1_sam_preds
}

In [22]:
fusion_metrics, da3_preds = fuse(fus_preds, fus_sam_preds)
print('Fusion Metrics', fusion_metrics)

Fusion Metrics {'iou': 0.7686429309002708, 'dice': 0.8684520701086881, 'mae': 5.111098781019195, 'f-measure': 0.8530498327533637, 'e-measure': 0.9467031925486953}


In [None]:
samus_path = f'{BASE_PATH}/outputs/train/v1_1_0_small/best_SAMUS.pth'
predictor = predictor_selector(
    model_topology='SAMUS',
    checkpoint_path=[SAM_WEIGHTS, samus_path],
    model_type='vit_b',
    device=DEVICE)

In [23]:
sam_metrics, small_sam_da3_preds = predict_sam(
    test_df, predictor, sampler, small_fus_preds)
print('SAM Metrics', sam_metrics)

49it [02:52,  3.51s/it]

SAM Metrics {'iou': 0.6358890626159537, 'dice': 0.7767126719795495, 'mae': 8.973565291054921, 'f-measure': 0.7539693235311709, 'e-measure': 0.9040754384652328, 'fps': 0.28949259309765196, 'latency': 3.4543198128135835, 'latency_p_batch': 3.4543198128135835}





In [25]:
fusion_metrics, fusion_preds = fuse(small_fus_preds, small_sam_da3_preds)
print('Fusion Metrics', fusion_metrics)

Fusion Metrics {'iou': 0.6725449702235992, 'dice': 0.8033278654421616, 'mae': 8.306923982501816, 'f-measure': 0.7861679291655427, 'e-measure': 0.9173022302621321}


In [27]:
sam_metrics, sam_da3_preds = predict_sam(
    test_df, predictor, sampler, fus_preds)
print('SAM Metrics', sam_metrics)

100%|██████████| 49/49 [02:31<00:00,  3.10s/it]

SAM Metrics {'iou': 0.6378849821719765, 'dice': 0.7782654476356223, 'mae': 9.104353290168705, 'f-measure': 0.7574649069311447, 'e-measure': 0.9068881372491581, 'fps': 0.3277945676507628, 'latency': 3.0506911910310084, 'latency_p_batch': 3.0506911910310084}





In [28]:
fusion_metrics, fusion_preds = fuse(fus_preds, sam_da3_preds)
print('Fusion Metrics', fusion_metrics)

Fusion Metrics {'iou': 0.7705527010067807, 'dice': 0.8697553929427443, 'mae': 5.600976977547058, 'f-measure': 0.8595721182579256, 'e-measure': 0.9488832734141}


In [54]:
# SAMUS 1+2+3
fus_sam_preds = {
    _k: {
        **da1_sam_preds[_k],
        'logits': np.mean([
            da1_sam_preds[_k]['logits'],
            da2_sam_preds[_k]['logits'],
            sam_da3_preds[_k]['logits']
        ], axis=0)
    }
    for _k in da1_sam_preds
}
small_fus_sam_preds = {
    _k: {
        **small_da1_sam_preds[_k],
        'logits': np.mean([
            small_da1_sam_preds[_k]['logits'],
            small_da2_sam_preds[_k]['logits'],
            small_sam_da3_preds[_k]['logits']
        ], axis=0)
    }
    for _k in small_da1_preds
}

In [55]:
fusion_metrics, fusion_preds = fuse(small_fus_preds, small_fus_sam_preds)
print('Fusion Metrics', fusion_metrics)

Fusion Metrics {'iou': 0.6689124031586137, 'dice': 0.8007171744662946, 'mae': 8.47706125985157, 'f-measure': 0.7840170173855323, 'e-measure': 0.9165229807920677}
