In [None]:
import numpy as np
import torch
import h5py

from pathlib import Path
from torch.utils.data import DataLoader
from torch.utils.data.sampler import WeightedRandomSampler
from torchvision.transforms import Compose
from src.harmonization.inet_pn1 import IntensityNet
from src.datasets.tools.lidar_dataset import LidarDatasetNP, LidarDatasetHDF5
from src.datasets.tools.transforms import CloudAngleNormalize
from src.datasets.tools.transforms import Corruption, GlobalShift, CloudJitter
from src.datasets.dublin.config import config as dataset_config
from src.training.config import config as train_config
from src.datasets.tools.dataloaders import get_transforms

config = {
    'dataset': dataset_config,
    'train': train_config
}


In [None]:
results_path = Path(f"{config['train']['results_path']}{config['dataset']['use_ss_str']}{config['dataset']['shift_str']}")
n_size = config['train']['neighborhood_size']
epoch="14"

device = config['train']['device']
model = IntensityNet(
    n_size, 
    interpolation_method="pointnet").double().to(device)

model.load_state_dict(torch.load(results_path / f"{n_size}_epoch={epoch}.pt"))
model_path = results_path / f"{n_size}_epoch={epoch}.pt"
print(f"Loaded model: {model_path}")
model.eval()


In [None]:
# create the eval tile if it doesn't exist
from src.datasets.tools.create_dataset import create_eval_tile, setup_eval_hdf5
if config['dataset']['eval_dataset'].exists():
    config['dataset']['eval_dataset'].unlink()
setup_eval_hdf5(config['dataset'])

# seems like this works faster when you write chunks that are smaller than the max chunk size set during dataset creation
# create_eval_tile(config['dataset'], chunk_size=config['dataset']['max_chunk_size'])
create_eval_tile(config['dataset'])

In [None]:
# load the eval tile
transforms = get_transforms(config)
print(transforms)
eval_tile = config['dataset']['eval_dataset']
eval_source = config['dataset']['eval_source_scan']
lidar_dataset = LidarDatasetHDF5(
                Path(config['dataset']['eval_dataset']), 
                transform=transforms,
                mode='eval',
                ss=config['dataset']['use_ss'])
eval_dataloader = DataLoader(
            lidar_dataset,
            batch_size=config['train']['batch_size'],
            sampler=None,
            shuffle=False,
            num_workers=config['train']['num_workers'],
            drop_last=False)


In [None]:
target_scan_num = 1
size = config['dataset']['eval_tile_size']
hz = torch.empty(size).double()
ip = torch.empty(size).double()
cr = torch.empty(size).double()
gt = torch.empty(size).double()
xyz = torch.empty(size, 3).double()

n_size = config['train']['neighborhood_size']
b_size = config['train']['batch_size']

with torch.no_grad():
    for i, batch in enumerate(eval_dataloader):
        ldx = i * b_size
        hdx = (i+1) * b_size
        xyz[ldx:hdx] = batch[:, 0, :3]
        batch[:, 0, -1] = target_scan_num
        
        batch = batch.to(config['train']['device'])
        
        h_target = batch[:, 0, 3].clone()
        i_target = batch[:, 1, 3].clone()
        harmonization, interpolation, _ = model(batch)
        
        
        hz[ldx:hdx] = harmonization.cpu().squeeze()
        ip[ldx:hdx] = interpolation.cpu().squeeze()
        cr[ldx:hdx] = i_target.cpu() # corruption
        gt[ldx:hdx] = h_target.cpu()
        
scan_error = torch.mean(torch.abs((gt - hz)))
corruption_error = torch.mean(torch.abs((cr - gt)))
interpolation_error = torch.mean(torch.abs((ip - cr)))

print(f"Results: Harmonization MAE: {scan_error}, Corruption MAE: {corruption_error}, Interpolation MAE: {interpolation_error}")

In [None]:
from src.datasets.tools.metrics import create_kde
create_kde(gt, np.clip(hz.numpy(), 0, 1), xlabel="gt", ylabel="predicted harmonization")

In [None]:
print(hz.shape, cr.shape, ip.shape, gt.shape, xyz.shape)
my_cloud = np.concatenate((xyz.numpy(), 
                           np.expand_dims(gt.numpy(), 1),
                           np.expand_dims(np.clip(hz.numpy(), 0, 1), 1),
                           np.expand_dims(cr.numpy(), 1),
                           np.expand_dims(np.clip(ip.numpy(), 0, 1), 1)
                          ), axis=1)
                           
print(my_cloud.shape)

In [None]:
from pptk import viewer
v = viewer(my_cloud[:, :3])
#       gt,             hz,             cr,             ip
attr = [my_cloud[:, 3], my_cloud[:, 4], my_cloud[:, 5], my_cloud[:, 6]]
v.attributes(*attr)

In [None]:
v.set(r=245.5078125, theta=1.57079637, phi=-1.57079637, lookat=config['dataset']['eval_tile_center'])
v.set(show_grid=False, show_info=False, show_axis=False, bg_color=[1, 1, 1, 1])
v.color_map("jet", scale=[0, 1])

In [None]:
import time
v.set(curr_attribute_id=0); time.sleep(.5)
v.capture("et_gt.png")
v.set(curr_attribute_id=3); time.sleep(.5)
v.capture("et_ip.png")
v.set(curr_attribute_id=1); time.sleep(.5)
v.capture("et_hz.png")
v.set(curr_attribute_id=2); time.sleep(.5)
v.capture("et_cr.png")
time.sleep(.5)

In [None]:
from PIL import Image
figname = f"igarss_fig{config['dataset']['shift_str']}.png"
images = [Image.open(x) for x in ['et_gt.png', 'et_hz.png', 'et_cr.png']]
widths, heights = zip(*(i.size for i in images))

total_width = sum(widths)
max_height = max(heights)

new_im = Image.new('RGB', (total_width, max_height))

x_offset = 0
for im in images:
    new_im.paste(im, (x_offset,0))
    x_offset += im.size[0]

new_im.save(figname)


In [None]:
from IPython.display import Image as IpImage
IpImage(filename=figname) 

In [None]:
print(figname)

In [None]:
# perform histogram matching on this tile!
# we can re-use my_cloud with the CR channel and then just apply histogram using
# whatever target scan was originally chosen

from src.evaluation.histogram_matching import hist_match

corrupted_intensities = my_cloud[:, 5].copy()
target_cloud = np.load(config['dataset']['scans_path'] / (config['dataset']['target_scan']+'.npy'))
print(target_cloud.shape)
harmonized_intensities = hist_match(corrupted_intensities, target_cloud[:, 3])
print(harmonized_intensities.shape)

v = viewer(my_cloud[:, :3])
#       gt,             hz,                     cr             
attr = [my_cloud[:, 3], harmonized_intensities, my_cloud[:, 5]]
v.attributes(*attr)

v.set(r=245.5078125, theta=1.57079637, phi=-1.57079637, lookat=config['dataset']['eval_tile_center'])
v.set(show_grid=False, show_info=False, show_axis=False, bg_color=[1, 1, 1, 1])
v.color_map("jet", scale=[0, 1])

v.set(curr_attribute_id=0); time.sleep(.5)
v.capture("et_gt.png")
v.set(curr_attribute_id=1); time.sleep(.5)
v.capture("et_hz.png")
v.set(curr_attribute_id=2); time.sleep(.5)
v.capture("et_cr.png")
time.sleep(.5)



In [None]:
figname = f"igarss_fig{config['dataset']['shift_str']}_hm.png"
images = [Image.open(x) for x in ['et_gt.png', 'et_hz.png', 'et_cr.png']]
widths, heights = zip(*(i.size for i in images))

total_width = sum(widths)
max_height = max(heights)

new_im = Image.new('RGB', (total_width, max_height))

x_offset = 0
for im in images:
    new_im.paste(im, (x_offset,0))
    x_offset += im.size[0]
new_im = new_im.convert("RGBA")
new_im.save(figname)


In [None]:
IpImage(filename=figname) 


In [None]:
print(figname)

In [None]:
from src.datasets.tools.metrics import create_kde
create_kde(gt, harmonized_intensities, xlabel="gt", ylabel="predicted harmonization")

In [None]:
scan_error = np.mean(np.abs(gt.numpy() - harmonized_intensities))
print(scan_error)

In [None]:
[Path(img).unlink() for img in ['et_gt.png', 'et_hz.png', 'et_cr.png', 'et_ip.png']]