This notebook can be used to visualize the performance of the models that have been fine-tuned

In [None]:
import os
import glob
import re
import pandas as pd
import numpy as np

In [None]:
# Point to the task outputs you would like to visualize
results_path = "../outputs/Phenotype"

In [3]:
def parse_filename(filename):
    """
    Extract hyperparameters from the filename.
    Expected format: lr=<lr>_wd=<wd>_epochs=<epochs>_seed=<seed>_effective_batch_size=<effective_bs>.txt
    """
    pattern = r"lr=([^_]+)_wd=([^_]+)_epochs=([^_]+)_seed=([^_]+)_effective_batch_size=([^\.]+)"
    match = re.search(pattern, filename)
    if match:
        lr, wd, epochs, seed, effective_bs = match.groups()
        return {
            "lr": float(lr),
            "wd": float(wd),
            "epochs": int(epochs),
            "seed": int(seed),
            "effective_batch_size": int(effective_bs)
        }
    else:
        return None

def parse_log_file(filepath):
    """
    Parse the log file content, which is assumed to contain lines of the form 'Key: Value'.
    
    This function ignores lines with keys in the ignore list so that the DataFrame doesn't pick up
    headers or unwanted entries.
    """
    ignore_keys = {"test evaluation results", "epoch", "seed", "effective batch size", "eval_samples_per_second", "eval_steps_per_second", "eval_loss"}
    data = {}
    with open(filepath, 'r') as f:
        lines = f.readlines()
    
    for line in lines:
        line = line.strip()
        # Skip lines that are empty or do not contain a colon.
        if not line or ":" not in line:
            continue
        key, value = line.split(":", 1)
        key = key.strip()
        # If the lowercase key is in the ignore list, skip it.
        if key.lower() in ignore_keys:
            continue
        value = value.strip()
        # If the value is an empty string, set it to np.nan.
        if value == "":
            data[key] = np.nan
        else:
            # Try to convert to a float if possible.
            try:
                # If value contains a decimal point then convert to float, otherwise try int.
                if "." in value:
                    data[key] = float(value)
                else:
                    data[key] = int(value)
            except ValueError:
                data[key] = value  # keep as string if conversion fails
    return data

In [None]:
records = []

# Recursively find all .txt files in the directory.
filepaths = glob.glob(os.path.join(results_path, "**", "*.txt"), recursive=True)
if not filepaths:
    print(f"No log files found in {results_path}")

for filepath in filepaths:
    filename = os.path.basename(filepath)
    # Assume that the model name is the immediate parent folder.
    model = os.path.basename(os.path.dirname(filepath))
    
    params = parse_filename(filename)
    if params is None:
        print(f"Filename {filename} does not match the expected pattern. Skipping file.")
        continue
    
    log_data = parse_log_file(filepath)
    
    # Combine the data from the filename, file content, and model name.
    record = {
        "Model": model,
        **params,  # Contains keys: lr, wd, epochs, seed, effective_batch_size
    }
    record.update(log_data)
    records.append(record)

# Create a DataFrame from the records.
df = pd.DataFrame(records)
print("Individual results:")
print(df.head(), "\n")

# Group on Model, lr, wd, and effective_batch_size.
group_cols = ["Model", "lr", "wd", "effective_batch_size", "epochs"]
ignore_cols = set(group_cols + ["seed"])
metric_cols = [col for col in df.columns if col not in ignore_cols]

# Convert each metric column to numeric (coercing errors to NaN).
for col in metric_cols:
    df[col] = pd.to_numeric(df[col], errors="coerce")

# Define aggregation: for each metric, compute mean, median, min, and max.
agg_funcs = {col: ["mean", "median", "min", "max"] for col in metric_cols}
grouped = df.groupby(group_cols).agg(agg_funcs)

# Flatten the multi-level column index.
grouped.columns = ["_".join(col).strip() for col in grouped.columns.values]
grouped = grouped.reset_index()

for col in grouped.columns:
    if "f1" in col.lower():
        grouped[col] = (grouped[col] * 100).round(1)


print("Aggregated results over seeds:")
print(grouped.head())

# Optionally, save the aggregated DataFrame to CSV.
output_csv = os.path.join(results_path, "aggregated_results.csv")
grouped.to_csv(output_csv, index=False)
print(f"\nAggregated results saved to {output_csv}")

In [None]:
grouped.sort_values(by='accuracy_mean', ascending=False)