This notebook calculates the PraNet weighted cross-entropy loss for the pre-trained model, cross-entropy fine-tuned model, and the conformal risk training model on the polyp dataset.

In [None]:
%load_ext autoreload
%autoreload 2

%cd ../

In [None]:
import itertools
import json

import pandas as pd
import torch.utils.data
from tqdm.auto import tqdm

from polyps import pranet_utils
from polyps.dataloader import get_loaders
from polyps.PraNet_Res2Net import PraNet

Device = str | torch.device

In [None]:
CKPT_PATH = 'polyps/PraNet-19.pth'
SEEDS = range(10)

In [None]:
def get_pranet_loss(
    model: PraNet,
    loader: torch.utils.data.DataLoader,
    device: Device,
) -> float:
    total_pranet_loss = 0.
    with torch.no_grad():
        for images, masks in loader:
            images = images.to(device, non_blocking=True)
            masks = masks.to(device=device, dtype=torch.float32, non_blocking=True)
            pranet_loss = pranet_utils.pranet_loss(model, images, masks).item()
            total_pranet_loss += pranet_loss * len(images)
        pranet_loss = total_pranet_loss / len(loader.dataset)

    return pranet_loss

In [None]:
device = 'cuda:2'

model = PraNet()
model.load_state_dict(torch.load(CKPT_PATH, weights_only=True))
model.to(device=device)

pranet_losses = []
for seed in tqdm(SEEDS):
    loaders = get_loaders(splits=('test',), batch_size=64, seed=seed)
    pranet_loss = get_pranet_loss(model, loader=loaders['test'], device=device)
    pranet_losses.append(pranet_loss)

In [None]:
pranet_loss_df = pd.DataFrame({'pretrain': pranet_losses}, index=pd.Index(SEEDS, name='seed'))

In [None]:
alphas = [0.01, 0.05, 0.1]
lrs = [1e-2, 1e-3, 1e-4, 1e-5, 1e-6]

rows = []
for s, lr in itertools.product(SEEDS, lrs):
    basename = f'lr{lr:.2g}_s{s}'
    try:
        with open(f'out/polyps/trainbase/{basename}.json') as f:
            result = json.load(f)
        assert result['seed'] == s
        assert result['lr'] == lr
        assert set(alphas).issubset(result['alphas'])
        rows.append({
            'seed': s,
            'lr': lr,
            'val_loss': result['val_loss'],
            'ckpt_path': f'out/polyps/trainbase/{basename}.pt'
        })
    except FileNotFoundError:
        print(f'File not found: {basename}.json')
        continue

df_trainbase = pd.DataFrame(rows)
best_hps = df_trainbase.groupby('seed')['val_loss'].idxmin().values.tolist()
df_trainbase_best = df_trainbase.loc[best_hps].reset_index().set_index('seed')
display(df_trainbase_best)

In [None]:
pranet_losses = []
for seed in tqdm(SEEDS):
    ckpt_path = df_trainbase_best.loc[seed, 'ckpt_path']
    model.load_state_dict(torch.load(ckpt_path, weights_only=True))
    loaders = get_loaders(splits=('test',), batch_size=64, seed=seed)
    pranet_loss = get_pranet_loss(model, loader=loaders['test'], device=device)
    pranet_losses.append(pranet_loss)

pranet_loss_df['cross-entropy'] = pranet_losses

In [None]:
alphas = [0.01, 0.05, 0.1]
lrs = [1e-2, 1e-3, 1e-4, 1e-5, 1e-6]

rows = []
for a, s, lr in itertools.product(alphas, SEEDS, lrs):
    basename = f'a{a:.2f}_lr{lr:.2g}_s{s}'
    try:
        with open(f'out/polyps/e2ecrc/{basename}.json') as f:
            result = json.load(f)
        assert result['seed'] == s
        assert result['lr'] == lr
        assert result['alpha'] == a
        rows.append({
            'alpha': a,
            'seed': s,
            'lr': lr,
            'val_fpr': result['val_fpr'],
            'ckpt_path': f'out/polyps/e2ecrc/{basename}.pt'
        })
    except FileNotFoundError:
        print(f'File not found: {basename}.json')
        continue

df_e2ecrc = pd.DataFrame(rows)
best_hps = df_e2ecrc.groupby(['alpha', 'seed'])['val_fpr'].idxmin().values.tolist()
df_e2ecrc_best = df_e2ecrc.loc[best_hps].reset_index().set_index(['alpha', 'seed'])
display(df_e2ecrc_best)

In [None]:
for a in alphas:
    pranet_losses = []
    for seed in tqdm(SEEDS):
        ckpt_path = df_e2ecrc_best.loc[(a, seed), 'ckpt_path']
        model.load_state_dict(torch.load(ckpt_path, weights_only=True))
        loaders = get_loaders(splits=('test',), batch_size=64, seed=seed)
        pranet_loss = get_pranet_loss(model, loader=loaders['test'], device=device)
        pranet_losses.append(pranet_loss)

    pranet_loss_df[f'e2e-crc α{a:.2f}'] = pranet_losses

In [None]:
pranet_loss_df.to_csv('out/polyps/pranet_losses.csv')

In [None]:
pranet_loss_df = pd.read_csv('out/polyps/pranet_losses.csv', index_col=0)

In [None]:
display(pranet_loss_df)

In [None]:
display(pranet_loss_df.agg(['mean', 'std']))