In [None]:
import tensorflow as tf
import numpy as np
import pandas as pd
from plotly import graph_objs as go
from plotly import tools
import plotly.offline as pyo
from google.cloud import storage
%matplotlib inline
pyo.init_notebook_mode(connected=True)

In [None]:
# change working directory to base, to make all imports and file paths work
import os
os.chdir(os.pardir)
print("Current directory: %s" % os.getcwd())

In [None]:
import trainer.data_pipeline as dp
import trainer.constants as cst
from trainer.helpers import print_dict_keys, simple_plotly
from trainer.evaluation import get_predictions_results, plot_predictions_and_errors

In [None]:
# make sure that saved_model has been created with the same version of TensorFlow
tf.__version__

In [None]:
model_dir = "data/ion_age_20190614_205447/saved_model"
dataset_dir = cst.TEST_SET

In [None]:
model = tf.keras.experimental.load_from_saved_model(model_dir)

# To look at available model methods:
# list(method for method in dir(model) if not method.startswith("_"))

In [None]:
# Predict with dataset
window_size = 20
shift = 5
stride = 1
batch_size = 16

dataset = dp.create_dataset(dataset_dir,
                            window_size=window_size,
                            shift=shift,  # Can vary during validation
                            stride=stride,
                            batch_size=batch_size,  # Can vary during validation
                            cycle_length=1,  # To match original order (so no files get interleaved)
                            num_parallel_calls=1,  # Has to be equal or below cycle_length
                            shuffle=False,  # To match original order
                            repeat=False)
scaling_factors = dp.load_scaling_factors()

In [None]:
results = get_predictions_results(model, dataset, scaling_factors)

In [None]:
def create_cell_index(results_df, cell_index_col_name="cell_index", inplace=False):
    # Initialization
    if inplace:
        results = results_df
    else:
        results = results_df.copy()
    results[cell_index_col_name] = 0
    
    # Getting the indexes which encapsulate all cells
    new_cell_index = list(results[results["target_current_cycle"].diff() < 0].index)
    new_cell_index.append(len(results))  # Add the last index manually, since there is no diff < 0
    last_s = 0  # Set first starting index manually
    
    # Setting cell_indexes
    for i, s in enumerate(new_cell_index):
        results[cell_index_col_name].iloc[last_s:s] = i
        last_s = s
    
    if not inplace:
        return results

results = create_cell_index(results)

In [None]:
results.head()

In [None]:
def get_binned_cycle_count_trace(results_df, window_size, cycle_bin_width, column="target_current_cycle"):
    # Get the current cycle value counts sorted from low to high
    current_cycle_counts = (results_df[column]
                            .value_counts()
                            .sort_index()
                            .reset_index()
                            .rename(columns={"index": "current_cycle_value",
                                             column: "count"}))


    # The actual cycle counts can be {20: 42,  21: 2,  25: 42,  26: 2} since some cycles were dropped
    # Binning aggregates these "outliers" with bin size equal to shift
    bins = list(range(window_size, results_df[column].max(), cycle_bin_width))
    bins.append(results[column].max())
    grouped_cycle_counts = (current_cycle_counts
                            .groupby(pd.cut(current_cycle_counts["current_cycle_value"], bins=bins))
                            .sum()
                            .loc[:, "count"])
    
    # Convert to percent, since the absolute counts can vary widely, when cycle_bin_width changes
    grouped_cycle_counts = ((grouped_cycle_counts - grouped_cycle_counts.min())
                            / (grouped_cycle_counts.max() - grouped_cycle_counts.min())) * 100

    return go.Scatter(
        x = np.array(bins) - window_size,  # shift necessary to line up with error traces
        y = grouped_cycle_counts,
        name = "Cells count"
    )

def get_errors_over_cycle_traces(results_df, cycle_bin_width=100):
    results = results_df.copy()

    # Calculate absolute errors
    results["ae_current_cycle"] = (results["target_current_cycle"] - results["pred_current_cycle"]).abs()
    results["ae_remaining_cycles"] = (results["target_remaining_cycles"] - results["pred_remaining_cycles"]).abs()
    
    # Create bin intervalls
    bins = list(range(0, results["target_current_cycle"].max(), cycle_bin_width))
    bins.append(results["target_current_cycle"].max())
    
    # Aggregate mean absolute errors over bins and save as new dataframe
    mae_binned = (results.groupby(pd.cut(results_df["target_current_cycle"], bins=bins))
                 .mean()
                 .loc[:, ["ae_current_cycle", "ae_remaining_cycles"]])
    
    std_binned = (results.groupby(pd.cut(results_df["target_current_cycle"], bins=bins))
                  .std()
                  .loc[:, ["ae_current_cycle", "ae_remaining_cycles"]])
    
    # Build mean absolute errors over bins
    mae_current_cycle_trace = go.Bar(
        x = bins,
        y = mae_binned["ae_current_cycle"],
        name = "mae_current_cycle"
    )
    mae_remaining_cycles_trace = go.Bar(
        x = bins,
        y = mae_binned["ae_remaining_cycles"],
        name = "mae_remaining_cycles"
    )
    
    # Build standard deviation of absolute errors over bins
    std_current_cycle_trace = go.Bar(
        x = bins,
        y = std_binned["ae_current_cycle"],
        name = "std_current_cycle"
    )
    std_remaining_cycles_trace = go.Bar(
        x = bins,
        y = std_binned["ae_remaining_cycles"],
        name = "std_remaining_cycles"
    )

    return (mae_current_cycle_trace,
            mae_remaining_cycles_trace,
            std_current_cycle_trace,
            std_remaining_cycles_trace)



mae_cc, mae_rc, _, _ = get_errors_over_cycle_traces(results, 40)


count_trace = get_binned_cycle_count_trace(results, window_size, 40)
count_trace.update(dict(
    mode= 'none',
    fill='tozeroy',
    fillcolor="rgba(210, 210, 210, 0.5)",
    yaxis="y2",
))

# dtick_error = 100
# max_error = max(mae_cc.y.max(), mae_rc.y.max())
# tickvals_error = list(range(0, int(max_error), dtick_error))
# print(max(tickvals_error) / max_error)

layout = dict(
    height=600,
    width=1000,
    xaxis=dict(
        title="Cycle"
    ),
    yaxis=dict(
        title="Mean absolute error",
        overlaying="y2",
        tickmode="array",
        tickvals=tickvals_error,

        dtick=100
    ),
    yaxis2=dict(
        title="Cell count",
        side="right",
        tickmode="array",
    )
)


fig = go.Figure(data=[count_trace,
                      mae_cc,
                      mae_rc],
                layout=layout)

pyo.iplot(fig)

In [None]:
# import numpy as np

# for i, (cell_k, cell_v) in enumerate(preprocessed_pkl.items()):
#     if i == 1:
#         break
#     cycle_keys = list(cell_v["cycles"].keys())

#     window_cycle_keys = []
#     for i, w_slice in enumerate(range(0, len(cycle_keys), shift)):
#         cycle_keys_slice = cycle_keys[w_slice:w_slice + window_size]
#         if len(cycle_keys_slice) % window_size == 0:  # drop remainder
#             window_cycle_keys.append(cycle_keys_slice)
    
#     assert np.all([len(w) == window_size for w in window_cycle_keys]), \
#         "Not all windows have the correct window_size"
#     print("Number of windows in '{}': {}".format(cell_k, len(window_cycle_keys)))
#     #print(window_cycle_keys)
#     for j, (w_cycle_keys, example) in enumerate(zip(window_cycle_keys, dataset)):  # Iterate over all windows
#         if j == 3:
#             break
#         print("\n##### Window {} #####".format(j))
        
#         # Dataset values
#         print("Example shapes:", {k:v.shape for k, v in example[0].items()})
#         print("\nDataset Target:", example[1].numpy())
#         # Processed Values
#         processed_target = cell_v["summary"][cst.REMAINING_CYCLES_NAME] \
#                                 / scaling_factors[cst.REMAINING_CYCLES_NAME]
#         #print("Processed Target:", processed_target[(window_size + stride * j)])
#         for z, k in enumerate(w_cycle_keys):  # Itereate over all strides
#             print("\n# Step")
#             scaled = example[0][cst.QDLIN_NAME][:, z, :, :].numpy().squeeze() * scaling_factors[cst.QDLIN_NAME]
#             print(scaled.shape)
#             processed = cell_v["cycles"][k][cst.QDLIN_NAME].squeeze()
#             print(processed.shape)
#             print("Num close:, ", np.sum(np.isclose(scaled, processed, atol=1e-8)))
#             processed_rc = (cell_v["summary"][cst.REMAINING_CYCLES_NAME][z + j * (window_size - stride)] \
#                                 / scaling_factors[cst.REMAINING_CYCLES_NAME]).astype(np.float32)
#             print(processed_rc)
#             print(cell_v["summary"].keys())