# Paleosol ternary plots

In [37]:
import numpy as np
import pandas as pd
import re
import datetime
import itertools
import mendeleev
import openpyxl
from pyrolite import geochem
from pyrolite.geochem.ind import (
    simple_oxides,
    common_oxides,
    common_elements,
    get_cations,
)
from openpyxl.utils.dataframe import dataframe_to_rows
from openpyxl.utils import get_column_letter

## Define needed data

In [38]:
# Define wanted columns.
metacols = [
    "Age",
    "Paleosol",
    "Reference",
    "Section",
    "Sample",
    "Depth",
    "Position",
    "New position",
    "Relative kaolinite",
]
compcols = ["CaO", "Al2O3", "MgO", "K2O", "Na2O", "P2O5", "TiO2", "Cr"]

ox_convert_df = pd.read_excel("oxide_conversion.xlsx")

infile = "vanad_paleomullad_allikas_2021-01-22.xlsx"

## Import and stitch all paleosols

In [39]:
def beautify_table(df, searchcol=0, identifier="Age", inplace=False):
    """
    Drop useless rows and set proper headers for df.
    """

    if not inplace:
        df = df.copy()

    # Check if the header already works.
    if identifier not in df.columns[searchcol]:
        # Find the index that contains the searched for string in its particular
        # column.
        condition = (
            df.iloc[:, searchcol]
            .astype("str")
            .str.contains(identifier, case=False, na=False)
        )
        # df.[condition] is a df; squeeze it down into a series, so we could replace
        # the column names with it.
        df.columns = df.loc[condition].squeeze()

        # drop columns header and all rows above it()
        df = df.iloc[df[condition].index.values[0]+1 : ]
        # drop all empty rows
        df.dropna(subset=[df.columns[0]], inplace=True)

        # reset index
        df.reset_index(inplace=True)

    if not inplace:
        return df
    

def add_relative_kaolinite(
    df,
    alumosilicates=[
        "K-feldspar",
        "Mica/Biotite",
        "Plagioclase",
        "Hornblende",
        "Kaolinite",
        "Illite/MLM",
        "Chlorite",
        "Vermiculite",
        "Cordierite",
    ],
    kaolinite="Kaolinite",
    header="Relative kaolinite",
):
    """
    Adds in a column of kaolinite relative to other alumosilicate minerals.
    """
    alumosilicates = df.columns.intersection(alumosilicates)
    df[header] = df[kaolinite] / df[alumosilicates].sum(axis="columns") * 100

    
def cull_paleosol(df, metacols, compcols, el_unit='ppm', ox_unit='wt%'):
    """  
    Convert chemistry in a df to needed species and cull all unneeded columns.
    """
    # Remove unneeded samples.
    #if "Exclude" in df.columns:
    #    df = df[df["Exclude"] != "yes"]
    df = df[(df["Position"] != "cover") & (df["Position"] != "sill") & (df["Position"] != "weathered")]

    # Prep compositional columns -- make them numeric, using the same unit.
    df.pyrochem.compositional = df.pyrochem.compositional.apply(pd.to_numeric, errors='coerce')
    
    df.pyrochem.elements = df.pyrochem.elements.pyrochem.scale(in_unit=el_unit, target_unit=ox_unit)

    # Convert the chemistry to required species then fix the units.
    df = df.pyrochem.convert_chemistry(to=compcols)
    
    df.pyrochem.elements = df.pyrochem.elements.pyrochem.scale(in_unit=ox_unit, target_unit=el_unit)

    # Throw out meta columns that don't exist here.
    metacols = df.columns.intersection(metacols).to_list()
    # Select only useful columns.
    df = df[metacols + compcols]
    
    return df

    
def stitch_paleosols(paleosol_dict, metacols, compcols):
    """
    Stich together all paleosol columns into one.
    """
    fulldf = pd.DataFrame()
    for name, df in paleosol_dict.items():
        
        # Unify paleosol structure.
        df = cull_paleosol(df, metacols=metacols, compcols=compcols)
        # `sort=False` suppresses the error when column names don't align, and doesn't
        # cause the column names to be sorted wonkily.
        fulldf = pd.concat([fulldf, df], axis="index", sort=False)

    return fulldf.reset_index(drop=True)


# Read in paleosol tables.
paleosol_tables = pd.read_excel(infile, sheet_name=None)

# Place headers correctly.
for name, df in paleosol_tables.items():
    beautify_table(df, inplace=True)
    
# Add kaolinite calculation where needed.
add_relative_kaolinite(paleosol_tables["Balti paleomuld 0.55"])
add_relative_kaolinite(paleosol_tables["Balti paleomuld 0.55 part 2"])

# Stitch all together.
fulldf = stitch_paleosols(paleosol_tables, metacols, compcols)

## Rename and rearrange

In [40]:
# Sort according to age.
fulldf["Age"] = pd.to_numeric(fulldf["Age"], errors="coerce")
# Mergesort is the only one that conserves order of non-sorted rows.
fulldf.sort_values("Age", ascending=False, inplace=True, kind="mergesort")
fulldf.reset_index(drop=True, inplace=True)

# Rename the columns.
renamedict = {
    "Age": "Age (Ga)",
    "Position": "Relative position",
    "CaO": "CaO (wt.%)",
    "Al2O3": "Al2O3 (wt.%)",
    "MgO": "MgO (wt.%)",
    "K2O": "K2O (wt.%)",
    "Na2O": "Na2O (wt.%)",
    "P2O5": "P2O5 (wt.%)",
    "TiO2": "TiO2 (wt.%)",
    "Depth": "Depth (m)",
    "Relative kaolinite": "Relative kaolinite (wt.%)",
    "Cr": "Cr (ppm)",
}
fulldf.rename(columns=renamedict, inplace=True)

# Rearrange the columns.
fulldf = fulldf[
    [
        "Age (Ga)",
        "Paleosol",
        "Reference",
        "Section",
        "Sample",
        "Depth (m)",
        "Relative position",
        "Relative kaolinite (wt.%)",
        "CaO (wt.%)",
        "Al2O3 (wt.%)",
        "MgO (wt.%)",
        "K2O (wt.%)",
        "Na2O (wt.%)",
        "P2O5 (wt.%)",
        "TiO2 (wt.%)",
        "Cr (ppm)",
    ]
]

In [41]:
fulldf.head()

Unnamed: 0,Age (Ga),Paleosol,Reference,Section,Sample,Depth (m),Relative position,Relative kaolinite (wt.%),CaO (wt.%),Al2O3 (wt.%),MgO (wt.%),K2O (wt.%),Na2O (wt.%),P2O5 (wt.%),TiO2 (wt.%),Cr (ppm)
0,3.46,Panorama,"Retallack, 2018",Jurl 1,3756,,top,,0.03,0.77,0.02,0.21,0.01,0.04,0.04,
1,3.46,Panorama,"Retallack, 2018",Jurl 1,3757,,,,0.02,0.56,0.02,0.14,0.01,0.009,0.02,68.420242
2,3.46,Panorama,"Retallack, 2018",Jurl 1,3758,,,,0.04,0.67,0.02,0.19,0.01,0.05,0.03,
3,3.46,Panorama,"Retallack, 2018",Jurl 1,3759,,proto,,0.04,0.72,0.02,0.19,0.01,0.05,0.02,68.420242
4,3.46,Panorama,"Retallack, 2018",Ngumpu,3761,,top,,0.02,0.71,0.02,0.17,0.01,0.02,0.03,


## Switch to openpyxl

This allows us to print excel-friendly formulas, to facilitate collaboration.

In [42]:
# Convert df to openpyxl workbook.
wb = openpyxl.Workbook()
ws = wb.active
for r in dataframe_to_rows(fulldf, index=False, header=True):
    ws.append(r)

In [43]:
# General helper methods for openpyxl.


def col_from_string(ws, name, flatten=True, letter=True, strict=False):
    """
    Fetch column number or letter where header matches string. Returns list of
    columns where header matches, unless only one column matches and flatten=True.
    'letter' converts column numbers to letters. 'strict' will enforce an exact,
    not partial, match.
    """
    if strict:
        colids = [cell.column for cell in ws[1] if cell.value == name]
    else:
        colids = [
            cell.column
            for cell in ws[1]
            if cell.value.lower() in name.lower() or name.lower() in cell.value.lower()
        ]

    if letter:
        colids = [get_column_letter(i) for i in colids]

    if flatten and len(colids) == 1:
        return colids[0]
    else:
        return colids


def make_isnumber(cols, row, chain="AND"):
    """
    Generate a (multiple) ISNUMBER test for given columns/row. `chain` takes as a string
    an Excel command with which to chain ISNUMBER() statements (usually 'AND' or 'OR').
    """
    tests = [f"ISNUMBER({col}{row})" for col in cols]
    if len(tests) > 1:
        return f"{chain.upper()}({', '.join(tests)})"
    elif len(tests) == 1:
        return tests[0]


def get_sections(
    ws,
    sectioncols=["Paleosol", "Section"],
    getcols=True,
    strict=False,
    which="starts",
    boundrows=None,
):
    """
    Returns list of row numbers per section. 'sectioncols' is a list of column
    header strings where section identifiers are found. Unless 'getcols=False', in
    which case it is a list of column letters. If `getcols`, `strict` controls whether
    its a strict or partial match. 'which' is either 'starts' or 'rows', returning
    either a list of section start rows, or a 2-dimensional list of all rows in all
    sections. 'boundrows' is None or a list of 2 row numbers, between which the
    worksheet is evaluated.
    """
    # Find all section columns.
    if getcols:
        sectioncols = [col_from_string(ws, name, strict=strict) for name in sectioncols]

    # Handle if you want to only evaluate between certain rows.
    if boundrows:
        startr, endr = boundrows[0], boundrows[1]
    else:
        startr, endr = 1, None
    # Get section starts from every column in sectioncols, then combine all these lists
    # into one.
    starts = []
    for col in sectioncols:
        # Add first row, if not header.
        if startr > 1:
            startlst = [ws[col][startr - 1].row]
        else:
            startlst = []
        # Add in the rest of the section starts.
        seccol = ws[col][startr:endr]
        startlst += [
            cell.row
            for i, cell in enumerate(seccol, start=startr)
            if cell.value != ws[col][i - 1].value
        ]
        starts.append(startlst)
    # Combine the lists you got per column, remove duplicate entries by converting to
    # set and back.
    starts = sorted(list(set(itertools.chain(*starts))))

    if which == "starts":
        return starts
    elif which == "rows":
        rows = []
        for i in range(len(starts)):
            try:
                rows.append(list(range(starts[i], starts[i + 1])))
            # Handle the last section (when starts[i+1] is out of index range).
            except IndexError:
                if endr:
                    rows.append(list(range(starts[i], endr + 1)))
                else:
                    rows.append(list(range(starts[i], ws.max_row + 1)))
        return rows
    else:
        raise ValueError("'which' needs to be either 'rows' or 'starts'")


def searchrows(ws, col, search, rows=None, getcols=True, strict=False):
    """
    Returns row numbers where value in 'col' corresponds to 'search'.  'rows' is None
    or an iterable of 2 row numbers between which the column is evaluated. 'col' is
    column header name, or if 'getcols=False', then column letter. If 'getcols', then
    'strict' controls whether searching for an exact or partial match.
    """
    if getcols:
        col = col_from_string(ws, col, strict=strict)

    if rows:
        searchcol = ws[col][rows[0] - 1 : rows[1]]
    else:
        searchcol = ws[col]

    selected_rows = []
    for cell in searchcol:
        if cell.value == search:
            selected_rows.append(cell.row)

    return selected_rows

In [44]:
# Specific methods to calculate corrections and indices for paleosol table.


def add_molar(
    ws,
    oxidecols,
    ox_convert_df,
    getcols=True,
    strict=False,
    oxconv_oxide="oxide",
    oxconv_element="element",
    oxconv_cation="cation_number",
    oxconv_oxygen="oxygen_number",
    headeradd=" (mol)",
):
    """
    Normalizes oxide masses by molar mass. 'oxidecols' can be either column letters of
    the worksheet, or column header names, depending on 'getcols'. If getcols=True,
    'strict' controls whether exact or partial match is used. 'ox_convert_df' is a
    Pandas dataframe that contains oxide formula, cation element and cation + oxygen
    amounts in oxide molecule. 'oxconv_' variables are header names of said dataframe.
    'headeradd' is concatenated to the end of the new column headers.
    """
    # Store column header names and letters in different lists.
    if getcols:
        oxideletters = [
            col_from_string(ws, oxide, strict=strict) for oxide in oxidecols
        ]
    else:
        oxideletters = oxidecols
        oxidecols = [ws[l][0].value for l in oxideletters]

    for oxide, letter in zip(oxidecols, oxideletters):
        # Strip (wt.%), so that you can match values in the oxide conversion worksheet.
        pureoxide = re.sub(r"\([^()]*\)", "", oxide).strip()
        # Select row of current oxide.
        oxrow_sel = ox_convert_df[oxconv_oxide].str.contains(
            pureoxide, case=False, na=False
        )
        oxrow = ox_convert_df[oxrow_sel].squeeze()

        # Build molecular weight calculation string using mendeleev.
        (els) = mendeleev.element([oxrow[oxconv_element], "O"])
        weights = []
        for el, numcol in zip(els, [oxconv_cation, oxconv_oxygen]):
            if oxrow[numcol] == 1:
                weights.append(str(el.atomic_weight))
            elif oxrow[numcol] > 1:
                weights.append(f"{oxrow[numcol]}*{el.atomic_weight}")
        mol_weight = " + ".join(weights)

        # Add data to new column.
        col = get_column_letter(ws.max_column + 1)
        ws[col][0].value = oxrow["element"] + headeradd
        for cell in ws[col][1:]:
            cell.value = f"={letter}{cell.row} / ({mol_weight})"


def add_ca_corrections(
    ws,
    getcols=True,
    ca="Ca (mol)",
    p="P (mol)",
    na="Na (mol)",
    header_ap="Ca (-ap)",
    header_sil="Ca*",
    header_carb="Ca (carb)",
    strict=False,
):
    """
    Adds conservative Ca corrections for apatite and carbonate to a worksheet. Element
    parameters need to be strings that match to column headers, unless 'getcols=False',
    in which case they correspond to column letters directly. 'strict': match names
    exactly, if 'getcols' is True.
    """
    # Get column letters.
    if getcols:
        ca = col_from_string(ws, ca, strict=strict)
        p = col_from_string(ws, p, strict=strict)
        na = col_from_string(ws, na, strict=strict)

    # Remove apatite Ca.
    ap_col = get_column_letter(ws.max_column + 1)
    ws[ap_col][0].value = header_ap
    for cell in ws[ap_col][1:]:
        correction = f"{ca}{cell.row} - 10/3*{p}{cell.row}"
        cell.value = f"=IF(({correction}) > 0, {correction}, 0)"

    # Estimate silicate Ca.
    sil_col = get_column_letter(ws.max_column + 1)
    ws[sil_col][0].value = header_sil
    for cell in ws[sil_col][1:]:
        test = f"{ap_col}{cell.row}>{na}{cell.row}"
        cell.value = f"=IF({test}, {na}{cell.row}, {ap_col}{cell.row})"

    # Calculate carbonate Ca.
    carb_col = get_column_letter(ws.max_column + 1)
    ws[carb_col][0].value = header_carb
    for cell in ws[carb_col][1:]:
        correction = f"{ap_col}{cell.row}-{sil_col}{cell.row}"
        cell.value = f"=IF(({correction}) > 0, {correction}, 0)"


def add_mg_correction(
    ws,
    getcols=True,
    mg="Mg (mol)",
    ca_carb="Ca (carb)",
    p="P (mol)",
    header="Mg*",
    strict=False,
    apatite=False,
):
    """
    Adds conservative Mg correction for carbonate to a worksheet, assuming dolomitic
    stochiometry with Ca. Element parameters need to be strings that match to column
    headers, unless 'getcols=False', in which case they correspond to column letters
    directly. 'strict': match names exactly, if 'getcols' is True. 'apatite' boolean
    controls whether apatite correction is applied. 
    """
    # Get column letters.
    if getcols:
        mg = col_from_string(ws, mg, strict=strict)
        ca_carb = col_from_string(ws, ca_carb, strict=strict)
        p = col_from_string(ws, p, strict=strict)

    # Apatite and carbonate correction for Mg.
    col = get_column_letter(ws.max_column + 1)
    ws[col][0].value = header
    for cell in ws[col][1:]:
        if apatite:
            correction = f"{mg}{cell.row} - 3/10*{p}{cell.row} - {ca_carb}{cell.row}"
        else:
            correction = f"{mg}{cell.row}-{ca_carb}{cell.row}"
        cell.value = f"=IF(({correction}) > 0, {correction}, 0)"


def add_cia(
    ws,
    which="both",
    getcols=True,
    al="Al (mol)",
    ca="Ca*",
    na="Na (mol)",
    k="K (mol)",
    header_cia="CIA",
    header_ciw="CIW",
    strict=False,
):
    """
    Adds a chemical index of alteration column to a worksheet. Element parameters need
    to be strings that match to column headers, unless 'getcols=False', in which case
    they correspond to column letters directly. 'strict': match names exactly, if
    'getcols' is True.
    """
    # Get column letters.
    if getcols:
        al = col_from_string(ws, al, strict=strict)
        ca = col_from_string(ws, ca, strict=strict)
        na = col_from_string(ws, na, strict=strict)
        k = col_from_string(ws, k, strict=strict)

    # Add cia column.
    for name, elements in zip(
        [header_cia, header_ciw], [[al, ca, na, k], [al, ca, na]]
    ):
        if which.lower() == name.lower() or which.lower() == "both":
            col = get_column_letter(ws.max_column + 1)
            ws[col][0].value = name
            for cell in ws[col][1:]:
                test = make_isnumber(elements, cell.row, chain="OR")
                el_lst = ", ".join([f"{e}{cell.row}" for e in elements])
                cell.value = f'=IF({test}, ({al}{cell.row} / SUM({el_lst}))*100, "")'


def add_cn_cnm(
    ws,
    which="both",
    getcols=True,
    ca="Ca*",
    na="Na (mol)",
    mg="Mg*",
    cn_header="CN",
    cnm_header="CNM",
    strict=False,
):
    """
    Adds a CNM and CN column to a worksheet. Element parameters need to be strings that
    match to column headers, unless 'getcols=False', when they need to be column letters.
    'which' is either cn_header, cnm_header or 'both'. 'strict': match names exactly,
    if 'getcols' is True.
    """
    # Get column letters.
    if getcols:
        ca = col_from_string(ws, ca, strict=strict)
        na = col_from_string(ws, na, strict=strict)
        mg = col_from_string(ws, mg, strict=strict)

    # Add cn/cnm column.
    for name, elements in zip([cn_header, cnm_header], [[ca, na], [ca, na, mg]]):
        if which.lower() == name.lower() or which.lower() == "both":
            col = get_column_letter(ws.max_column + 1)
            ws[col][0].value = name.upper()
            for cell in ws[col][1:]:
                cell.value = "=SUM({})".format(
                    ", ".join([f"{e}{cell.row}" for e in elements])
                )
        
        
def add_proto_ratio(
    ws,
    numerator,
    denominator,
    header=None,
    sectioncols=["Paleosol", "Section"],
    position="Relative position",
    protoname="proto",
    getcols=True,
    strict=False,
):
    """
    Calculate ratios of two chemical columns (named 'numerator' and 'denominator') in
    the 'ws' workbook. 'sectioncols' is a list of column names that hold section info,
    first of which is the main divider (e.g., paleosols), others are smaller subdivisions
    (e.g., sections). Unless 'getcols=False', in which case it is a list of column letters.
    'position' is the header/letter of the column that tells if a sample is the protolith.
    If 'getcols', then 'strict' controls whether to search for partial or strict match.
    'protoname' is the value that marks protolith samples in the 'pos' column. 'header' is
    the column name of the output columns. By default 'numerator'/'denominator'
    """
    # Fix header name.
    if header == None:
        header = f"{numerator}/{denominator} (protolith)"

    # Get existing column letters.
    if getcols:
        position = col_from_string(ws, position, strict=strict, flatten=False)[0]
        numerator = col_from_string(ws, numerator, strict=strict, flatten=False)[0]
        denominator = col_from_string(ws, denominator, strict=strict, flatten=False)[0]
        sectioncols = [
            col_from_string(ws, col, strict=strict, flatten=False)[0] for col in sectioncols
        ]

    # Get new column letter.
    ratio_col = get_column_letter(ws.max_column + 1)
    ws[ratio_col][0].value = header

    # Protolith calculation is done per paleosol...
    sols = get_sections(
        ws, sectioncols=sectioncols[0:1], getcols=False, which="rows"
    )
    for sol in sols:
        # ...and per section.
        sections = get_sections(
            ws,
            sectioncols=sectioncols[1:],
            getcols=False,
            strict=strict,
            which="rows",
            boundrows=(sol[0], sol[-1]),
        )
        for sec in sections:

            # Get protolith rows.
            protorows = searchrows(
                ws, col=position, getcols=False, search=protoname, rows=(sec[0], sec[-1])
            )
            
            # Use other sections of the same paleosol, if none in the same section.
            if len(protorows) < 1:
                protorows = searchrows(
                    ws, col=position, getcols=False, search=protoname, rows=(sol[0], sol[-1])
                )
                
            # Just grab the last sample in the section, if still none found.
            if len(protorows) < 1:
                protorows = [sec[-1]]
                
            # List all protolith numerator and denominator cells.
            cell_dct = {numerator: [], denominator: []}
            for row in protorows:
                for key in cell_dct:
                    cell_dct[key].append(f"{key}${row}")
            
            # Replace list with joined string.
            for key, cells in cell_dct.items():
                if len(cells) > 1:
                    cell_dct[key] = f"AVERAGE({', '.join(cells)})"
                elif len(cells) == 1:
                    cell_dct[key] = str(cells[0])
            
            # Construct output string for cell.
            ratio_str = f"{cell_dct[numerator]} / {cell_dct[denominator]}"
            conditional_str = f"AND(ISNUMBER({cell_dct[numerator]}), ISNUMBER({cell_dct[denominator]}))"
            output_str = f'=IF({conditional_str}, {ratio_str}, "")'

            # Fill in column.
            for row in sec:
                ws[f"{ratio_col}{row}"].value = output_str

                
def add_k_correction(
    ws,
    sectioncols=["Paleosol", "Section"],
    position="Relative position",
    k="K (mol)",
    al="Al (mol)",
    getcols=True,
    strict=False,
    protoname="proto",
    header_proto="K/Al (protolith)",
    header_kcalc="K*",
):
    """
    Calculate pre-metasomatism K as per Ärps (based on the K/Al ratio of the protolith).
    'sectioncols' is a list of column names that hold section info, first of which is
    the main divider (e.g., paleosols), others are smaller subdivisions (e.g.,
    sections). Unless 'getcols=False', in which case it is a list of column letters.
    'pos' is the header/letter of the column that tells if a sample is the protolith.
    'k' and 'al' are K and Al molar value column headers/letters. If 'getcols', then
    'strict' controls whether to search for partial or strict match. 'protoname' is the
    value that marks protolith samples in the 'pos' column. 'header_' variables are
    column names for output columns.
    """
    # Get existing column letters.
    if getcols:
        position = col_from_string(ws, position, strict=strict)
        k = col_from_string(ws, k, strict=strict)
        al = col_from_string(ws, al, strict=strict)
        sectioncols = [
            col_from_string(ws, col, strict=strict, flatten=False)[0] for col in sectioncols
        ]
        
    # Make a new column for all protolith values
    add_proto_ratio(
        ws,
        numerator=k,
        denominator=al,
        header=header_proto,
        sectioncols=sectioncols,
        position=position,
        protoname=protoname,
        getcols=False,
    )
    
    # Get new column letter.
    kal_col = get_column_letter(ws.max_column)
    kcalc_col = get_column_letter(ws.max_column + 1)

    # Add K* column.
    ws[kcalc_col][0].value = header_kcalc
    for cell in ws[kcalc_col][1:]:
        # Al(mol) * K/Al (protolith).
        kcalc = f"{al}{cell.row}*{kal_col}{cell.row}"
        # Check if Kcalc is smaller than K (mol).
        cell.value = f"=IF(({kcalc}) < {k}{cell.row}, {kcalc}, {k}{cell.row})"
        
        
def add_tau(
    ws,
    mobile="Cr (ppm)",
    immobile='TiO2 (wt.%)',
    sectioncols=["Paleosol", "Section"],
    position="Relative position",
    getcols=True,
    strict=False,
    protoname="proto",
    header_proto="Cr/TiO2 (protolith)",
    header_tau="Tau Cr/Ti",
):
    """
    Add a column with Tau values (loss from protolith) of a mobile element (default is Cr) compared to an
    immobile element (default is Ti).
    """
    # Get existing column letters.
    if getcols:
        position = col_from_string(ws, position, strict=strict)
        mobile = col_from_string(ws, mobile, strict=strict)
        immobile = col_from_string(ws, immobile, strict=strict)
        sectioncols = [
            col_from_string(ws, col, strict=strict, flatten=False)[0] for col in sectioncols
        ]
    
    # Add a protolith ratio
    add_proto_ratio(
        ws,
        numerator=mobile,
        denominator=immobile,
        header=header_proto,
        sectioncols=sectioncols,
        position=position,
        protoname=protoname,
        getcols=False,
    )
    
    # Get new column letters.
    proto_col = get_column_letter(ws.max_column)
    tau_col = get_column_letter(ws.max_column + 1)

    # Add tau column
    ws[tau_col][0].value = header_tau
    for cell in ws[tau_col][1:]:
        # (mobile/immobile) / (mobile/immobile)protolith - 1
        tau = f"({mobile}{cell.row}/{immobile}{cell.row}) / {proto_col}{cell.row} - 1"
        condition = f'AND(ISNUMBER({mobile}{cell.row}), ISNUMBER({immobile}{cell.row}), ISNUMBER({proto_col}{cell.row}))'
        cell.value = f'=IF({condition}, {tau}, "")'
                        
                
def add_kaolinite_estimate(
    ws,
    al="Al (mol)",
    k="K*",
    cnm="CNM",
    getcols=True,
    strict=False,
    header="A/(CNM+A+K*)",
):
    """
    Adds a column with a relative kaolinite estimate (vs other alumosilicate phases),
    based on A/(CNM+A+K). 'al', 'k', 'cnm' are column headers for source data, unless
    getcols=False, in which case they are column letters. If 'getcols', 'strict'
    controls whether the match is exact or partial.
    """
    if getcols:
        al = col_from_string(ws, al, strict=strict)
        k = col_from_string(ws, k, strict=strict)
        cnm = col_from_string(ws, cnm, strict=strict)

    kaol_col = get_column_letter(ws.max_column + 1)
    ws[kaol_col][0].value = header
    for cell in ws[kaol_col][1:]:
        cell.value = (
            f"={al}{cell.row} / SUM({cnm}{cell.row}, {al}{cell.row}, {k}{cell.row})"
        )

In [45]:
# Add in all corrections and calculations.
add_molar(ws, compcols, ox_convert_df)
add_ca_corrections(ws)
add_mg_correction(ws)
add_k_correction(ws)
#add_proto_ratio(ws, 'Cr (ppm)', 'TiO2 (wt.%)')
add_cn_cnm(ws)
add_cia(ws)
add_cia(ws, which="CIA (K*)", k="K*", header_cia="CIA (K*)")
add_kaolinite_estimate(ws, strict=True)
add_tau(ws)
add_tau(ws, immobile="Al2O3 (wt.%)", header_proto='Cr/Al2O3 (protolith)', header_tau='Tau Cr/Al')
add_tau(ws, mobile='P2O5 (wt.%)',header_proto="P2O5/TiO2 (protolith)", header_tau="Tau P/Ti")

In [46]:
# Functions to style and format the table.


def as_text(val):
    """
    Convert value to string, handle None values.
    """
    if val is None:
        return ""
    return str(val)


# TODO: the column selection needs to be general-cased somehow. See how to connect it to
# the col_from_string() f-n.
def size_columns(
    ws, indexcols, long_cutoff=12, short_cutoff=7, middle_factor=1.1, long_factor=0.9
):
    """
    Set optimal column widths for worksheet. If column not in indexcols, only take
    header into account. Also, uglily hacks lengths to account for variable-width fonts:
    columns shorter than short_cutoff will be set to width short_cutoff, longer than
    long_cutoff will be kept at longest cell length. Columns in between will be
    multiplied by middle_factor.
    """
    # TODO: base on col_from_string()
    # Find column numbers that contain some element of indexcols.
    inums = [
        cell.column
        for cell, ihead in itertools.product(ws[1], indexcols)
        if ihead.lower() in cell.value.lower()
    ]

    for col in ws.columns:
        # Find longest cell length for index columns, header cell length for numerical
        # columns.
        if col[0].column in inums:
            length = max(len(as_text(cell.value)) for cell in col)
        else:
            length = len(as_text(col[0].value))

        # Ugly hack to make the col widths fit better (var-width fonts are a pain for
        # this, but look pretty)
        if length > long_cutoff:
            length *= long_factor
        elif length >= short_cutoff:
            length *= middle_factor
        elif length < short_cutoff:
            length = short_cutoff

        col_letter = get_column_letter(col[0].column)
        ws.column_dimensions[col_letter].width = length


def color_alternate_sections(ws, color="D9D9D9", sectioncols=["Paleosol", "Section"]):
    """
    Colors every other section a certain solid color.
    """
    secrows = get_sections(ws, sectioncols, which="rows")

    fill = openpyxl.styles.PatternFill(
        fill_type="solid", start_color=color, end_color=color
    )
    # Color every other section.
    for i in range(0, len(secrows), 2):
        for row in secrows[i]:
            for cell in ws[row]:
                cell.fill = fill


# TODO: the column selection needs to be general-cased somehow. See how to connect it to
# the col_from_string() f-n.
def format_values(ws, column_strings, cell_format="0.00", opposite=False):
    """
    Format columns that contain the given strings in their header in a numerical fashion.
    If 'opposite=True', format the columns that DON'T contain these strings.
    'cell_format' can either be a string that translates to a valid numeric format, or 'text'.
    """
    # TODO: base on col_from_string().
    # Get column numbers that contain strings in column_strings.
    selected_cols = set(
        [
            cell.column
            for cell, string in itertools.product(ws[1], column_strings)
            if string.lower() in cell.value.lower()
        ]
    )
    if opposite:
        allcols = set(cell.column for cell in ws[1])
        selected_cols = allcols.symmetric_difference(selected_cols)

    for col in ws.columns:
        if col[0].column in selected_cols:
            # print(f'formatting column {(col[0].column)} as {cell_format}')
            for cell in col[1:]:
                # TODO: Pandas will start using pd.NA instead of np.nan. See how to
                # general-case this for all missing-value values.
                # Have to check for np.nan, otherwise empty values will be printed as
                # 'nan'.
                if cell_format == "text" and cell.value is not np.nan:
                    cell.value = str(cell.value)
                else:
                    cell.number_format = cell_format

In [47]:
# Style header row.
ft = openpyxl.styles.Font(name="Calibri", bold=True)
for cell in ws[1]:
    cell.font = ft
ws.freeze_panes = ws["A2"]

# Color by section.
color_alternate_sections(ws)

# Set optimal column widths.
size_columns(ws, metacols)

# Change columns into proper format.
nonnumcols = metacols.copy()
nonnumcols.remove("Depth")
nonnumcols.remove("Relative kaolinite")
format_values(ws, nonnumcols, cell_format="text")
format_values(ws, nonnumcols, opposite=True)
format_values(ws, ["mol", "sil", "carb", "*", "CN", "CNM"], cell_format="0.0000")
format_values(ws, ["CIA (K*)", "A/(CNM+A+K*)"])

ws.title = "Paleosol compilation"

wb.save(f"paleosol_compilation_cnm_{datetime.date.today()}.xlsx")