In [None]:
import lightning as L
import tqdm
import torch
import matplotlib.pyplot as plt
import pandas as pd

import torch_deeplabv3 as dlv3

In [None]:
model = dlv3.LitDeepLabV3.load_from_checkpoint(
    "lightning_logs/version_45/checkpoints/epoch=19-step=840.ckpt",
    n_ch=15, ens_prediction=True                                              
)
model

In [None]:
_, _, test_loader = dlv3.get_loaders(kf_weighing=False, batch_size=32)

In [None]:
X, y = zip(*(batch for batch in tqdm.tqdm(test_loader)))
X = torch.concat(X)
y = torch.concat(y)
X.shape, y.shape

In [None]:
trainer = L.Trainer(devices=1, logger=False)
trainer.test(model, test_loader)

# PFI

In [None]:
def shuffle_channel(data, channel):
    """Shuffle the data in a specific channel."""
    # Clone the data to avoid modifying the original dataset
    shuffled_data = data.clone()
    # Shuffle each image in the batch separately to maintain independence
    for i in range(shuffled_data.shape[0]):
        shuffled_data[i, channel] = shuffled_data[i, channel].flatten()[
            torch.randperm(shuffled_data[i, channel].nelement())
        ].view(
            shuffled_data[i, channel].shape
        )
        
    return shuffled_data

In [None]:
def get_Xy_loader(X, y, **kwargs):
    if y is None:
        ds = [(Xi, None) for Xi in X]
    else:
        ds = list(zip(X, y))
    return torch.utils.data.DataLoader(ds, **kwargs)

In [None]:
# Only one batch for now
# X, y = next(iter(test_loader))
# X.shape, y.shape

X, y = zip(*(batch for batch in tqdm.tqdm(test_loader)))
X = torch.concat(X)
y = torch.concat(y)
X.shape, y.shape

In [None]:
plt.imshow(X[0, 0])

In [None]:
X_shuf = shuffle_channel(X[:2], 1)

In [None]:
plt.imshow(X_shuf[0, 1])

In [None]:
trainer = L.Trainer(devices=1, logger=False)

In [None]:
score_baseline = trainer.test(model, get_Xy_loader(X, y, batch_size=32))[0]
score_baseline

In [None]:
scores_perm = []
for i in range(model.n_ch):
    X_perm = shuffle_channel(X, i)
    scores_perm.append(
        trainer.test(model, get_Xy_loader(X_perm, y, batch_size=32))[0]
    )

# Make scores relative to baseline
for i, scores_perm_ch in enumerate(scores_perm):
    for k in score_baseline:
        scores_perm[i][k] /= score_baseline[k]

fi_df = pd.DataFrame(scores_perm)
fi_df["test_dice"] = 1 - fi_df["test_dice"]  # performance drop = important
fi_df["test_nz_std"] = fi_df["test_nz_std"] - 1 # more uncertainty = important

In [None]:
fi_df["test_dice"].plot.bar()

In [None]:
fi_df["test_nz_std"].plot.bar()

# Ensemble

In [None]:
trainer = L.Trainer(devices=1, logger=False)
y_hat = trainer.predict(model, test_loader)
y_hat = torch.concat(y_hat)

In [None]:
plt.imshow(y_hat[3])

In [None]:
plt.imshow(y[3])

In [None]:
Xi = X[3]
Xrot = [
    torch.rot90(Xi, k=k, dims=[1, 2])
    for k in [0, 1, 2, 3]
]

fig, axarr = plt.subplots(ncols=2, nrows=2)
for i, ax in enumerate(axarr.ravel()):
    ax.imshow(Xrot[i][0])

In [None]:
def get_X_loader(X):
    ds = [(Xi, 0) for Xi in X]
    return torch.utils.data.DataLoader(ds)

In [None]:
y_hat = trainer.predict(model, get_X_loader(Xrot))
y_hat = torch.concat(y_hat).float()

# Rotate predictions back
y_hat = torch.stack([
    torch.rot90(y_hat[i], k=-i)
    for i in range(len(y_hat))
])

y_hat.shape

In [None]:
fig, axarr = plt.subplots(ncols=2, nrows=2)
for i, ax in enumerate(axarr.ravel()):
    ax.imshow(y_hat[i])

In [None]:
plt.imshow(y_hat.std(axis=0))

In [None]:
plt.imshow(y_hat.mean(axis=0))