Building likelihood maps for a set of representative images.

> TODO: use images from dataset and superimpose the bounding boxes on them

In [None]:
import retinoto_py as fovea
args = fovea.Params(do_fovea=True, batch_size=1)
args

In [None]:
import torch
from torchvision.io import read_image

In [None]:
# idx_to_label = fovea.get_idx_to_label(args)
label2idx = fovea.get_label_to_idx(args)
label2idx['impala']

In [None]:
for key in label2idx.keys(): 
    if 'wolf' in key.lower(): print(key)

In [None]:
dataset = 'bbox'
model_filename = args.data_cache / f'32_fovea_model_name={args.model_name}_dataset={dataset}.pth'
model = fovea.load_model(args, model_filename=model_filename)
model_filename

In [None]:
from torchvision.transforms.functional import InterpolationMode, resize
image_size_full = 512
resolution = (54, 34)
resolution = (34, 21)
resolution = (21, 13)
size_ratio = .4
alpha = .8
s_max = 200
import cmocean
cmap = cmocean.cm.haline

In [None]:
list_images = [('leopard.jpg',  'leopard'),
                # ('wolf.jpg',  'white_wolf'),
                ('frog.jpg',  'tree_frog'),
                # ('Hidden_snow_leopard.png',  'snow_leopard'),
                ('Hidden_leopard.png',  'leopard'),
                # ('Hidden_giraffe.png',  'girafe'),
                ('Hidden_tiger.png',  'tiger'),
                ('Hidden_owl.png',  'great_grey_owl'),
                ]

# list_images = []

for image_url, true_label in  list_images:
    full_image = read_image('./images/' + image_url)[:3, :, :]/255
    full_image = resize(full_image, image_size_full, interpolation=InterpolationMode.BILINEAR, antialias=True)
    # print(f"{type(full_image) = }, {full_image.dtype = }, {full_image.shape = }")

    for do_fovea  in [True, False]:
        print('do_fovea = ', do_fovea)
        args = fovea.Params(do_mask=not(do_fovea), do_fovea=do_fovea)
        pos_H, pos_W = fovea.get_positions(full_image.shape[1], full_image.shape[2], resolution=resolution)
        probas = fovea.compute_likelihood_map(args, model, full_image, size_ratio=size_ratio)
        probas = probas.cpu()
        proba_label = probas[:, label2idx[true_label]]
        idx_max = proba_label.argmax()


        fig, ax = fovea.plt.subplots()
        full_image_np = torch.movedim(full_image, (1, 2, 0), (0, 1, 2)).numpy()
        ax.imshow(full_image_np)
        scatter = ax.scatter(pos_W, pos_H, s=proba_label.abs()*s_max, c=proba_label, alpha=alpha, edgecolors='none', cmap=cmap, vmin=0, vmax=1)
        print(f'Max proba_label ={max(proba_label).item():.3f}')

        ax.scatter(pos_W[idx_max], pos_H[idx_max], s=proba_label[idx_max]*s_max, marker='*', c='red', alpha=alpha)
        fig.colorbar(scatter, ax=ax)  # Add colorbar

        fig.set_facecolor(color='white')
        fovea.plt.show()

In [None]:

for image_url, true_label in list_images:
    full_image = read_image('./images/' + image_url)[:3, :, :]/255
    full_image = resize(full_image, image_size_full, interpolation=InterpolationMode.BILINEAR, antialias=True)
    # print(f"{type(full_image) = }, {full_image.dtype = }, {full_image.shape = }")

    for do_fovea  in [True, False]:
        args = fovea.Params(do_mask=not(do_fovea), do_fovea=do_fovea)

        pos_H, pos_W = fovea.get_positions(full_image.shape[1], full_image.shape[2], resolution=resolution)
        probas = fovea.compute_likelihood_map(args, model, full_image, size_ratio=size_ratio)

        probas = probas.cpu()
        proba_label = probas[:, label2idx[true_label]]
        idx_max = proba_label.argmax()

        print(50*'=-')
        print(f'{image_url=} contains a {true_label}?')
        print(50*'.-')
        fig, ax = fovea.plt.subplots()
        full_image_np = torch.movedim(full_image, (1, 2, 0), (0, 1, 2)).numpy()
        ax.imshow(full_image_np)
        scatter = ax.scatter(pos_W, pos_H, s=proba_label.abs()*s_max, c=proba_label, alpha=alpha, edgecolors='none', cmap=cmap, vmin=0, vmax=1)
        print(f'Max proba_label ={max(proba_label).item():.3f}')

        ax.scatter(pos_W[idx_max], pos_H[idx_max], s=proba_label[idx_max]*s_max, marker='*', c='red', alpha=alpha)
        fig.colorbar(scatter, ax=ax)  # Add colorbar

        fig.set_facecolor(color='white')
        fovea.plt.show()

In [None]:
n_batch = 5

In [None]:
from torchvision.transforms import v2 as transforms
dataset = 'full'
VAL_DATA_DIR = args.DATAROOT / f'Imagenet_{dataset}' / 'val'
val_dataset = fovea.get_dataset(args, VAL_DATA_DIR, do_full_preprocess=False)
val_loader = fovea.get_loader(args, val_dataset)

for i_batch, (image, true_idx) in enumerate(val_loader):
    if i_batch >= n_batch : break
    # Since batch_size=1, the current index is just i_batch
    # Get the path of the current image
    path = val_loader.dataset.imgs[i_batch][0]
    print("Image path:", path)
    image, true_idx = image.to(args.device), true_idx.to(args.device)
    image = image.squeeze(0)    

    crop_size = max(image.shape[-2], image.shape[-1])
    pad_width = max(0, (crop_size - image.shape[-1]) // 2)
    pad_height = max(0, (crop_size - image.shape[-1]) // 2)
    transform = transforms.Compose([
        transforms.Pad((pad_width, pad_width, pad_height, pad_height), padding_mode='reflect'),
        transforms.CenterCrop(crop_size),
    ])
    image = transform(image).squeeze(0)   
        
    fig, ax = fovea.plt.subplots()
    image_np = torch.movedim(image, (1, 2, 0), (0, 1, 2)).cpu().numpy()
    ax.imshow(image_np)
    fig.set_facecolor(color='white')
    fovea.plt.show()


In [None]:
# from torchvision.transforms import v2 as transforms
# args = fovea.Params(do_fovea=True, batch_size=1)
# for dataset in fovea.params.all_datasets:
#     VAL_DATA_DIR = args.DATAROOT / f'Imagenet_{dataset}' / 'val'
#     val_dataset = fovea.get_dataset(args, VAL_DATA_DIR, do_full_preprocess=False)
#     val_loader = fovea.get_loader(args, val_dataset)
    
#     print('dataset=', dataset, ' from', VAL_DATA_DIR, 'dataset=', dataset)
#     print('dataset=', VAL_DATA_DIR)

#     for i_batch, (image, true_idx) in fovea.tqdm(enumerate(val_loader), total=n_batch):
#         print(50*'=-')
#         # print(f'{image_url=} contains a {true_label}?')
#         print(f'image contains a {true_idx} = {val_loader.dataset.idx2label[true_idx]}')
#         print(50*'.-')

#         if i_batch >= n_batch : break
#         image, true_idx = image.to(args.device), true_idx.to(args.device)
#         image = image.squeeze(0)
#         crop_size = min(image.shape[1], image.shape[2])
#         image = transforms.CenterCrop(crop_size)(image)

#         pos_H, pos_W, probas = fovea.compute_likelihood_map(args, model, image.squeeze(0), size_ratio=size_ratio, resolution=resolution)
#         probas = probas.cpu()
#         proba_label = probas[:, true_idx.cpu()]
#         idx_max = proba_label.argmax()



#         fig, ax = fovea.plt.subplots()
#         image_np = torch.movedim(image, (1, 2, 0), (0, 1, 2)).cpu().numpy()
#         ax.imshow(image_np)
#         scatter = ax.scatter(pos_W, pos_H, s=proba_label.abs()*s_max, c=proba_label, alpha=alpha, edgecolors='none', cmap=cmap, vmin=0, vmax=1)
#         print(f'Max proba_label ={max(proba_label).item():.3f}')

#         ax.scatter(pos_W[idx_max], pos_H[idx_max], s=proba_label[idx_max]*s_max, marker='*', c='red', alpha=alpha)
#         fig.colorbar(scatter, ax=ax)  # Add colorbar

#         fig.set_facecolor(color='white')
#         fovea.plt.show()


In [None]:
from torchvision.transforms import v2 as transforms
# from torchvision import transforms
args = fovea.Params(do_fovea=True, batch_size=1)
for dataset in fovea.params.all_datasets:
    VAL_DATA_DIR = args.DATAROOT / f'Imagenet_{dataset}' / 'val'
    val_dataset = fovea.get_dataset(args, VAL_DATA_DIR, do_full_preprocess=False)
    val_loader = fovea.get_loader(args, val_dataset)

    for i_batch, (image, true_idx) in fovea.tqdm(enumerate(val_loader), total=n_batch):
        if i_batch >= n_batch : break

        print(50*'=-')
        # print(f'{image_url=} contains a {true_label}?')
        print(f'image contains a {true_idx} = {val_loader.dataset.idx2label[true_idx]}')
        print(50*'.-')

        image, true_idx = image.to(args.device), true_idx.to(args.device)
        
        crop_size = max(image.shape[-2], image.shape[-1])
        pad_width = max(0, (crop_size - image.shape[-1]) // 2)
        pad_height = max(0, (crop_size - image.shape[-1]) // 2)
        transform = transforms.Compose([
            transforms.Pad((pad_width, pad_width, pad_height, pad_height), padding_mode='reflect'),
            transforms.CenterCrop(crop_size),
        ])
        image = transform(image).squeeze(0)   

        pos_H, pos_W = fovea.get_positions(image.shape[1], image.shape[2], resolution=resolution)
        probas = fovea.compute_likelihood_map(args, model, full_image, size_ratio=size_ratio)
        probas = probas.cpu()
        proba_label = probas[:, true_idx.cpu()]
        idx_max = proba_label.argmax()

        fig, ax = fovea.plt.subplots()
        image_np = torch.movedim(image, (1, 2, 0), (0, 1, 2)).cpu().numpy()
        ax.imshow(image_np)
        scatter = ax.scatter(pos_W, pos_H, s=proba_label.abs()*s_max, c=proba_label, alpha=alpha, edgecolors='none', cmap=cmap, vmin=0, vmax=1)
        print(f'Max proba_label ={max(proba_label).item():.3f}')

        ax.scatter(pos_W[idx_max], pos_H[idx_max], s=proba_label[idx_max]*s_max, marker='*', c='red', alpha=alpha)
        fig.colorbar(scatter, ax=ax)  # Add colorbar

        fig.set_facecolor(color='white')
        fovea.plt.show()
