# GRSS DFC 2018 Data Analysis Notebook

In [10]:
### Futures ###
#TODO

### Built-in Imports ###
import datetime
import os
import time

### Other Library Imports ###
from chart_studio import tools
import chart_studio.plotly as py
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
# from plotly import tools
from plotly.offline import init_notebook_mode, iplot
import plotly.figure_factory as ff
import plotly.graph_objects as go
# import plotly.plotly as py
import seaborn as sns

### Local Imports ###
from data.grss_dfc_2018_uh import UH_2018_Dataset

In [11]:
### Environment ###
init_notebook_mode(connected=True)
pd.set_option('display.max_columns', 100)

In [12]:
def load_data():
    """
    """
    dataset = UH_2018_Dataset()
    gt_train = dataset.load_full_gt_image(train_only=True)
    gt_test = dataset.load_full_gt_image(test_only=True)
    hs_data = dataset.load_full_hs_image(thres=False, 
                                         normalize=False, 
                                         resampling='average',   # Also need to try 'nearest'
                                        )
    lidar_ms_data = dataset.load_full_lidar_ms_image(thres=False,
                                                     normalize=False,
                                                     resampling=None,
                                                    )
    lidar_dsm_data = dataset.load_full_lidar_dsm_image(thres=False,
                                                       normalize=False,
                                                       resampling=None,
                                                      )
    lidar_dem_data = dataset.load_full_lidar_dem_image(thres=False,
                                                       normalize=False,
                                                       resampling=None,
                                                      )
    vhr_rgb_data = dataset.load_full_vhr_image(thres=False,
                                               normalize=False,
                                               resampling='cubic_spline',
                                              )

In [13]:
def load_hs_data():
    """
    """
    dataset = UH_2018_Dataset()
    gt_train = dataset.load_full_gt_image(train_only=True)
    hs_data = dataset.load_full_hs_image(thres=False, 
                                         normalize=False, 
                                         resampling='average',   # Also need to try 'nearest'
                                        )

    return hs_data, gt_train

In [14]:
def get_sample_pixels(data, gt, ignored_labels, num_samples):
    """
    """
    pass

In [15]:
def get_class_dataframes(data, gt, ignored_labels, class_labels, channel_labels=['data']):
    """
    """
    if data.ndim == 3:
        height, width, channels = data.shape
    elif data.ndim == 2:
        height, width = data.shape
        channels = 1
    else:
        raise Exception('Data does not have 2 or 3 dimensions!')
    
    class_info = {label: [] for label in class_labels if label not in ignored_labels}

    for row in range(height):
        for col in range(width):
            label = class_labels[gt[row, col]]
            if label in ignored_labels: continue

            class_info[label].append(data[row, col])
    
    for label in class_info:
        class_info[label] = pd.DataFrame(np.array(class_info[label]),
                                         columns=channel_labels)
    
    return class_info

In [16]:
def save_dataframe_descriptions(dataframes, output_path='./'):
    """
    """
    file_name = f'{label}_description.csv'
    file_path = os.path.join(output_path, file_name)
    for label in dataframes:
        dataframes[label].describe().to_csv(file_path)

In [17]:
def get_box_plots(dataframes, output_path='./'):
    """
    """
    
    for label in dataframes:
        # Make dataframe variable
        df = dataframes[label]

        # Generate file path for saving image
        file_name = f'{label}__box_plots.png'
        file_path = os.path.join(output_path, file_name)

        # Get random color profile for box plot
        red, green, blue = list(np.random.choice(256, size=3))

        data = [
            go.Box(
                y=df[column],
                name=column,
                marker=dict(
                    color = f'rgb({red},{green},{blue})',
                ),
            )
            for column in df.columns
        ]

        layout = go.Layout(
            title = f'Plots of channel intensity distribution for "{label}"'
        )

        # Create figure and save image
        fig = go.Figure(data=data, layout=layout)
        # fig.write_image(file_path)

        # Show figure
        py.iplot(fig)

In [18]:
dataset = UH_2018_Dataset()
ignored_labels = dataset.gt_ignored_labels
class_labels = dataset.gt_class_label_list
channel_labels = dataset.hs_band_wavelength_labels

data, gt = load_hs_data()
dataframes = get_class_dataframes(data, gt, 
                                    ignored_labels, 
                                    class_labels, 
                                    channel_labels=channel_labels)
# save_dataframe_descriptions(dataframes, output_path='./analysis/grss_dfc_2018/descriptions/')
get_box_plots(dataframes, output_path='./analysis/grss_dfc_2018/box_plots/')

Loading training ground truth image...
Loading training ground truth tiles...
Merging image tiles...
Loading full hyperspectral image...



Woah there! Look at all those points! Due to browser limitations, the Plotly SVG drawing functions have a hard time graphing more than 500k data points for line charts, or 40k points for other types of charts. Here are some suggestions:
(1) Use the `plotly.graph_objs.Scattergl` trace object to generate a WebGl graph.
(2) Trying using the image API to return an image instead of a graph URL
(3) Use matplotlib
(4) See if you can create your visualization with fewer data points




KeyboardInterrupt: 