In [None]:
%load_ext autoreload
%autoreload 2
!pip install numpy matplotlib tqdm
!pip install --upgrade pip

import numpy as np
from pathlib import Path
from tqdm import tqdm
import matplotlib.pyplot as plt


path = Path('/mnt/e/PycharmProjects/Colorization/rawData/results/results.npz')
if not path.is_file():
    raise FileNotFoundError(f'File {path} not found')
path_to_save = Path('figs/scheme/')
path_to_save.mkdir(parents=True, exist_ok=True)
data = {k:v for k,v in np.load(path).items()}
print('Loaded data from', path)
print('Available keys:', list(data.keys()))
print('Number of samples:', len(data['mono']))
print('Shape of samples:', data['mono'][0].shape)

# Display some samples

### PAN

In [None]:
# crop = 1

# indices = [63,99]
# # indices = np.arange(0, len(data['pan']))
# # indices = np.random.choice(len(data['pan']), 30, replace=False)
# print('Selected indices:', indices)
# for idx in sorted(indices):
#     pan = data['pan'][idx][crop:-crop, crop:-crop]
#     plt.figure(frameon=False)
#     plt.imshow(pan, cmap='gray', interpolation='none')
#     # plt.axis('off')
#     plt.title(f'Panchromatic image {idx}')
#     plt.tight_layout(pad=0)
#     plt.show()
#     # plt.savefig(path_to_save / f'h_maker_pan_{idx}.png', bbox_inches='tight', pad_inches=0)


# b=a


### MONO

In [None]:
crop = 40

# indices = np.random.choice(len(data['mono']), 10, replace=False)
# print('Selected indices:', indices)
# for idx in indices:
#     mono = data['mono'][idx][crop:-crop, crop:-crop]
#     pan = data['pan'][idx][crop:-crop, crop:-crop]
#     mask = data['mask'][idx]
#     plt.figure(figsize=(15, 5))
#     plt.subplot(131)
#     plt.imshow(mono)
#     plt.title('Mono')
#     plt.subplot(132)
#     plt.imshow(pan)
#     plt.title('Pan')
#     plt.subplot(133)
#     plt.imshow(mask)
#     plt.title('Mask')
#     plt.show()



### Both

In [None]:
crop = 35


losses = []
for idx in tqdm(range(len(data['mono']))):
    mono = data['mono'][idx][crop:-crop, crop:-crop]
    petit = data['pred_petit'][idx][crop:-crop, crop:-crop]
    losses.append(np.linalg.norm(mono - petit, ord=2) / np.prod(mono.shape))


In [None]:
cmap = 'gray'
indices = np.argsort(losses)[:50]
print('Selected indices:', indices)
for idx in indices:
    mono = data['mono'][idx][crop:-crop, crop:-crop]
    pan = data['pan_original'][idx][crop:-crop, crop:-crop]
    petit = data['pred_petit'][idx][crop:-crop, crop:-crop]
    baseline = data['pred_baseline'][idx][crop:-crop, crop:-crop]
    loss_l1 = data['l1_petit'][idx]
    loss_l2 = losses[idx]
    mask = data['mask'][idx]

    # plt.figure()
    # plt.imshow(mono - petit, cmap='bwr')
    # plt.colorbar()
    # plt.title(f"Index {idx}, l1 {loss_l1:.2f}, l2 {loss_l2:.2f}")
    # plt.show()


    # plt.figure(frameon=False)
    # plt.imshow(mono, cmap='gray', interpolation='none')
    # plt.axis('off')
    # plt.title(f'Panchromatic image {idx}')
    # plt.tight_layout(pad=0)
    # plt.show()
    # plt.savefig(path_to_save / f'data_pan_{idx}.png', bbox_inches='tight', pad_inches=0)

    plt.figure(figsize=(15, 5))
    plt.subplot(131)
    plt.imshow(mono, cmap=cmap)
    plt.title('Mono')
    plt.subplot(132)
    plt.imshow(pan, cmap=cmap)
    plt.title('Pan')
    plt.subplot(133)
    plt.imshow(petit, cmap=cmap)
    plt.title('Petit')
    plt.title(f"Index {idx}, l1 {loss_l1:.2f}, l2 {loss_l2:.2f}")
    plt.show()



In [None]:
# Make schematic homography matrix figure
h_scheme = np.array([[1.02, -1e-3, 12],
                     [0.01, 0.98, -26],
                     [1e-5, -1e-6, 1]])
h_scheme

In [None]:
# Save figs for schematic
idx = 78
mono = data['mono'][idx]
warped = data['pan'][idx]
pan = data['pan_original'][idx]
petit = data['pred_petit'][idx]

# # Save mono
# plt.figure(frameon=False)
# plt.imshow(mono, cmap='gray', interpolation='none')
# plt.axis('off')
# plt.tight_layout(pad=0)
# plt.show()
# plt.savefig(path_to_save / f'scheme_data_mono_{idx}.png', bbox_inches='tight', pad_inches=0)

# # Save warped
# vmin = warped[warped > 0].min()
# plt.figure(frameon=False)
# plt.imshow(warped, cmap='gray', interpolation='none', vmin=vmin)
# plt.axis('off')
# plt.tight_layout(pad=0)
# # plt.show()
# plt.savefig(path_to_save / f'scheme_data_warped_{idx}.png', bbox_inches='tight', pad_inches=0)

# # Save pan
# plt.figure(frameon=False)
# plt.imshow(pan, cmap='gray', interpolation='none')
# plt.axis('off')
# plt.tight_layout(pad=0)
# # plt.show()
# plt.savefig(path_to_save / f'scheme_data_pan_{idx}.png', bbox_inches='tight', pad_inches=0)

# # Save mask
# mask = warped == 0
# plt.figure(frameon=False)
# plt.imshow(mask, cmap='gray', interpolation='none')
# plt.axis('off')
# plt.tight_layout(pad=0)
# # plt.show()
# plt.savefig(path_to_save / f'scheme_data_mask_{idx}.png', bbox_inches='tight', pad_inches=0)

# Save pred petit
crop = 35
plt.figure(frameon=False)
plt.imshow(petit[crop:-crop, crop:-crop], cmap='gray', interpolation='none')
plt.axis('off')
plt.tight_layout(pad=0)
plt.show()
# plt.savefig(path_to_save / f'scheme_data_petit_{idx}.png', bbox_inches='tight', pad_inches=0)

