In [None]:
from pathlib import Path
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from src.dataset.tools.checkerboard import checkerboard
import pprint

# this is pretty old
pp = pprint.PrettyPrinter(indent=4)

project_root = Path(".").absolute().parents[1]
print(project_root)

base_tile_paths = {
    "gt": project_root / "dataset" / "synth_crptn"/ "big_tile_no_overlap" / "gt.txt.gz",
    "alt": project_root / "dataset" / "synth_crptn" / "big_tile_no_overlap" / "alt.txt.gz"
}

shift_tile_paths = {
    "gt": project_root / "dataset" / "synth_crptn+shift"/ "big_tile_no_overlap" / "gt.txt.gz",
    "alt": project_root / "dataset" / "synth_crptn+shift" / "big_tile_no_overlap" / "alt.txt.gz"
}

# for dark-mode: `jt -t monokai -f fira -fs 10 -nf fira -nfs 11 -N -kl -cursw 2 -cursc r -cellw 95% -T`
# for default: `jt`
COLOR = 'white'
matplotlib.rcParams['text.color'] = COLOR
matplotlib.rcParams['axes.labelcolor'] = COLOR
matplotlib.rcParams['xtick.color'] = COLOR
matplotlib.rcParams['ytick.color'] = COLOR

# more readable font size?
matplotlib.rcParams.update({'font.size': 18})

# plot size
matplotlib.rcParams['figure.figsize'] = [32, 8]

In [None]:
def make_single_plot(tile, name=""):
    fig, ax = plt.subplots(1)
    
    ax.scatter(tile[:, 0], tile[:, 1], c=tile[:, 4], s=1, vmin=0, vmax=512)   # fixed tiles have gt and fix channels
    ax.axis("off")
    plt.savefig(name)
    
def get_dl_tile_paths(n_size):
    paths = {
       "Deep Learning Harmonization": project_root / "dataset" / "synth_crptn" / "big_tile_no_overlap" / f"fixed_dl_{n_size}.txt.gz", 
       "Deep Learning Harmonization (Global Shift)": project_root / "dataset" / "synth_crptn+shift" / "big_tile_no_overlap" / f"fixed_dl_{n_size}.txt.gz", 

    }
    return paths

def get_li_paths(n_size, interp_method, harmonization_method):
    paths = {
        "Linear Interpolation + Least Squares Harmonization": project_root / "dataset" / "synth_crptn" /"big_tile_no_overlap" / f"fixed_li_{n_size}_{interp_method}_{harmonization_method}.txt.gz",
        "Linear Interpolation + Least Squares Harmonization (Global Shift)": project_root / "dataset" / "synth_crptn+shift" / "big_tile_no_overlap" / f"fixed_li_{n_size}_{interp_method}_{harmonization_method}.txt.gz",
        }
    return paths

def make_plots(fixed_tile_paths, save=True):
    for key, path in fixed_tile_paths.items(): 
            fig, ax = plt.subplots(1, 4)
            fig.suptitle(f"{key}")
            
            if "Global Shift" in key:
                gt_tile, alt_tile = np.loadtxt(shift_tile_paths['gt']), np.loadtxt(shift_tile_paths['alt'])
            else:
                gt_tile, alt_tile = np.loadtxt(base_tile_paths['gt']), np.loadtxt(base_tile_paths['alt'])
                                                                                            
            # Load in fix first
            tile = np.loadtxt(path)

            # clip below 0/1, put into same range as other tiles
            tile[:, 4] = np.clip(tile[:, 4], 0, 1)
            tile[:, 3] *= 512; tile[:, 4] *= 512

            mae = np.mean(np.abs((tile[:, 3] - tile[:, 4]))) / 512
            ax.flat[1].scatter(tile[:, 0], tile[:, 1], c=tile[:, 4], s=1, vmin=0, vmax=512)   # fixed tiles have gt and fix channels
            ax.flat[1].set_title(f"fixed (MAE: {mae:.5f})")
            ax.flat[1].axis("off")
            
            if "Global Shift" in key:
                tile = np.loadtxt(shift_tile_paths['gt']), tile
            else:
                tile = np.loadtxt(base_tile_paths['gt']), tile
                                                                                  
            cb_t1, cb_t2 = checkerboard(tile[0], tile[1])
            ax.flat[3].scatter(cb_t1[:, 0], cb_t1[:, 1], c=cb_t1[:, 3], s=1, vmin=0, vmax=512)
            ax.flat[3].scatter(cb_t2[:, 0], cb_t2[:, 1], c=cb_t2[:, 4], s=1, vmin=0, vmax=512)
            ax.flat[3].set_title("checkerboard fixed vs gt")
            ax.flat[3].axis("off")

            ax.flat[2].scatter(tile[0][:, 0], tile[0][:, 1])
            ax.flat[2].scatter(tile[0][:, 0], tile[0][:, 1], c=tile[0][:, 3], s=1, vmin=0, vmax=512)
            ax.flat[2].set_title("ground truth")
            ax.flat[2].axis("off")
            
            if "Global Shift" in key:
                tile = np.loadtxt(shift_tile_paths['alt'])
            else:
                tile = np.loadtxt(base_tile_paths['alt'])

            ax.flat[0].scatter(tile[:, 0], tile[:, 1], c=tile[:, 3], s=1, vmin=0, vmax=512)
            ax.flat[0].set_title("synthetic corruption")
            ax.flat[0].axis("off")
            plt.show()

# Custom paths
hm_tile_paths = {
    "Histogram Matching": project_root / "dataset" / "synth_crptn" / "big_tile_no_overlap" / "fixed_hm.txt.gz",
    "Histogram Matching (Global Shift)": project_root / "dataset" / "synth_crptn+shift" / "big_tile_no_overlap" / "fixed_hm.txt.gz"
}

cheat_tile_path = {
    "Deep Learning Harmonization (interpolation target as input feature)": project_root / "dataset" / "synth_crptn" / "big_tile_no_overlap" / "fixed_dl_0.txt.gz",
    "Deep Learning Harmonization (interpolation target as input feature) (Global Shift)" : project_root / "dataset" / "synth_crptn" / "big_tile_no_overlap" / "fixed_dl_0.txt.gz"
}


## Dataset Information
Below is the difference between the global shift and default ground truth evaluation tiles:

In [None]:
mae = np.mean(
    np.abs(
        np.loadtxt(base_tile_paths['gt'])[:, 3]/512 - 
        np.loadtxt(shift_tile_paths['alt'])[:, 3]/512))
print(mae)

## Special Cases

### Histogram Matching
Histogram matching features excellent performance when the scan regions are over areas with similar intensity distributions. However, as scan region increases, it is inevitable that the distributions will no longer line up. Below, histogram matching's performance is measured in two test cases: one test case where the intensity distribution is relatively uniform across all scans, and another where a global shift is applied to the eastern half of the scan area. As can be seen, the performance is significantly worse in this case. 

In [None]:
make_plots(hm_tile_paths)

### Pointnet Interpolation and MLP Harmonization with the interpolation target as an input feature
The deep learning method proposed in this project examines overlapping lidar scans. In an ideal case, the points in this overlap region would be perfectly aligned (correspondence), and a mapping for intensity could be derived directly. However, this is not usually the case.

Given two scans, source and target, with some overlapping region, correspondence between points can be achieved by interpolation. Given a point in the target scan, $X_t$ (with intensity $I_t$), a neighborhood $N_s$ in the source scan can be obtained by finding the closest points to $X_t$. 

In this project, each source scan is given a unique (monotonically increasing) transformation to simulate differing sensor/camera configurations. This transformation is applied to $N_s$ as well as the interpolated intensity target, yielding $I_s$. $I_t$ then becomes the harmonization target, and $I_s$ becomes the interpolation target for the transformed neighborhood. 

Normally, $I_s$ is stripped from each training example (as this would not exist in a real-world test case). If $I_s$ is supplied at train time, then there is never a need for the network to perform interpolation, and harmonization becomes the only task. Below is the result from this test case, which yields an exceedingly small error rate. 

In [None]:
make_plots(get_dl_tile_paths(0))

### Pointnet Interpolation, N = 5

In [None]:
make_plots(get_dl_tile_paths(5))

## N = 20

In [None]:
make_plots(get_dl_tile_paths(20))

## N = 50

In [None]:
make_plots(get_dl_tile_paths(50))