In [1]:
from careamics import CAREamist
from datasets import load_split_datasets, load_datasets_yml, iter_tiff_batch
import os
import matplotlib.pyplot as plt
import numpy as np
import tifffile

In [None]:
BASE_FOLDER = 'models/run4/'

# Shape: [{'input': [], 'predictions': [{'model_name': [], 'prediction': []}]
imgs_to_show = []

for dset in load_datasets_yml():

    print(dset['name'])
    (train, val), (train_mean, train_std), (val_mean, val_std) = load_split_datasets(dset['name'])
    val_images = val[:5]
    #if val_images.ndim == 3:
    #    val_images = val_images[:, None, ...]
    
    pred_record = {'input': val_images, 'predictions': []}
    
    for model_type in ['n2v', 'n2v2']:
        model_name = f'{model_type}_{dset["name"]}_chwise'
        model_path = os.path.join(BASE_FOLDER, model_name, 'checkpoints', 'last.ckpt')
        print(model_path, os.path.isfile(model_path))
        
        model = CAREamist(model_path)
        try:
            predictions = model.predict(val_images)
        except RuntimeError:
            predictions = np.zeros_like(val_images)
            print(f'Failed to predict with {model_path}')
        pred_record['predictions'].append({'model_name': model_name, 'prediction': predictions})
    
    imgs_to_show.append(pred_record)

        

In [None]:
def visualize_predictions(prediction_list):
    """
        Prediction list is a list of dictionaries with keys 'input' and 'predictions'.
        Each dict is a record for multiple inputs coming from the same dataset.

        'input' is an ndarray of shape (N, C, H, W) if C > 1, else (N, H, W)
        'predictions' is a list of dictionaries with keys 'model_name' and 'prediction'.
        'model_name' is a string, 'prediction' is an ndarray of shape (N, C, H, W) if C > 1, else (N, H, W)

    """

    for d, dataset in enumerate(prediction_list):
        fig = plt.figure(figsize=(50, 10), layout='constrained', dpi=500)
        fig.suptitle(f'Dataset: {d}')
        
        input_ = dataset['input']
        predictions = dataset['predictions']

        subfig = fig.subfigures(1, input_.shape[0])
        for i, sf in enumerate(subfig):
            input_img = input_[i]
            if input_img.ndim == 2:
                input_img = input_img[None, ...]
            axs = sf.subplots(input_img.shape[0], len(predictions) + 1)
            if input_img.shape[0] == 1:
                axs = axs[None, ...]
            for c in range(input_img.shape[0]):
                axs[c, 0].imshow(input_img[c, ...], cmap='gray')
                axs[c, 0].set_title(f'Input {i} , Channel {c}')
                axs[c, 0].axis('off')
                for j, pred in enumerate(predictions):
                    axs[c, j+1].imshow(pred['prediction'][i, c, ...], cmap='gray')
                    axs[c, j+1].set_title(pred['model_name'].split('_')[0])
                    axs[c, j+1].axis('off')
        #fig.tight_layout()
        fig.show()





visualize_predictions(imgs_to_show)
plt.show()