# Notebook for plotting GARISOMv3 model timesteps output

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from itertools import cycle # used for cycling colors

In [None]:
files = []
log_large_differences = True # switch to True if logged y-axis is desirable
clip = True # clip large outlier values

In [None]:
def plot_across_index(dt, var_name, log=False, ylim=None, xlim=None):
    plt.figure(figsize=(10, 6))
    plt.plot(dt[var_name] if not log else np.log1p(dt[var_name]), label=f'{var_name}', alpha=0.5, color="b")
    plt.xlabel('Index')
    log_str = "log "
    plt.ylabel(f'{log_str if log else ""}{var_name}')
    plt.title(f'{var_name}')
    if ylim:
        plt.ylim(ylim)
    if xlim:
        plt.xlim(xlim)
    plt.legend()
    plt.show()

In [None]:
def plot(files):
    for file in files:
        dt = pd.read_csv(file)
        dt = dt[dt['year'] != 0]

        skip = 3 # skip year, month, and jd
        for i, col_name in enumerate(dt.columns[skip:]):
            max_val = dt[col_name].max()
            min_val = dt[col_name].min()
            log = (max_val - min_val) > 10**2 and log_large_differences
            print(col_name)
            plot_across_index(dt, col_name, log=log)

In [None]:
colors = ['b', 'g', 'r', 'c']

def multiplot_across_index(dts, var_name, log=False, clip=False, ylim=None, xlim=None):
    plt.figure(figsize=(10, 6))
    threshold = max(np.percentile(dt[var_name], 99) for dt, _ in dts)
    for (dt, file), color in zip(dts, cycle(colors)):
        plt_data = dt[var_name]
        if clip:
            plt_data = np.clip(plt_data, None, threshold)
        if log:
            plt_data = np.log1p(plt_data)

        plt.plot(plt_data, label=f'{file},{var_name}', alpha=0.5, color=color)
    plt.xlabel('Index')
    log_str = "log "
    plt.ylabel(f'{log_str if log else ""}{var_name}')
    plt.title(f'{var_name} across files')
    if ylim:
        plt.ylim(ylim)
    if xlim:
        plt.xlim(xlim)
    plt.legend()
    plt.show()

In [None]:
def multiplot(files):
    dts = []
    for file in files:
        dt = pd.read_csv(file)
        dt = dt[dt['year'] != 0]
        dts.append((dt, file))
    
    skip = 3 # skip year, month, and jd
    for i, col_name in enumerate(dt.columns[skip:]):
        max_val = max(dt[col_name].max() for dt, _ in dts)
        min_val = min(dt[col_name].min() for dt, _ in dts)
        log = (max_val - min_val) > 10**2 and log_large_differences
        print(col_name)
        multiplot_across_index(dts, col_name, log=log, clip=clip)
    

In [None]:
multiplot(files)