In [None]:
import keras
import random
import numpy as np
from pathlib import Path
from PIL import Image
from utils import DataGenerator
import matplotlib.pyplot as plt

root = Path.cwd().parent

In [None]:
model = keras.saving.load_model(root / 'models' / 'MAE_ep108_loss0.0043.keras')
loss = keras.losses.MeanSquaredError( reduction="sum_over_batch_size")
model.compile(loss=loss,optimizer=model.optimizer)

In [None]:
validation_dataset = DataGenerator(root / 'data' / 'validation' / 'crop' / 'npz', 128)
model.evaluate(validation_dataset)

In [None]:
def to_pil(img):

    assert not np.isnan(img).any(), 'NAN'

    #grayscale_img = ((img - img.min()) * (1/(img.max() - img.min()) * 255)).astype('uint8')
    image = Image.fromarray(img.squeeze().astype('uint8'))
    return image

def npz_to_png(dataset, folder: Path):
    dataset.batch_size = None
    x,y = dataset[0]
    for stack, gt, filename in zip(x, y, dataset.filenames):
        sample_folder: Path = folder  / filename.split('.')[0]
        sample_folder.mkdir(exist_ok=True, parents=True)
        to_pil(gt.squeeze()).save(sample_folder / f'gt.png')
        for i, focal in enumerate(['000','040', '080', '120', '160', '200']):
            plane = stack[:,:,i]
            to_pil(plane).save(sample_folder / f'{focal}.png')

In [None]:
path = root / 'data' / 'test' / 'cases' / 'npz'
for p in path.iterdir():
    target = path.parent / 'images' / p.name
    target.mkdir(exist_ok=True, parents=True)
    npz_to_png(DataGenerator(p),target)

In [None]:
npz_to_png(
    DataGenerator(root / 'data' / 'test' / 'original' / 'npz'),
    root / 'data' / 'test' / 'original' /'images'
)

In [None]:
# create the filtered datasets
if False:
    import shutil
    datasets = dict()
    trees = [0,100,200]
    poses = ['no_person', 'idle','sitting', 'laying']
    for tree in trees:
        for pose in poses:
            dataset = DataGenerator(
                basedir=root / 'data' /'test'/ 'original' / 'npz', 
                included_trees=[tree], 
                included_poses=[pose],
                only_use_n=45,
                shuffle=True,
            )
            datasets[str(tree)+'_trees_'+pose] =  dataset
            target = root / 'data' / 'test' / 'cases'/ 'npz' / (str(tree)+'_trees_'+pose)
            target.mkdir(exist_ok=True, parents=True)
            print(len(dataset.filenames))
            for filename in dataset.filenames:
                shutil.copyfile(dataset.basedir / filename, target / filename)

In [None]:
path = root / 'data' / 'test' / 'cases' / 'npz'
performance = dict()
for p in path.iterdir():
    dataset = DataGenerator(p, 45)
    performance[p.stem] = model.evaluate(dataset)

In [None]:
performance

In [None]:
plt.figure(figsize=(6, 6))
plt.plot([0,100,200], 3*[np.mean(list(performance.values()))], linestyle='--', color='red', label='average', alpha=0.5, linewidth=2,)
for pose in  ['no_person', 'idle','sitting', 'laying']:
    posedict = {k:v for k,v in performance.items() if pose in k}
    sorted_items = sorted(posedict.items(), key=lambda x: int(x[0].split('_')[0]))
    sorted_keys = [ int(item[0].split('_')[0]) for item in sorted_items]
    sorted_values = [item[1] for item in sorted_items]
    plt.plot(sorted_keys, sorted_values, marker='o', alpha=0.5, linewidth=2, label=pose)

plt.xticks([0,100,200], [0,100,200])
plt.xlabel('number of trees per ha')
plt.ylabel('test loss')
plt.title('Test loss by pose')
plt.legend()
plt.tight_layout()
plt.savefig('test-loss-by-trees-and-pose.png')