In [1]:
from careamics import CAREamist
from datasets import load_split_datasets, load_datasets_yml
import os
import matplotlib.pyplot as plt
import numpy as np

In [None]:
def visualize_predictions(input_, predictions):

    plt.figure(figsize=(15, 5), dpi=300)
    n_models = len(predictions.items())
    fig, axs = plt.subplots(1, 1 + n_models)
    if axs.ndim == 1:
        axs = axs[np.newaxis, :]
        
    for ch in range(input_.shape[1]):
        axs[0, 0].imshow(input_[0, ch, ...])
        axs[0, 0].set_title('Input')
        axs[0, 0].axis('off')
        for i, (model_name, pred) in enumerate(predictions.items()):
            axs[0, i+1].imshow(pred[0, ch, ...])
            axs[0, i+1].set_title(model_name)
            axs[0, i+1].axis('off')
    plt.show()

In [None]:
BASE_FOLDER = 'models/run4test/checkpoints/'
for dset in load_datasets_yml():

    print(dset['name'])
    (train, val), (train_mean, train_std), (val_mean, val_std) = load_split_datasets(dset['name'])
    val_images = val[:5]
    if val_images.ndim == 3:
        val_images = val_images[:, None, ...]
    multichannel = (val_images.shape[1] != 1)

    print(val_images.shape, multichannel)

    model_predictions = {}

    for model_type in ['n2v', 'n2v2']:

        train_modalities = ['indch'] + ([''] if multichannel else [])
        for train_modality in train_modalities:
            model_path = os.path.join(BASE_FOLDER, f'{model_type}_{dset["name"]}{train_modality}.ckpt')
            print(model_path, os.path.isfile(model_path))
            
            model = CAREamist(model_path)
            try:
                predictions = model.predict(val_images)
            except RuntimeError:
                predictions = np.zeros_like(val_images)
                print(f'Failed to predict with {model_path}')

            model_predictions[f'{dset["name"]}_{model_type}_{train_modality}'] = predictions
        
    visualize_predictions(val_images, model_predictions)
    
        