<a href="https://colab.research.google.com/github/nneibaue/ocean_explorer/blob/master/explorer_official.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

<h1>Ocean Data Explorer</h1>

This notebook is written to provide some basic visualization tools in Python using some of Colab's nice output features. It's fairly basic, but should provide a decent example of how Python and Colab can be useful for something like this. Google provides a free runtime in the cloud, so no need to install Python and set anything up on the computer. The free version of Colab has more than enough features, memory, and drive space for our purposes here 


## Research Context

Samples are collected at different depths for a given location in the ocean (e.g. lat, long). Each of these samples is measured for concentrations of various different elements via a 2D scan, yielding concentration values at individual pixels. A given pixel may contain a non-trivial concentration value for one or more elements. 

It is of particular interest how a given element (Cu in this case) is distributed among different element groups for a given scan. For example, one pixel could contain non-trivial concentrations of Cu, Mg, Br, and Zn, whereas another pixel might only contain Fe and Mg. 


## Problem Statement

* Given a dataset for a single location, how does the distribution of an element vary with depth? Assumptions include:
  * There can be many scans at a given depth
  * No two scans overlap in space
  * Concentration values ($[x]$) at a pixel are only considered non-trivial if:
  $$
  [x] > \bar{[x]} + 2 \cdot \sigma_x 
  $$
  where $\bar{[x]}$ is the average concentration value and $\sigma$ is the standard deviation
  * Concentration values filtered by an element are only considered non-trivial if the element in question satisfies the above condition
    * E.g. a pixel may contain non-trivial amounts of Ca and Mg, but not Cu. If we are filtering by Cu, then this pixel is rejected


**Please don't edit this notebook directly. To make changes, first make a copy of the notebook.**

#Setup

The following cell clones the github repo so private libraries can be imported.

In [None]:
#@title Clone Github Repo

BRANCH_NAME = "live_debug" #@param {type:"string"}

import os
import sys
import shutil

ROOT = '/content'
os.chdir(ROOT)
REPO_NAME = 'ocean_explorer'
REPO_URL = f'https://github.com/nneibaue/{REPO_NAME}'
REPO_PATH = os.path.join(ROOT, REPO_NAME)


# Remove old repo
print('Removing old repo...')
!rm -rf $REPO_PATH

print('Cloning from github...')
!git clone $REPO_URL
os.chdir(REPO_PATH)

if BRANCH_NAME != 'master':
  !git checkout --track origin/$BRANCH_NAME
  !git config user.email "colab_anon@gmail.com"
else:
  !git pull
  
if REPO_PATH not in sys.path:
  print(f'Adding {REPO_PATH} to path')
  sys.path.append(REPO_PATH)

os.chdir(ROOT)

In [None]:
#@title Imports

# etsp stuff
from ocean import Detsum, Scan, CombinedScan, Depth
from plotting import ribbon_plot, encode_matplotlib_fig

# Colab output stuff
from google.colab import drive
from google.colab import widgets
from IPython.display import display, HTML
import ipywidgets

# General
import numpy as np
import random
import re
import pandas as pd

# Plotting
from cycler import cycler
import altair as alt
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
#from IPython import display, html
#Namespace class to keep things organized
class Namespace:
  def __init__(self, **kwargs):
    self.__dict__.update(**kwargs)

In [None]:
#@title Connect Google Drive

drive.mount('/content/gdrive')
DRIVE_BASE = '/content/gdrive/My Drive'

# Analysis

In [None]:
#@title Data Import

#@markdown Enter drive path to data folder (do not include 'My Drive/'):
EXPERIMENT_DIR = "software_development/etsp/XRF data deglitch/" #@param{type:"string"}
#depth_path = "software_development/etsp/XRF data/25m" #@param {type:"string"}
#@markdown Enter elements separated by comma
ELEMENTS_OF_INTEREST = "Br,Ca,Cu,Fe,K,Cl,Mn,S,Si,Zn" #@param {type:"string"}
ELEMENTS_OF_INTEREST=ELEMENTS_OF_INTEREST.split(',')
ORBITALS = "K" #@param {type:"string"}

def import_data(data_path):
  depths = []
  for d in os.listdir(os.path.join(DRIVE_BASE, data_path)):
    try:
      fullpath = os.path.join(DRIVE_BASE, data_path, d)
      d = Depth(os.path.join(fullpath),
                      ELEMENTS_OF_INTEREST,
                      orbitals=['K'],
                      normalized=True)
      depths.append(d)
      print(f"Successfully imported data for {d.depth}")
    except NameError as e:
      print(e)
      pass
  return depths

depths = import_data(EXPERIMENT_DIR)

In [None]:
#@title Element Filter

In [None]:
# Function template
mean_n_std = lambda n: (lambda x: np.mean(x) + np.std(x) * n)

ELEMENT_FILTER = {
    'Br': mean_n_std(2),
    'Ca': mean_n_std(2),
    'Cu': mean_n_std(2),
    'Fe': mean_n_std(2),
    'K': mean_n_std(2),
    'Cl': mean_n_std(2),
    'Mn': mean_n_std(2),
    'S': mean_n_std(2),
    'Si': mean_n_std(2),
    'Zn': mean_n_std(2),
}

del(mean_n_std)

## Looking for Glitches

Uncomment the last line and run the following cell to plot all `Detsums` from all depths.

In [None]:
#@title Plot All Detsums

elements_to_plot = 'Br,Ca,Cu,Fe,K,Cl,Mn,S,Si,Zn' #@param {type:"string"}
sort_by = "element" #@param ["element", "depth"]

#@markdown To show plots, check box below and run cell
show_plots = True #@param {type:"boolean"}

def plot_all_detsums(depths, elements=None, sort_by='element'):
  '''Plots the raw data from all detsums of the given elements.

  Args:
    depths: list of Depth objects
    elements: optional list of elements. E.g. ['Cu', 'Fe']. If this 
      is `None`, then all elements will be plotted
    sort_by: string. Can either be 'element' or 'depth'. This will
      determine how the detsums are sorted before they are rendered
      to the screen. This is set to 'element' by default
    
  Returns: raw detsums plotted in a grid
  '''

  # Triple looping to get detsums from all depths
  detsums = []
  for d in depths:
    for s in d.scans:
      for detsum in s.detsums:
        if elements is not None:
          if detsum.element not in elements:
            continue # skip to next iteration
        detsums.append(detsum)

  # Determine sorting function 
  if sort_by == 'element':
    sort_func = lambda d: d.element
  elif sort_by == 'depth':
    sort_func = lambda d: int(d.depth.split('m')[0]) # Turn depth into integer for sorting
  else:
    raise ValueError("`sort_by` must be 'element' or 'depth'")
  
  # Sort detsums
  detsums = sorted(detsums, key=sort_func)

  # Build grid
  ncols = 4
  nrows = 1 + (len(detsums) // ncols)
  g = widgets.Grid(nrows, ncols)
  row = 0
  col = 0
  for i, detsum in enumerate(detsums):
    with g.output_to(row, col):
      #print(f'Element: {detsum.element}, Depth: {detsum.depth}, Scan: {detsum.scan_name}')
      print(f'    {detsum.element}    |    {detsum.depth}    |    {detsum.scan_name}')
      display(ipywidgets.HTML(detsum.plot(raw=True, base64=True)))
    if (col + 1) % 4 == 0:
      row += 1
      col = 0
    else:
      col += 1

##Example Usage
#=====================================
#Uncomment this line to plot all detsums from Iron and Copper, e.g:
#plot_all_detsums(depths, elements=['Fe', 'Cu'])

#Uncomment this line to plot all detsums from all elements and sort by depth:
if show_plots:
  plot_all_detsums(depths,
                  elements=elements_to_plot.split(','),
                  sort_by=sort_by)

del(elements_to_plot, sort_by, show_plots)

## **Plotting**

In [None]:
#@title Ribbon Plot UI { form-width: "200px" }
def ribbon_plot_ui(depths):
  experiment_dir = os.path.join(DRIVE_BASE, EXPERIMENT_DIR)
  status_indicator = ipywidgets.Output()
  with status_indicator:
    display(ipywidgets.HTML('<h3 style="color:green">Ready</h3>'))
  graph_output = ipywidgets.Output()
  element_inputs = {}
  element_filter = {}
  test = {}
  smalltextbox = ipywidgets.Layout(width='50px', height='25px')
  filter_func = lambda n: lambda x: np.mean(x) + np.std(x)*n
  
  for e in ELEMENTS_OF_INTEREST:
    element_inputs[e] = ipywidgets.Textarea(value='2', layout=smalltextbox)
    element_filter[e] = filter_func(2)
    
  filter_by_control = ipywidgets.Dropdown(options=ELEMENTS_OF_INTEREST,
                                          value='Cu', description='Filter by:')
  
  combine_scans_checkbox = ipywidgets.Checkbox(value=True, description='Combine Scans')
  
  combine_detsums_checkbox = ipywidgets.Checkbox(value=False, description='Combine Detsums')
  
  normalize_by_control = ipywidgets.Dropdown(options=['counts', 'pixels'],
                                            value='counts',
                                            description='Normalize By')
  
  
  N_input = ipywidgets.Textarea(value='8', layout=ipywidgets.Layout(width='150px'), description='N')
  update_button = ipywidgets.Button(description='Update Plot')                          
  clear_output_control = ipywidgets.Checkbox(value=False, description='Clear output after each run')
  
  element_filter_input = ipywidgets.HBox(
      [ipywidgets.VBox([ipywidgets.HTML(f'<h3>{e}</h3>'), element_inputs[e]]) for e in ELEMENTS_OF_INTEREST]
  )
  
  def update_plot(b):
    status_indicator.clear_output()
    for e in ELEMENTS_OF_INTEREST:
      val = float(element_inputs[e].value)
      element_filter[e] = filter_func(val)
    with status_indicator:
      display(ipywidgets.HTML('<h3 style="color:red">Working...</h3>'))

    info_banner_html = (f'Filter by: {filter_by_control.value} | '
                      f'Comb. Scans: {combine_scans_checkbox.value} | '
                      f'Comb. Detsums: {combine_detsums_checkbox.value} | '
                      f'N: {N_input.value} | '
                      f'Normalize By: {normalize_by_control.value} | ' + \
                      ''.join([f'{e}: {element_inputs[e].value} | ' for e in ELEMENTS_OF_INTEREST]))
    info_banner = ipywidgets.HTML(info_banner_html)

    plot = ribbon_plot(depths, element_filter=element_filter,
                filter_by=filter_by_control.value,
                combine_detsums=combine_detsums_checkbox.value,
                combine_scans=combine_scans_checkbox.value,
                N=int(N_input.value),
                normalize_by=normalize_by_control.value,
                base64=True,
                experiment_dir=experiment_dir)
    if clear_output_control.value:
      graph_output.clear_output()

    with graph_output:
      display(ipywidgets.VBox([ipywidgets.HTML(plot), ipywidgets.HTML(info_banner_html)]))

    status_indicator.clear_output()
    with status_indicator:
      display(ipywidgets.HTML('<h3 style="color:green">Ready</h3>'))
  
  update_button.on_click(update_plot)
  

  update_plot('this param does not matter here')  
  
  controls_bot = element_filter_input
  controls_top = ipywidgets.HBox([ipywidgets.VBox([update_button, status_indicator]),
                              ipywidgets.VBox([filter_by_control, normalize_by_control]),
                              ipywidgets.VBox([combine_scans_checkbox, combine_detsums_checkbox]),
                              N_input, clear_output_control])

  controls = ipywidgets.VBox([controls_top, controls_bot],
                            layout=ipywidgets.Layout(
                                border='1px solid black',
                                width='100%',
                            ))
  app = ipywidgets.VBox([graph_output, controls])
  display(app)
    
ribbon_plot_ui(depths)

In [None]:
# Filter by Cu, take the top 8 groups. Separate Scans, Combined Detsums
# ribbon_plot(depths, element_filter=ELEMENT_FILTER,
#             filter_by='Cu',
#             combine_scans=False,
#             combine_detsums=True,
#             N=8,
#             normalize_by='counts',
#             base64=True,
#             prop_dict=PROP_DICT)

In [None]:
#@title Image UI { form-width: "200px" }
def make_image_ui(depths):
  plot_area = ipywidgets.Output()
  status_indicator = ipywidgets.Output()
  with status_indicator:
    display(ipywidgets.HTML('<h3 style="color:green">Ready</h3>'))
  element_inputs = {}
  element_filter = {}
  smalltextbox = ipywidgets.Layout(width='50px', height='25px')
  filter_func = lambda n: lambda x: np.mean(x) + np.std(x)*n

  combine_detsums_checkbox = ipywidgets.Checkbox(value=False, description='Combine Detsums')
  
  for e in ELEMENTS_OF_INTEREST:
    element_inputs[e] = ipywidgets.FloatSlider(2, min=0, max=4, step=0.1, description=e)
    element_filter[e] = filter_func(2)
  
  element_selector_boxes = [ipywidgets.Checkbox(value=False, description=e, indent=False) for e in ELEMENTS_OF_INTEREST]

  element_selectors = ipywidgets.VBox([ipywidgets.HBox([box, element_filter]) for box, element_filter in zip(element_selector_boxes, element_inputs.values())])
  
  depth_selector_layout = ipywidgets.Layout(indent=False, width='200px')
  depth_selector_boxes = [ipywidgets.Checkbox(value=False, description=depth.depth, indent=False) for depth in depths]
  depth_selector_labels = [ipywidgets.HTML(depth.depth) for depth in depths]
  
  
  depth_selectors = {d.depth: box for d, box in zip(depths, depth_selector_boxes)}
  
  depth_selector_input = ipywidgets.VBox([ipywidgets.HTML('<h3>Depths</h3>')] + list(depth_selectors.values()))
  
  update_button = ipywidgets.Button(description='Update')
  controls = ipywidgets.VBox([ipywidgets.HBox([depth_selector_input, element_selectors],
                            layout=ipywidgets.Layout(width='75%')), 
                            ipywidgets.HBox([update_button, combine_detsums_checkbox, status_indicator])])

  def update_plot(b):
    plot_area.clear_output()
    status_indicator.clear_output()
    with status_indicator:
      display(ipywidgets.HTML('<h3 style="color:red">Working....</h3>'))
    depths_to_plot = [depth for depth, selector in depth_selectors.items() if selector.value]
    depths_to_plot = [depth for depth in depths if depth.depth in depths_to_plot] 

    elements_to_plot = []
    for element, selector in zip(ELEMENTS_OF_INTEREST, element_selector_boxes):
      if selector.value:
        elements_to_plot.append(element)

    for e in elements_to_plot:
      #element_filter.update({e: filter_func(element_inputs[e].value)})
      element_filter[e] = filter_func(element_inputs[e].value)

    for depth in depths_to_plot:
      depth.apply_element_filter(element_filter,
                                 combine_detsums=combine_detsums_checkbox.value)


    def get_row(depth, elements):
      detsums = sorted(depth.detsums, key=lambda d: d.element)
      plots = [ipywidgets.HTML(detsum.plot(base64=True))
              for detsum in detsums if detsum.element in elements]
      return plots


    rows = [ipywidgets.HBox(get_row(depth, elements_to_plot)) for depth in depths_to_plot]

    with plot_area:
      display(ipywidgets.VBox(rows))
    
    status_indicator.clear_output()
    with status_indicator:
      display(ipywidgets.HTML('<h3 style="color:green">Ready</h3>'))



  update_plot(True)
  update_button.on_click(update_plot)
  # for selector in depth_selectors.values():
  #   selector.observe(update_plot)
  display(ipywidgets.VBox([controls, plot_area]))
  
  
make_image_ui(depths)