In [1]:
# --- Script to generate plots from spinup run
import glob
import os
import re

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

# --- set figure-generating parameters
runs2save = "all"  # (later, build functionality for this to be a list)
domain = ["field"]  # ["field, "lab"] or just one
prof_list = [
    "prof_aq",
    "prof_aq(ads%cec)",
    "prof_gas",
    "prof_sld(wt%)",
    "sa",
    "rate",
]  # "bsd"...
flx_list = ["flx_co2sp", "flx_gas", "flx_aq", "flx_sld"]
flx_plot_excludezeros = True  # True means we don't render lines that contribute no flux the entire timeseries
logtime_color = [
    True,
    False,
]  # [True, False] or just one; whether to make time colorbar a log scale

# --- where to save
save_base = "/home/tykukla/aglime-swap-cdr/scepter/process/spinups/figures"

In [2]:
# --- plotting decisions
# ***** PROFILE PLOTS ******
num_cols = 3  # number of columns for multipanel
varmax_threshold = (
    1e-6  # concentrations smaller than this won't get zoomed in on (effectively zero)
)
xax_titlesize = 20
yax_titlesize = 13
xticksize = 14
yticksize = 12
cbar_titlesize = 15
cbar_ticksize = 10

# ***** FLUX PLOTS ******
Fnum_cols = 3  # number of columns for multipanel
Fxax_titlesize = 20
Fyax_titlesize = 13
Fxticksize = 14
Fyticksize = 12
Fcbar_titlesize = 15
Fcbar_ticksize = 10

In [3]:
# --- read in spinup table
csv_fn = "spinup-inputs.csv"
csv_loc = "/home/tykukla/aglime-swap-cdr/scepter/batch-inputs"
dfin = pd.read_csv(os.path.join(csv_loc, csv_fn))

# add column for fullname (minus field / lab)
dfin["spinname_full"] = dfin["spinname"] + "_spintuneup"

dfin

Unnamed: 0,site,spinname,lat,lon,mat,soilmoisture,qrun,tsom,erosion,nitrif,tph,cec,tec,tsoilco2,poro,alpha,spinname_full
0,site_311,site_311,42.5,-91,8.22219,0.282727,0.351361,2.051667,0.001013,1.005952,6.058007,21.10329,20.98031,-1.80371,0.447,2.0,site_311_spintuneup
1,site_411,site_411,32.0,-83,18.52789,0.231552,0.243426,2.276667,0.00084,0.831883,5.200242,1.96125,46.91557,-1.61194,0.419,2.0,site_411_spintuneup


In [4]:
# --- FUNCTION to preprocess .txt files for consistent delimiters


def preprocess_txt(file_path):
    data = []  # Initialize a list to store the processed data

    # Initialize a flag to determine if we are reading the header
    is_header = True

    # Read the file line by line and process the data
    with open(file_path) as file:
        for line in file:
            line = line.strip()  # Remove leading/trailing whitespace
            if is_header:
                # Split the first line into column names
                column_names = re.split(r"\s+", line)
                is_header = False
            else:
                # Split the other lines into data values
                values = re.split(r"\s+", line)
                data.append(values)

    # Create a DataFrame with the processed data and set column names
    df = pd.DataFrame(data, columns=column_names)
    # return
    return df

In [5]:
# --- function to read in profile data


def read_prof_dat(resdir, runname_in, domain_in, prof_list):
    # where results are stored
    dirname = runname_in + "_" + domain_in
    results_path = os.path.join(resdir, dirname)
    prof_path = os.path.join(results_path, "prof")

    # define file name pattern
    fn_pref = prof_list
    fn_ext = ".txt"

    # loop through variables
    df = pd.DataFrame()  # initialize empty df to store dat
    for var in fn_pref:
        # read out status
        print("reading in " + var + "...")
        # set pattern
        fn_pattern = f"{var}-*{fn_ext}"
        # get list of filenames
        file_paths = glob.glob(f"{prof_path}/{fn_pattern}")

        # read in data and concatenate
        for file_path in file_paths:
            dfi = preprocess_txt(file_path)
            # apply pd.to_numeric to all columns using the "map" method
            dfi = dfi.map(pd.to_numeric)
            # add var
            dfi["var"] = var
            # combine
            df = pd.concat([df, dfi], ignore_index=True)

    # sort by time and depth
    df = df.sort_values(by=["var", "time", "z"])
    return df

In [6]:
# --- function to read in flux data
def read_flx_dat(resdir, runname_in, domain_in, flx_list):
    # where results are stored
    dirname = runname_in + "_" + domain_in
    results_path = os.path.join(resdir, dirname)
    flx_path = os.path.join(results_path, "flx")

    # define file name pattern
    fn_pref = flx_list
    fn_varInclude = []
    varCheck = True if len(fn_varInclude) > 0 else False
    fn_ext = ".txt"

    df = pd.DataFrame()  # initialize empty df to store dat

    for fset in fn_pref:
        # read out status
        print("reading in " + fset + "...")
        # set pattern
        fn_pattern = f"{fset}-*{fn_ext}"
        # get list of filenames
        file_paths = glob.glob(f"{flx_path}/{fn_pattern}")

        # read in data and concatenate
        for file_path in file_paths:
            # get the variable
            varpattern = re.escape(fset) + r"-(.*?).txt"
            varmatch = re.search(varpattern, file_path)
            var = varmatch.group(1)
            # skip this step if it's not in the include arr
            if varCheck:
                if var not in fn_varInclude:
                    continue
            # read in
            dfi = preprocess_txt(file_path)
            # apply pd.to_numeric to all columns using the "map" method
            dfi = dfi.map(pd.to_numeric)
            # add set, var, spinrun, ctrl
            dfi["set"] = fset
            dfi["var"] = var

            # combine
            df = pd.concat([df, dfi], ignore_index=True)

    # drop all time slices dangerously close to zero (these produce astronomical (like 10^10 or higher) residuals)
    # df = df.loc[df['time'] > 1e-3]

    # sort by time and depth
    df = df.sort_values(by=["set", "var", "time"])
    return df

In [7]:
def plot_profile(
    df,
    tdf,
    save_base,
    domain_in,
    logtime_color,
    num_cols=3,
    xax_titlesize=20,
    yax_titlesize=13,
    xticksize=14,
    yticksize=12,
    cbar_titlesize=15,
    cbar_ticksize=10,
    save_dpi=250,
    save_transparent=False,
    plot_prefix="PROF_",
):

    # --- turn off interactive mode
    plt.ioff()

    # --- set the output directory
    save_prof_dir = os.path.join(save_base, tdf["site"], domain_in)
    # create it if it doesn't exist
    if not os.path.exists(save_prof_dir):
        os.makedirs(save_prof_dir)

    # --- get list of variables to loop through
    thesevars = df["var"].unique()

    # LOOP ONE --- VARS
    for thisvar in thesevars:
        # extract just the variable we want
        dfx = df[df["var"] == thisvar]
        # remove columns with all nan
        dfx1 = dfx.dropna(axis=1, how="all")
        # remove columns whose names are numbers
        number_pattern = re.compile(
            r"^[-+]?\d*\.?\d+(?:[eE][-+]?\d+)?$"
        )  # required to detect columns that are numbers
        for col in dfx1.columns:
            if number_pattern.match(col):
                dfx1.drop(columns=col, inplace=True)
        # set the plot df
        dfp = dfx1

        # LOOP 2 --- WHETHER COLOR BAR IS A LOG SCALE
        for log_col in logtime_color:
            # Get the list of variables
            variables = [col for col in dfp.columns if col not in ["z", "time", "var"]]

            # group by time
            grouped = dfp.groupby("time")

            # Calculate number of rows and columns for subplots
            num_rows = -(-len(variables) // num_cols)  # Round up division

            # Create a colormap
            cmap = plt.get_cmap("magma")  # Get the colormap
            cmap = cmap.reversed()  # flip colormap
            # -- no log normalization
            if not log_col:
                norm = plt.Normalize(
                    df["time"].min(), df["time"].max()
                )  # Normalize time values for colormap
            # -- log normalization WIP
            else:
                log_time = np.log(df["time"])
                norm = plt.Normalize(
                    log_time.min(), log_time.max()
                )  # Normalize time values for colormap

            # Create a colorbar based on the Viridis colormap
            sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)

            # Create subplots
            fig, axes = plt.subplots(
                num_rows, num_cols, figsize=(5 * num_cols, 4 * num_rows)
            )

            # Plot each variable versus depth
            for i, var in enumerate(variables):
                row = i // num_cols
                col = i % num_cols
                ax = axes[row, col] if num_rows > 1 else axes[col]
                # get max (to determine how to handle y-axis)
                varmax = np.nanmax(dfp[var])

                # Create a plot for each time step
                for time, group in grouped:
                    color = cmap(norm(time))  # Map time to color using Viridis colormap
                    ax.plot(
                        group[var], group["z"], color=color, label=None, linewidth=3
                    )
                ax.set_ylabel("Depth", size=yax_titlesize)
                ax.set_xlabel(var, size=xax_titlesize)
                ax.tick_params(
                    axis="x", which="major", labelsize=xticksize
                )  # Adjust the size as needed
                ax.tick_params(
                    axis="y", which="major", labelsize=yticksize
                )  # Adjust the size as needed
                # set limits if needed
                if varmax < varmax_threshold:
                    ax.set_xlim(0, varmax_threshold)
                ax.invert_yaxis()

            # Remove empty subplots
            for i in range(len(variables), num_rows * num_cols):
                row = i // num_cols
                col = i % num_cols
                fig.delaxes(axes[row, col] if num_rows > 1 else axes[col])

            # add colorbar
            sm.set_array([])  # Set an empty array for the colorbar data
            if log_col:
                colorlabel = "logTime (yr)"
                cbar = plt.colorbar(sm, label=colorlabel, ax=plt.gca())
            else:
                colorlabel = "Time (yr)"
                cbar = plt.colorbar(sm, label=colorlabel, ax=plt.gca())
            cbar.set_label(colorlabel, fontsize=cbar_titlesize)
            cbar.ax.tick_params(labelsize=cbar_ticksize)

            plt.tight_layout()
            # plt.show()

            # --- save the result
            if log_col:
                fname = plot_prefix + thisvar + "_logColor.png"
            else:
                fname = plot_prefix + thisvar + ".png"
            plt.savefig(
                os.path.join(save_prof_dir, fname),
                dpi=250,
                bbox_inches="tight",
                transparent=save_transparent,
            )
            # --- close to release memory
            plt.close()

    # --- turn interactive mode back on
    plt.ion()

In [8]:
def plot_flx(
    df,
    tdf,
    save_base,
    domain_in,
    mycmap="magma",
    color_start=0.1,
    color_end=0.9,
    Fnum_cols=3,
    Fxax_titlesize=20,
    Fyax_titlesize=13,
    Fxticksize=14,
    Fyticksize=12,
    save_dpi=250,
    save_transparent=False,
    plot_prefix="FLUX_",
):
    # --- turn interactive mode off
    plt.ioff()
    # --- set the output directory
    save_prof_dir = os.path.join(save_base, tdf["site"], domain_in)
    # create it if it doesn't exist
    if not os.path.exists(save_prof_dir):
        os.makedirs(save_prof_dir)

    # --- get list of variables to loop through
    thesesets = df["set"].unique()

    # --- LOOP 1: THIS SET
    for thisset in thesesets:
        dfset = df[df["set"] == thisset]

        # --- get a color dictionary
        if flx_plot_excludezeros:  # remove columns with all zeros
            dftmp = dfset.loc[:, (dfset != 0).any(axis=0)]
        allflux_components = [
            col for col in dftmp.columns if col not in ["set", "time", "var"]
        ]
        cmap = plt.get_cmap(mycmap)
        num_colors = len(allflux_components)
        colors = cmap(np.linspace(color_start, color_end, num_colors))
        color_dict = dict(zip(allflux_components, colors))

        # --- get vars
        thesevars = dfset["var"].unique()

        # --- Calculate number of rows and columns for subplots
        num_rows = -(-len(thesevars) // Fnum_cols)  # Round up division
        # count which panel we're on
        cnt = 0

        # --- Create subplots
        fig, axes = plt.subplots(
            num_rows, Fnum_cols, figsize=(5 * Fnum_cols, 4 * num_rows)
        )

        # --- LOOP 2: THESEVARS
        for thisvar in thesevars:
            # extract just the variable we want
            dfx = dfset[dfset["var"] == thisvar]
            # remove columns with all nan
            dfx1 = dfx.dropna(axis=1, how="all")
            # remove columns with all zeros
            if flx_plot_excludezeros:
                dfx2 = dfx1.loc[:, (dfx1 != 0).any(axis=0)]
            else:
                dfx2 = dfx1
            # remove columns whose names are numbers
            number_pattern = re.compile(
                r"^[-+]?\d*\.?\d+(?:[eE][-+]?\d+)?$"
            )  # required to detect columns that are numbers
            for col in dfx2.columns:
                if number_pattern.match(col):
                    dfx2.drop(columns=col, inplace=True)
            # set the plot df
            dfp = dfx2

            # Get the list of variables
            variables = [
                col for col in dfp.columns if col not in ["set", "time", "var"]
            ]

            # Plot each variable versus time
            row = cnt // Fnum_cols
            col = cnt % Fnum_cols

            for i, var in enumerate(variables):
                ax = axes[row, col] if num_rows > 1 else axes[col]
                if i % 2 == 0:
                    linestylex = "--"  # if even,
                else:
                    linestylex = "solid"
                ax.plot(
                    dfp["time"],
                    dfp[var],
                    color=color_dict[var],
                    label=var,
                    linewidth=3,
                    linestyle=linestylex,
                )

            ax.set_title(thisvar)
            ax.set_ylabel("flux", size=Fyax_titlesize)
            ax.set_xlabel("Time (yr)", size=Fxax_titlesize)
            ax.tick_params(
                axis="x", which="major", labelsize=Fxticksize
            )  # Adjust the size as needed
            ax.tick_params(
                axis="y", which="major", labelsize=Fyticksize
            )  # Adjust the size as needed
            ax.legend()

            # update cnt
            cnt += 1

        # Remove empty subplots
        for i in range(len(variables), num_rows * Fnum_cols):
            row = i // Fnum_cols
            col = i % Fnum_cols
            fig.delaxes(axes[row, col] if num_rows > 1 else axes[col])

        # add colorbar
        plt.tight_layout()
        # plt.show()

        # --- save the result
        fname = plot_prefix + thisset + ".png"
        plt.savefig(
            os.path.join(save_prof_dir, fname),
            dpi=250,
            bbox_inches="tight",
            transparent=save_transparent,
        )
        # --- close to release memory
        plt.close()

    # --- turn interactive mode back on
    plt.ion()

In [9]:
# --- run the loop and generate figures
resdir = "/home/tykukla/SCEPTER/scepter_output"  # location of output directories

# select which runs to use
if runs2save == "all":
    dfin = dfin
# ELSE... select all the spinnames from a list

# LOOP -------------------------------------------------------
# --- first across all runs
for trun in range(len(dfin)):
    runname_in = dfin["spinname_full"][trun]
    tdf = dfin.loc[trun]

    # --- second across lab / field domains
    for domain_in in domain:
        # ***** PROFILE DATA ***** #
        # read in
        df = read_prof_dat(resdir, runname_in, domain_in, prof_list)
        # plot + save ----------------------------------------------------
        plot_profile(df, tdf, save_base, domain_in, logtime_color)
        # ----------------------------------------------------------------

        # ***** FLUX DATA ***** #
        # read in
        df_flx = read_flx_dat(resdir, runname_in, domain_in, flx_list)
        # plot + save ----------------------------------------------------
        plot_flx(df_flx, tdf, save_base, domain_in)
        # ----------------------------------------------------------------

reading in prof_aq...
reading in prof_aq(ads%cec)...
reading in prof_gas...
reading in prof_sld(wt%)...
reading in sa...
reading in rate...


A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  dfx1.drop(columns=col, inplace=True)
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  dfx1.drop(columns=col, inplace=True)


reading in flx_co2sp...
reading in flx_gas...
reading in flx_aq...
reading in flx_sld...
reading in prof_aq...
reading in prof_aq(ads%cec)...
reading in prof_gas...
reading in prof_sld(wt%)...
reading in sa...
reading in rate...


A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  dfx1.drop(columns=col, inplace=True)
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  dfx1.drop(columns=col, inplace=True)


reading in flx_co2sp...
reading in flx_gas...
reading in flx_aq...
reading in flx_sld...


In [11]:
df = df_flx
tdf = dfin.loc[0]
runname_in = dfin["spinname_full"][0]
domain_in = domain[0]
runname_in

'site_311_spintuneup'