In [None]:
import os
os.environ['AWS_PROFILE'] = 'admin'
os.environ['HAVEN_DATABASE'] = 'haven'

import plotly.graph_objects as go
import plotly.express as px
import pandas as pd
import numpy as np
import h3
from tqdm import tqdm

from mirrorverse.utils import read_data_w_cache
from mirrorverse.plotting import build_geojson

In [None]:
sql = '''
select 
    h3_index,
    chlorophyll, 
    time,
    extract(year from time) as year,
    extract(month from time) as month
from 
    copernicus_biochemistry
where 
    h3_resolution = 4
    and depth_bin = 25.0
    and extract(year from time) in (2015, 2016, 2017, 2018)
    and extract(day from time) = 1
'''
raw_data = read_data_w_cache(sql)
raw_data['lat'] = raw_data['h3_index'].apply(lambda h: h3.h3_to_geo(h)[0])
raw_data['lon'] = raw_data['h3_index'].apply(lambda h: h3.h3_to_geo(h)[1])
raw_data['epoch'] = raw_data['time'].astype('int64') // 10**9
raw_data['raw_chlorophyll'] = raw_data['chlorophyll']

raw_data = raw_data[(raw_data['lon'] > -170) & (raw_data['lat'] > 42) & (raw_data['lat'] < 64)]
raw_data = raw_data.sort_values(['epoch', 'h3_index'], ascending=False).reset_index(drop=True)
print(raw_data.shape)
raw_data.head()

In [None]:
px.histogram(raw_data['raw_chlorophyll'])

In [None]:
raw_data['chlorophyll'] = (np.log(raw_data['raw_chlorophyll']) - np.log(raw_data['raw_chlorophyll']).mean()) / (np.log(raw_data['raw_chlorophyll']).std())

In [None]:
px.histogram(raw_data['chlorophyll'])

In [None]:
rows = []
for h3_index in tqdm(list(raw_data['h3_index'].unique())):
    df = raw_data[raw_data['h3_index'] == h3_index].reset_index(drop=True)
    for i, entry in df.iterrows():
        row = {
            'year': entry['year'],
            'month': entry['month'],
            'h3_index': entry['h3_index'],
        }
        vals = df['chlorophyll'].values[i:12+i]
        if len(vals) < 12:
            break
        for j, val in enumerate(vals):
            row[f'chlorophyll_{j}'] = val
        rows.append(row)
data = pd.DataFrame(rows)
print(data.shape)
data.head()

In [None]:
from sklearn.model_selection import train_test_split

X = data[[f'chlorophyll_{i}' for i in range(12)]]
X = X - np.min(X)
X = X / np.max(X)
X_train, X_test = train_test_split(X, test_size=0.33, random_state=42)

In [None]:
from tensorflow.keras.models import Model
import tensorflow as tf
from tensorflow.keras import layers, losses

In [None]:
class Autoencoder(Model):
    def __init__(self, latent_dim, shape):
        super(Autoencoder, self).__init__()
        self.latent_dim = latent_dim
        self.shape = shape
        self.encoder = tf.keras.Sequential([
            layers.Dense(latent_dim, activation='linear')
        ])
        self.decoder = tf.keras.Sequential([
            layers.Dense(shape, activation='sigmoid') # note X has to be in the range of 0 - 1
        ])

    def call(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded

shape = 12
latent_dim = 1
autoencoder = Autoencoder(latent_dim, shape)

In [None]:
autoencoder.compile(optimizer='adam', loss=losses.MeanSquaredError())

In [None]:
autoencoder.fit(X_train, X_train,
                epochs=10,
                shuffle=True,
                validation_data=(X_test, X_test))

In [None]:
X_test

In [None]:
X_test_pred = autoencoder.predict(X_test)
losses.MeanSquaredError()(X_test, X_test_pred)

In [None]:
indices = pd.DataFrame(autoencoder.encoder(X))
indices = pd.concat([data[['h3_index', 'month', 'year']], indices], axis=1)

In [None]:
df = indices[indices['year'] == 2016]

fig = go.Figure()
geojson = build_geojson(df, 'h3_index')
months = sorted(df['month'].unique())
for month in months:
    sdf = df[df['month'] == month]
    fig.add_trace(
        go.Choroplethmapbox(
            geojson=geojson,
            locations=sdf['h3_index'],
            z=sdf[0],
            visible=False,
            marker_line_color='rgba(255,255,255,0)',
            zmin=df[0].min(),
            zmax=df[0].max(),
            colorscale='algae'
        )
    )

fig.data[0].visible = True

steps = []
for i, slider_val in enumerate(months):
    step = dict(
        method="update",
        args=[
            {"visible": [False] * len(months)},
            {"title": f"month: {slider_val}"},
        ],
        label=f"{slider_val}"
    )
    step["args"][0]["visible"][i] = True
    steps.append(step)

sliders = [dict(
    active=0,
    currentvalue={"prefix": f"feature: "},
    pad={"t": 50, "b": 25, "l": 25},
    steps=steps
)]

fig.update_layout(
    sliders=sliders
)

fig.update_layout(
    autosize=False,  # Disable autosizing
    width=800,       # Set width in pixels
    height=800,      # Set height in pixels
)

fig.update_layout(
    margin={"r":0,"t":30,"l":0,"b":0}, mapbox=dict(style="carto-positron", zoom=3, center = {"lat": 57, "lon": -150})
)

fig.show()

In [None]:
df = raw_data[raw_data['year'] == 2016]

fig = go.Figure()
geojson = build_geojson(df, 'h3_index')
months = sorted(df['month'].unique())
for month in months:
    sdf = df[df['month'] == month]
    fig.add_trace(
        go.Choroplethmapbox(
            geojson=geojson,
            locations=sdf['h3_index'],
            z=sdf['chlorophyll'],
            visible=False,
            marker_line_color='rgba(255,255,255,0)',
            zmin=df['chlorophyll'].min(),
            zmax=df['chlorophyll'].max(),
            colorscale='algae'
        )
    )

fig.data[0].visible = True

steps = []
for i, slider_val in enumerate(months):
    step = dict(
        method="update",
        args=[
            {"visible": [False] * len(months)},
            {"title": f"month: {slider_val}"},
        ],
        label=f"{slider_val}"
    )
    step["args"][0]["visible"][i] = True
    steps.append(step)

sliders = [dict(
    active=0,
    currentvalue={"prefix": f"feature: "},
    pad={"t": 50, "b": 25, "l": 25},
    steps=steps
)]

fig.update_layout(
    sliders=sliders
)

fig.update_layout(
    autosize=False,  # Disable autosizing
    width=800,       # Set width in pixels
    height=800,      # Set height in pixels
)

fig.update_layout(
    margin={"r":0,"t":30,"l":0,"b":0}, mapbox=dict(style="carto-positron", zoom=3, center = {"lat": 57, "lon": -150})
)

fig.show()