# Setting

In [None]:
import glob
import fastremap
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import pickle
from PIL import Image
import pingouin
import random
import re
from scipy import ndimage
import scipy.stats

from sklearn.metrics import cohen_kappa_score, confusion_matrix, ConfusionMatrixDisplay
from tqdm.autonotebook import tqdm
from typing import Dict, List, Optional, Sequence, Tuple, Union
import wandb

from omegaconf import DictConfig, OmegaConf
from hydra import initialize, initialize_config_module, initialize_config_dir, compose
from hydra.utils import instantiate

import torch
from monai.utils import set_determinism

from data.transforms import (
    get_longest_consecutive_positive, ConvertLabel, ConvertLabeld,
)
from utils.metrics import (
    psg_event_analysis, PSGDetectionMatrixMetric, PSGDetectionMatrixMetricV2, PSGAHIMetric, PSGAHIMetricV2, PSGAHIMetricV3,
    apply_valid_consecutive, count_valid_consecutive,
    apply_valid_sleep_event, get_matrix_sleep_event, get_matrix_sleep_event_V2,
    calculate_icc,
)

# Valid

## raw + best thresholds

In [None]:
ckpt_base_dir = './runs/psgradar-seg-re-2311/2023-12-25_23-36-15'

pretrained_ckpt_list = sorted(glob.glob(os.path.join(ckpt_base_dir, '**/*.ckpt'), recursive=True))

parse_str = 'best_metric'
pretrained_ckpt_list = [x for x in pretrained_ckpt_list if parse_str in os.path.basename(x)]

len(pretrained_ckpt_list)

In [None]:
save_output_dir_basebase = './results'
save_output_dir_base = os.path.join(
    save_output_dir_basebase,
    os.path.basename(ckpt_base_dir),
)

save_output_dir = os.path.join(save_output_dir_base, f'valid_raw_{parse_str}')
os.makedirs(save_output_dir, exist_ok=True)
print(save_output_dir)

In [None]:
best_pred_thresholds = []
for pretrained_ckpt in tqdm(pretrained_ckpt_list):
    run_dir = os.sep.join(pretrained_ckpt.split(os.sep)[:-2])
    cfg_override_path = os.path.join(run_dir, 'config/overrides.yaml')
    
    overrides = list(OmegaConf.load(cfg_override_path))
    overrides += [
        "paths.output_dir=temp/logs",
        "train=False",
        "valid=True",
        "model.inferer.sw_batch_size=32",
    ]

    with initialize(version_base=None, config_path="config"):
        cfg = compose(
            config_name="train",
            overrides=overrides,
            return_hydra_config=True,
        )

    if cfg.seed.seed:
        torch.manual_seed(cfg.seed.seed)
    model = instantiate(cfg.model, _recursive_ = False)

    model.load_pretrained(pretrained_ckpt)
    model.eval()

    device_type = 'cuda' if torch.cuda.is_available() else 'cpu'
    device = torch.device(device_type)
    model.to(device)

    if cfg.seed.seed:
        set_determinism(cfg.seed.seed)
    dm = instantiate(cfg.data, _recursive_ = False)

    dm.setup('valid')
    dl_valid = dm.val_dataloader()
    
    for batch in tqdm(dl_valid, total=len(dl_valid)):
        for k in batch.keys():
            try:
                batch[k] = batch[k].to(device)
            except:
                pass
            
        with torch.autocast(device_type, enabled=(cfg.trainer.precision==16)):
            with torch.no_grad():
                model.set_input(batch)
                outputs = model.inferer(model.image, model.forward)
                bin_outputs = torch.softmax(outputs, 1)
        
        tkeys = batch['metadata']['label_path']
        touts = bin_outputs.detach().cpu().numpy()
        for tk, to in zip(tkeys, touts):
            new_path = os.path.join(save_output_dir, os.path.basename(tk))
            np.save(new_path, to)
        
        for k in model.metrics.keys():
            if hasattr(model, 'stage'):
                model.metrics[k](bin_outputs.float(), model.label.float(), mask=model.stage.float())
            else:
                model.metrics[k](bin_outputs.float(), model.label.float())
                
    for k in model.metrics.keys():
        if model.metrics[k].get_buffer() is not None:
            mean_metric = model.metrics[k].aggregate()
            if isinstance(mean_metric, list):
                kks = k.split('__')
                for i in range(len(mean_metric)):
                    mmetric = mean_metric[i].item()
                    print(f'test_metrics/{kks[i]} = {mmetric}')                    
            else:
                mean_metric = mean_metric.item()
                print(f'test_metrics/{k} = {mean_metric}') 
        
        if hasattr(model.metrics[k], 'best_pred_threshold'):        
            best_pred_thresholds.append(model.metrics[k].best_pred_threshold)
        
        model.metrics[k].reset()

In [None]:
data_thr = {
    'ckpt': pretrained_ckpt_list,
    'best_pred_threshold': best_pred_thresholds,
}

with open(os.path.join(save_output_dir, 'thresholds.pkl'), 'wb') as f:
    pickle.dump(data_thr, f, protocol=pickle.HIGHEST_PROTOCOL)
    
pd.DataFrame(data=data_thr)

## best threshold best valid

In [None]:
save_output_dir_basebase = './results'
save_output_dir_base = os.path.join(
    save_output_dir_basebase,
    '2023-12-25_23-36-15',

)
save_output_dir_raw = os.path.join(save_output_dir_base, f'valid_raw_{parse_str}')
save_output_dir = os.path.join(save_output_dir_base, f'valid_best_{parse_str}')
os.makedirs(save_output_dir, exist_ok=True)

In [None]:
with open(os.path.join(save_output_dir_raw, 'thresholds.pkl'), 'rb') as f:
    data_thr = pickle.load(f)
pretrained_ckpt_list = data_thr['ckpt']
best_pred_thresholds = data_thr['best_pred_threshold']
pd.DataFrame(data=data_thr)

In [None]:
for i, pretrained_ckpt in tqdm(enumerate(pretrained_ckpt_list), total=len(pretrained_ckpt_list)):
    run_dir = os.sep.join(pretrained_ckpt.split(os.sep)[:-2])
    cfg_override_path = os.path.join(run_dir, 'config/overrides.yaml')
    
    overrides = list(OmegaConf.load(cfg_override_path))
    overrides += [
        "paths.output_dir=temp/logs",
        "train=False",
        "valid=True",
        "model.inferer.sw_batch_size=32",
    ]

    with initialize(version_base=None, config_path="config"):
        cfg = compose(
            config_name="train",
            overrides=overrides,
            return_hydra_config=True,
        )

    if cfg.seed.seed:
        torch.manual_seed(cfg.seed.seed)

    if cfg.seed.seed:
        set_determinism(cfg.seed.seed)
    dm = instantiate(cfg.data, _recursive_ = False)

    dm.setup('valid')
    dl_valid = dm.val_dataloader()
    
    for batch in tqdm(dl_valid, total=len(dl_valid)):
        tkeys = batch['metadata']['label_path']
        for tk in tkeys:
            raw_tx = np.load(os.path.join(save_output_dir_raw, os.path.basename(tk)))
            bin_tx = (raw_tx[1:].sum(0, keepdims=True) >= best_pred_thresholds[i][0]).astype('uint8')
            bin_tx_0 = 1 - bin_tx
            if raw_tx.shape[0] > 2:
                _y_im, _nb_labels = ndimage.label(bin_tx[0])
                bin_tx = np.zeros_like(raw_tx[1:]).astype('uint8')
                for j in range(1, _nb_labels+1):
                    _cidx = _y_im == j
                    _argm = raw_tx[1:][:,_cidx].sum(1).argmax()
                    bin_tx[_argm, _cidx] = 1            
            bin_tx = np.concatenate([bin_tx_0, bin_tx], axis=0)
            new_path = os.path.join(save_output_dir, os.path.basename(tk))
            np.save(new_path, bin_tx)

## valid metrics 

In [None]:
save_output_dir_basebase = './results'
save_output_dir_base = os.path.join(
    save_output_dir_basebase,
    '2023-12-25_23-36-15',
)

parse_str = 'best_metric'
save_output_dir_raw = os.path.join(save_output_dir_base, f'valid_raw_{parse_str}')
save_output_dir = os.path.join(save_output_dir_base, f'valid_best_{parse_str}')

label_dir = './temp/data/230821_resample_data_v0/Event-re-2311'
stage_dir = './temp/data/230821_resample_data_v0/Stage'

In [None]:
preds_files_list = sorted(glob.glob(os.path.join(save_output_dir, '*.npy')))
cases_list = [os.path.basename(x).split('.npy')[0] for x in preds_files_list]

In [None]:
use_HA = True
use_best = True
max_label = 6
iou_thr = 0.00001
use_thr = 0.5
threshold_consecutive_seconds=10
sampling_frequency=8
threshold_consecutive = threshold_consecutive_seconds * sampling_frequency

det_metric_name = ['f1', 'precision', 'recall']
ahi_metric_name = ['ahimae', 'ahimape', 'eventmae', 'ahicor', 'ahiicc', 'osakappalinear']
sleep_postprocess_list = [False, True]

use_output_dir = save_output_dir if use_best else save_output_dir_raw
    
detmat = []
gts_tst = []

for tcase in tqdm(cases_list):
    tpreds_file = os.path.join(use_output_dir, f'{tcase}.npy')
    tlabel_file = os.path.join(label_dir, f'{tcase}.npy')
    tstage_file = os.path.join(stage_dir, f'{tcase}.npy')
    tpreds = np.load(tpreds_file)
    tlabel = np.load(tlabel_file)
    tstage = np.load(tstage_file)
    simage = tstage > 0.5

    gts_tst.append(simage.sum().item()/8/3600)
        
    det2 = get_matrix_sleep_event_V2(tpreds, tlabel, simage, threshold_consecutive=threshold_consecutive, iou_thr=iou_thr, max_label=max_label)
    detmat.append(det2)

detmat = np.array(detmat)
gts_tst = np.array(gts_tst)
ahimat = detmat / gts_tst.reshape(-1, 1, 1)

ahi = ahimat[:,1:].sum((1,2))
est_ahi = ahimat[:,:,1:].sum((1,2))
nev = detmat[:,1:].sum((1,2))
est_nev = detmat[:,:,1:].sum((1,2))
api = ahimat[:,np.array([2,3,4])].sum((1,2))
est_api = ahimat[:,:,1:].sum((1,2)) - ahimat[:,:,1:2].sum((1,2))

def ahi_to_osa(x):
    if x >= 30:
        return 3
    elif x >= 15:
        return 2
    elif x >= 5:
        return 1
    else:
        return 0

osa_ahi = np.array(list(map(ahi_to_osa, ahi)))
osa_est_ahi = np.array(list(map(ahi_to_osa, est_ahi)))

In [None]:
print(use_output_dir)
if use_best:
    print(f'iou_thr: {iou_thr}')
else:
    print(f'iou_thr: {iou_thr}, use_thr: {use_thr}')
print('')

met_ahimae = np.abs(est_ahi - ahi).mean()
print(f'AHI MAE = {met_ahimae}')
met_eventmae = np.abs(est_nev - nev).mean()
print(f'Event MAE = {met_eventmae}')
met_ahimape = np.abs((est_ahi - ahi)/(1 + ahi)).mean()
print(f'AHI MAPE = {met_ahimape}')
met_ahicor = torch.corrcoef(torch.stack([torch.tensor(est_ahi), torch.tensor(ahi)]))[0,1].item()
print(f'AHI COR = {met_ahicor}')
n_samples = len(ahi)
icc_df = pd.DataFrame(data={
    'targets': [*np.arange(n_samples),]*2, 
    'ratings': ahi.tolist() + est_ahi.tolist(), 
    'raters': [1]*n_samples + [2]*n_samples,
})
icc_stat = pingouin.intraclass_corr(data=icc_df, targets='targets', raters='raters', ratings='ratings')
met_ahiicc = icc_stat[icc_stat['Type']=='ICC2']['ICC'].item()
print(f'AHI ICC = {met_ahiicc}')

met_apimae = np.abs(est_api - api).mean()
print(f'AI MAE = {met_apimae}')
met_apicor = torch.corrcoef(torch.stack([torch.tensor(est_api), torch.tensor(api)]))[0,1].item()
print(f'AI COR = {met_apicor}')
n_samples = len(api)
icc_df = pd.DataFrame(data={
    'targets': [*np.arange(n_samples),]*2, 
    'ratings': api.tolist() + est_api.tolist(), 
    'raters': [1]*n_samples + [2]*n_samples,
})
icc_stat = pingouin.intraclass_corr(data=icc_df, targets='targets', raters='raters', ratings='ratings')
met_apiicc = icc_stat[icc_stat['Type']=='ICC2']['ICC'].item()
print(f'AI ICC = {met_apiicc}')

met_osakappalinear = cohen_kappa_score(osa_est_ahi, osa_ahi, weights='linear')
print(f'OSA Kappa Linear = {met_osakappalinear}')

cm = confusion_matrix(osa_ahi, osa_est_ahi)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=['Normal', 'Mild', 'Moderate', 'Severe'],)
disp.plot()
disp.ax_.set_title('OSA')

# Test

## save results

In [None]:
save_output_dir_basebase = './results'
save_output_dir_base = os.path.join(
    save_output_dir_basebase,
    '2023-12-25_23-36-15', 
)

parse_str = 'best_metric'
save_output_dir_raw = os.path.join(save_output_dir_base, f'test_raw_{parse_str}')
save_output_dir = os.path.join(save_output_dir_base, f'test_best_{parse_str}')

label_dir = './temp/data/230821_resample_data_v0/Event-re-2311'
stage_dir = './temp/data/230821_resample_data_v0/Stage'

os.makedirs(save_output_dir_raw, exist_ok=True)
os.makedirs(save_output_dir, exist_ok=True)
print(save_output_dir)

valid_raw_dir = os.path.join(save_output_dir_base, f'valid_raw_{parse_str}')
with open(os.path.join(valid_raw_dir, 'thresholds.pkl'), 'rb') as f:
    data_thr = pickle.load(f)
pretrained_ckpt_list = data_thr['ckpt']
best_pred_thresholds = data_thr['best_pred_threshold']

with open(os.path.join(save_output_dir_raw, 'thresholds.pkl'), 'wb') as f:
    pickle.dump(data_thr, f, protocol=pickle.HIGHEST_PROTOCOL)
    
with open(os.path.join(save_output_dir, 'thresholds.pkl'), 'wb') as f:
    pickle.dump(data_thr, f, protocol=pickle.HIGHEST_PROTOCOL)

pd.DataFrame(data=data_thr)

In [None]:
for i, pretrained_ckpt in tqdm(enumerate(pretrained_ckpt_list), total=len(pretrained_ckpt_list)):
    run_dir = os.sep.join(pretrained_ckpt.split(os.sep)[:-2])
    cfg_override_path = os.path.join(run_dir, 'config/overrides.yaml')
    
    overrides = list(OmegaConf.load(cfg_override_path))
    overrides += [
        "paths.output_dir=temp/logs",
        "train=False",
        "valid=False",
        "test=True",
        "model.inferer.sw_batch_size=32",
    ]

    with initialize(version_base=None, config_path="config"):
        cfg = compose(
            config_name="train",
            overrides=overrides,
            return_hydra_config=True,
        )

    if cfg.seed.seed:
        torch.manual_seed(cfg.seed.seed)
    model = instantiate(cfg.model, _recursive_ = False)

    model.load_pretrained(pretrained_ckpt)
    model.eval()

    device_type = 'cuda' if torch.cuda.is_available() else 'cpu'
    device = torch.device(device_type)
    model.to(device)

    if cfg.seed.seed:
        set_determinism(cfg.seed.seed)
    dm = instantiate(cfg.data, _recursive_ = False)

    dm.setup('test')
    dl_test = dm.test_dataloader()
    
    for batch in tqdm(dl_test, total=len(dl_test)):
        for k in batch.keys():
            try:
                batch[k] = batch[k].to(device)
            except:
                pass
            
        with torch.autocast(device_type, enabled=(cfg.trainer.precision==16)):
            with torch.no_grad():
                model.set_input(batch)
                outputs = model.inferer(model.image, model.forward)
                bin_outputs = torch.softmax(outputs, 1)
        
        tkeys = batch['metadata']['label_path']
        touts = bin_outputs.detach().cpu().numpy()
        for tk, to in zip(tkeys, touts):
            new_raw_path_dir = os.path.join(save_output_dir_raw, str(i))
            os.makedirs(new_raw_path_dir, exist_ok=True)
            new_raw_path = os.path.join(new_raw_path_dir, os.path.basename(tk))
            np.save(new_raw_path, to)
            
            new_path_dir = os.path.join(save_output_dir, str(i))
            os.makedirs(new_path_dir, exist_ok=True)
            raw_tx = to
            bin_tx = (raw_tx[1:].sum(0, keepdims=True) >= best_pred_thresholds[i][0]).astype('uint8')
            bin_tx_0 = 1 - bin_tx
            if raw_tx.shape[0] > 2:
                _y_im, _nb_labels = ndimage.label(bin_tx[0])
                bin_tx = np.zeros_like(raw_tx[1:]).astype('uint8')
                for j in range(1, _nb_labels+1):
                    _cidx = _y_im == j
                    _argm = raw_tx[1:][:,_cidx].sum(1).argmax()
                    bin_tx[_argm, _cidx] = 1            
            bin_tx = np.concatenate([bin_tx_0, bin_tx], axis=0)
            new_path = os.path.join(new_path_dir, os.path.basename(tk))
            np.save(new_path, bin_tx)

## ensemble metrics

In [None]:
save_output_dir_basebase = './results'
save_output_dir_base = os.path.join(
    save_output_dir_basebase,
    '2023-12-25_23-36-15', 
)

parse_str = 'best_metric'
save_output_dir_raw = os.path.join(save_output_dir_base, f'test_raw_{parse_str}')
save_output_dir = os.path.join(save_output_dir_base, f'test_best_{parse_str}')

save_output_dir_raw_folds = [x for x in sorted(glob.glob(os.path.join(save_output_dir_raw, '*'))) if os.path.isdir(x)]
save_output_dir_folds = [x for x in sorted(glob.glob(os.path.join(save_output_dir, '*'))) if os.path.isdir(x)]

save_ensemble_dir_raw = os.path.join(save_output_dir_base, f'ensemble_test_raw_{parse_str}')
save_ensemble_dir = os.path.join(save_output_dir_base, f'ensemble_test_{parse_str}')

label_dir = './temp/data/230821_resample_data_v0/Event-re-2311'
stage_dir = './temp/data/230821_resample_data_v0/Stage'

In [None]:
# save ensemble
preds_files_list = sorted(glob.glob(os.path.join(save_output_dir_folds[0], '*.npy')))
cases_list = [os.path.basename(x).split('.npy')[0] for x in preds_files_list]

for tcase in tqdm(cases_list):
    folds_preds = []
    for fold in save_output_dir_raw_folds:
        fpath = os.path.join(fold, f'{tcase}.npy')
        fpreds = np.load(fpath)
        folds_preds.append(fpreds)
    tpreds = np.mean(np.stack(folds_preds), axis=0)
    
    os.makedirs(save_ensemble_dir_raw, exist_ok=True)
    new_path = os.path.join(save_ensemble_dir_raw, f'{tcase}.npy')
    np.save(new_path, tpreds)
    
    folds_preds = []
    for fold in save_output_dir_folds:
        fpath = os.path.join(fold, f'{tcase}.npy')
        fpreds = np.load(fpath)
        folds_preds.append(fpreds)
    tpreds_float = np.mean(np.stack(folds_preds), axis=0)
    tpreds = np.zeros_like(tpreds_float)
    tpreds[0][tpreds_float[0] > 0.5] = 1
    pos_tpreds = tpreds_float[1:].argmax(0)
    for i in range(1, tpreds.shape[0]):
        tpreds[i][(tpreds_float[0] < 0.5)*(pos_tpreds==(i-1))] = 1
    
    os.makedirs(save_ensemble_dir, exist_ok=True)
    new_path = os.path.join(save_ensemble_dir, f'{tcase}.npy')
    np.save(new_path, tpreds)

In [None]:
preds_files_list = sorted(glob.glob(os.path.join(save_ensemble_dir, '*.npy')))
cases_list = [os.path.basename(x).split('.npy')[0] for x in preds_files_list]

In [None]:
use_HA = True
use_best = True
max_label = 6
iou_thr = 0.00001
use_thr = 0.5
threshold_consecutive_seconds=10
sampling_frequency=8
threshold_consecutive = threshold_consecutive_seconds * sampling_frequency

det_metric_name = ['f1', 'precision', 'recall']
ahi_metric_name = ['ahimae', 'ahimape', 'eventmae', 'ahicor', 'ahiicc', 'osakappalinear']
sleep_postprocess_list = [False, True]

foldN = None
#foldN = 4
if foldN is None:
    use_output_dir = save_ensemble_dir if use_best else save_ensemble_dir_raw
else:
    use_output_dir = save_output_dir_folds[foldN] if use_best else save_output_dir_raw_folds[foldN]  
    
    
detmat = []
gts_tst = []

for tcase in tqdm(cases_list):
    tpreds_file = os.path.join(use_output_dir, f'{tcase}.npy')
    tlabel_file = os.path.join(label_dir, f'{tcase}.npy')
    tstage_file = os.path.join(stage_dir, f'{tcase}.npy')
    tpreds = np.load(tpreds_file)
    tlabel = np.load(tlabel_file)
    tstage = np.load(tstage_file)
    simage = tstage > 0.5

    gts_tst.append(simage.sum().item()/8/3600)
        
    det2 = get_matrix_sleep_event_V2(tpreds, tlabel, simage, threshold_consecutive=threshold_consecutive, iou_thr=iou_thr, max_label=max_label)
    detmat.append(det2)

detmat = np.array(detmat)
gts_tst = np.array(gts_tst)
ahimat = detmat / gts_tst.reshape(-1, 1, 1)

ahi = ahimat[:,1:].sum((1,2))
est_ahi = ahimat[:,:,1:].sum((1,2))
nev = detmat[:,1:].sum((1,2))
est_nev = detmat[:,:,1:].sum((1,2))
api = ahimat[:,np.array([2,3,4])].sum((1,2))
est_api = ahimat[:,:,1:].sum((1,2)) - ahimat[:,:,1:2].sum((1,2))

def ahi_to_osa(x):
    if x >= 30:
        return 3
    elif x >= 15:
        return 2
    elif x >= 5:
        return 1
    else:
        return 0

osa_ahi = np.array(list(map(ahi_to_osa, ahi)))
osa_est_ahi = np.array(list(map(ahi_to_osa, est_ahi)))

In [None]:
print(use_output_dir)
if use_best:
    print(f'iou_thr: {iou_thr}')
else:
    print(f'iou_thr: {iou_thr}, use_thr: {use_thr}')
print('')

met_ahimae = np.abs(est_ahi - ahi).mean()
print(f'AHI MAE = {met_ahimae}')
met_eventmae = np.abs(est_nev - nev).mean()
print(f'Event MAE = {met_eventmae}')
met_ahimape = np.abs((est_ahi - ahi)/(1 + ahi)).mean()
print(f'AHI MAPE = {met_ahimape}')
met_ahicor = torch.corrcoef(torch.stack([torch.tensor(est_ahi), torch.tensor(ahi)]))[0,1].item()
print(f'AHI COR = {met_ahicor}')
n_samples = len(ahi)
icc_df = pd.DataFrame(data={
    'targets': [*np.arange(n_samples),]*2, 
    'ratings': ahi.tolist() + est_ahi.tolist(), 
    'raters': [1]*n_samples + [2]*n_samples,
})
icc_stat = pingouin.intraclass_corr(data=icc_df, targets='targets', raters='raters', ratings='ratings')
met_ahiicc = icc_stat[icc_stat['Type']=='ICC2']['ICC'].item()
print(f'AHI ICC = {met_ahiicc}')

met_apimae = np.abs(est_api - api).mean()
print(f'AI MAE = {met_apimae}')
met_apicor = torch.corrcoef(torch.stack([torch.tensor(est_api), torch.tensor(api)]))[0,1].item()
print(f'AI COR = {met_apicor}')
n_samples = len(api)
icc_df = pd.DataFrame(data={
    'targets': [*np.arange(n_samples),]*2, 
    'ratings': api.tolist() + est_api.tolist(), 
    'raters': [1]*n_samples + [2]*n_samples,
})
icc_stat = pingouin.intraclass_corr(data=icc_df, targets='targets', raters='raters', ratings='ratings')
met_apiicc = icc_stat[icc_stat['Type']=='ICC2']['ICC'].item()
print(f'AI ICC = {met_apiicc}')

met_osakappalinear = cohen_kappa_score(osa_est_ahi, osa_ahi, weights='linear')
print(f'OSA Kappa Linear = {met_osakappalinear}')

cm = confusion_matrix(osa_ahi, osa_est_ahi)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=['Normal', 'Mild', 'Moderate', 'Severe'],)
disp.plot()
disp.ax_.set_title('OSA')