In [None]:
import joblib
import torch
import tqdm
import matplotlib.pyplot as plt
import numpy as np
import pathlib

from data import Channel as Ch, KelpNCDataset
import torch_simple_unet as unet
from torchmetrics import Dice
import shared

In [None]:
def get_val_score(ckpt):
    cb = ckpt["callbacks"]
    cb_key = None
    for k in cb.keys():
        if "ModelCheckpoint" in k:
            cb_key = k
            break
    return float(cb[cb_key]["best_model_score"].cpu())

In [None]:
seg_dir = pathlib.Path("ens_seg/20240219_163535")
seg_ckpt_files = sorted(seg_dir.glob("*.ckpt"))
seg_val_scores = np.array([
    get_val_score(torch.load(ckpt_path))
    for ckpt_path in tqdm.tqdm(seg_ckpt_files)
])

In [None]:
seg_val_scores

In [None]:
scores, y_hat_seg_ens, use_ch_list = joblib.load(seg_dir / "pred_seg.joblib")
y_hat_ens.shape

In [None]:
# Member weights based on score
if scores[0] is None:
    print("Using val scores for weighting")
    w = seg_val_scores
else:
    print("Using test scores for weighting")
    w = np.array([s[0]["test_dice"] for s in scores])

w /= w.sum()
w

In [None]:
_, _, ds_test = shared.get_dataset(use_channels=None, tile_seed=1337, split_seed=shared.GLOBAL_SEED, mode="seg")
ds_test.load()

In [None]:
len(ds_test) / ds_test.tile_sampler.n_tiles

In [None]:
lsm = [xi[Ch.IS_LAND] for (xi, _) in tqdm.tqdm(ds_test)]
lsm = torch.stack(lsm)
lsm.shape

In [None]:
# Postproc predictions: Force land values to zero
is_sea = 1 - lsm
y_hat_seg_ens = y_hat_seg_ens * is_sea

In [None]:
y_hat_seg_ens.shape

In [None]:
def assemble_tiles(tiles, orig_size, inds):
    _, tile_size, _ = tiles.shape
    count = np.zeros((orig_size, orig_size), dtype=np.uint8)
    array = np.zeros((orig_size, orig_size))
    for tile, (i, j) in zip(tiles, inds):
        count[i:i + tile_size, j:j + tile_size] += 1
        array[i:i + tile_size, j:j + tile_size] += tile
    return count, array

In [None]:
n = ds_test.tile_sampler.n_tiles
inds = ds_test.tile_sampler.inds_

y_hat_ens_ass = []
for j in tqdm.trange(len(y_hat_seg_ens)):
    y_hat_ass = []
    y_hat_ass_cnt = []
    
    for i in range(0, len(ds_test), n):
        c, a = assemble_tiles(y_hat_seg_ens[j, i:i+n].numpy(), 350, inds)
        y_hat_ass.append(a)
        y_hat_ass_cnt.append(c)
        
    y_hat_ass = np.array(y_hat_ass)
    y_hat_ass_cnt = np.array(y_hat_ass_cnt)
    y_hat_ass /= y_hat_ass_cnt
    y_hat_ens_ass.append(y_hat_ass)

y_hat_ens_ass = np.array(y_hat_ens_ass)

In [None]:
y_hat_ens_ass.shape

In [None]:
y_hat_agg = (y_hat_ens_ass > .5).astype(int)
y_hat_agg = (y_hat_ens_ass * w[:, None, None, None]).sum(axis=0)
y_hat_agg.shape

In [None]:
Dice()(torch.tensor(y_hat_agg), torch.tensor(ds_test.masks.values))

In [None]:
n, _, _, _ = y_hat_ens.shapea
for i in range(n):
    d = Dice()(y_hat_ens[i], y)
    print(d)

In [None]:
Dice()(y_hat_ens.mean(dim=0), y)

In [None]:
y_hat_ens.std(dim=0).shape

In [None]:
y_hat_ens = joblib.load("pred_clf.joblib")
y_hat_ens = torch.sigmoid(y_hat_ens)

y_hat_ens = y_hat_ens > 0.5

y_hat_ens.shape

In [None]:
y = y_hat_ens.sum(dim=0).squeeze()
y.shape

In [None]:
plt.hist(y, bins=np.arange(25))
plt.yscale("log")