In [1]:
%cd ~/protein-transfer

/home/francesca/protein-transfer


In [2]:
%load_ext blackcellmagic

In [3]:
from scr.utils import pickle_load

In [4]:
"""Analyzing per layer output"""

from __future__ import annotations

from collections import defaultdict

import os
from glob import glob
import numpy as np

import matplotlib.pyplot as plt
from matplotlib.ticker import FormatStrFormatter

from scr.encoding.encoding_classes import get_emb_info
from scr.params.emb import TRANSFORMER_INFO, CARP_INFO
from scr.params.vis import CHECKPOINT_COLOR
from scr.utils import pickle_load, get_filename, checkNgen_folder


class LayerLoss:
    """A class for handling layer analysis"""

    def __init__(
        self,
        add_checkpoint: bool = True,
        checkpoint_list: list = [0.5, 0.25, 0.125],
        input_path: str = "results/sklearn",
        output_path: str = "results/sklearn_layer",
        metric_dict: dict[list[str]] = {
            "proeng": ["train_mse", "val_mse", "test_mse", "test_ndcg", "test_rho"],
            "annotation": [
                "train_cross-entropy",
                "val_cross-entropy",
                "test_cross-entropy",
                "test_acc",
                "test_rocauc",
            ],
            "structure": [
                "train_cross-entropy",
                "val_cross-entropy",
                "casp12_acc",
                "casp12_rocauc",
                "cb513_acc",
                "cb513_rocauc",
                "ts115_acc",
                "ts115_rocauc",
            ],
        },
    ):
        """
        Args:
        - add_checkpoint: bool = True, if add checkpoint for carp
        - checkpoint_list: list = [0.5, 0.25, 0.125],
        - input_path: str = "results/sklearn",
        - output_path: str = "results/sklearn_layer"
        - metric_dict: list[str] = ["train_mse", "test_ndcg", "test_rho"]
        """
        self._add_checkpoint = add_checkpoint
        self._checkpoint_list = checkpoint_list
        # get rid of the last "/" if any
        self._input_path = os.path.normpath(input_path)
        # get the list of subfolders for each dataset
        self._dataset_folders = glob(f"{self._input_path}/*/*/*/*/*")
        # glob("results/train_val_test/*/*/*/*/*")

        # get rid of the last "/" if any
        self._output_path = os.path.normpath(output_path)
        self._metric_dict = metric_dict

        # init a dictionary for recording outputs
        self._onehot_baseline_dict = defaultdict(dict)
        self._layer_analysis_dict = defaultdict(dict)
        self._rand_layer_analysis_dict = defaultdict(dict)
        self._stat_layer_analysis_dict = defaultdict(dict)

        # init a dict for metric params
        self._metric_numb = defaultdict(dict)

        # init
        self._checkpoint_analysis_dict = defaultdict(dict)
        if self._add_checkpoint:
            for checkpoint in self._checkpoint_list:
                self._checkpoint_analysis_dict[checkpoint] = defaultdict(dict)

        for dataset_folder in self._dataset_folders:
            # dataset_folder = "results/train_val_test/proeng/gb1/two_vs_rest/esm1b_t33_650M_UR50S/max"
            # get the details for the dataset such as proeng/gb1/two_vs_rest
            task_subfolder = dataset_folder.split(self._input_path + "/")[-1]
            # task_subfolder = "proeng/gb1/two_vs_rest/esm1b_t33_650M_UR50S/max"
            task, dataset, split, encoder_name, flatten_emb = task_subfolder.split("/")
            # get collage_name
            collage_name = f"{task}_{dataset}_{split}_{flatten_emb}"

            # get number of metircs
            self._metric_numb[collage_name] = len(self._metric_dict[task])

            # parse results for plotting the collage and onehot
            self._layer_analysis_dict[collage_name][
                encoder_name
            ] = self.parse_result_dicts(
                dataset_folder, task, dataset, split, encoder_name, flatten_emb
            )

            # init
            # check if check points exists
            if self._add_checkpoint:
                for checkpoint in self._checkpoint_list:
                    checkpoint_path = f"{self._input_path}-{str(checkpoint)}"

                    if os.path.exists(checkpoint_path):
                        self._checkpoint_analysis_dict[checkpoint][
                            f"{task}_{dataset}_{split}_{flatten_emb}"
                        ][encoder_name] = self.parse_result_dicts(
                            dataset_folder.replace(self._input_path, checkpoint_path),
                            task,
                            dataset,
                            split,
                            encoder_name,
                            flatten_emb,
                        )

            # check if reset param experimental results exist
            reset_param_path = f"{self._input_path}-rand"

            if os.path.exists(reset_param_path):
                self._rand_layer_analysis_dict[
                    f"{task}_{dataset}_{split}_{flatten_emb}"
                ][encoder_name] = self.parse_result_dicts(
                    dataset_folder.replace(self._input_path, reset_param_path),
                    task,
                    dataset,
                    split,
                    encoder_name,
                    flatten_emb,
                )
                add_rand = True
            else:
                add_rand = False

            # check if resample param experimental results exist
            resample_param_path = f"{self._input_path}-stat"

            if os.path.exists(resample_param_path):
                self._stat_layer_analysis_dict[
                    f"{task}_{dataset}_{split}_{flatten_emb}"
                ][encoder_name] = self.parse_result_dicts(
                    dataset_folder.replace(self._input_path, resample_param_path),
                    task,
                    dataset,
                    split,
                    encoder_name,
                    flatten_emb,
                )
                add_stat = True
            else:
                add_stat = False

            # check if onehot experimental results exist
            onehot_path = f"{self._input_path}-onehot"

            if os.path.exists(onehot_path):
                if task == "structure":
                    onehot_flatten_emb_name = "noflatten"
                else:
                    onehot_flatten_emb_name = "flatten"
                self._onehot_baseline_dict[
                    f"{task}_{dataset}_{split}"
                ] = self.parse_result_dicts(
                    dataset_folder.replace(self._input_path, onehot_path)
                    .replace(encoder_name, "onehot")
                    .replace(flatten_emb, onehot_flatten_emb_name),
                    task,
                    dataset,
                    split,
                    "onehot",
                    onehot_flatten_emb_name,
                )
                add_onehot = True
            else:
                add_onehot = False

        # combine different model into one big plot with different encoders
        collage_folder = os.path.join(self._output_path, "collage")
        checkNgen_folder(collage_folder)

        for collage_name, encoder_dict in self._layer_analysis_dict.items():

            print(f"Plotting collage_name {collage_name}...")

            onehot_name = "_".join(collage_name.split("_")[:-1])

            if set(list(TRANSFORMER_INFO.keys())) == set(encoder_dict.keys()):
                # set the key rankings to default
                encoder_names = list(TRANSFORMER_INFO.keys())
                encoder_label = "esm"
            elif set(list(CARP_INFO.keys())) == set(encoder_dict.keys()):
                # set the key rankings to default
                encoder_names = list(CARP_INFO.keys())
                encoder_label = "carp"
            else:
                encoder_names = list(set(encoder_dict.keys()))
                encoder_label = "pretrained"

            fig, axs = plt.subplots(
                self._metric_numb[collage_name],
                len(encoder_names),
                sharey="row",
                sharex="col",
                figsize=(20, 2 * self._metric_numb[collage_name]),
                squeeze=False # not get rid off the extra dim if 1D
            )

            for m, metric in enumerate(self._metric_dict[collage_name.split("_")[0]]):

                for n, encoder_name in enumerate(encoder_names):
                    axs[m, n].plot(
                        encoder_dict[encoder_name][metric],
                        label=encoder_label,
                        color="#f79646ff",  # orange
                    )

                    # add checkpoints
                    if self._add_checkpoint:
                        for checkpoint in self._checkpoint_list:

                            checkpoint_vals = self._checkpoint_analysis_dict[
                                checkpoint
                            ][collage_name][encoder_name][metric]

                            if not np.all(checkpoint_vals == 0):
                                axs[m, n].plot(
                                    checkpoint_vals,
                                    label=f"{encoder_label}-{checkpoint}",
                                    color=CHECKPOINT_COLOR[
                                        checkpoint
                                    ],  # darker oranges
                                    linestyle="dashed",
                                )

                    # overlay random init
                    if add_rand:
                        axs[m, n].plot(
                            self._rand_layer_analysis_dict[collage_name][encoder_name][
                                metric
                            ],
                            label="random init",
                            color="#4bacc6",  # blue
                            linestyle="dashed"
                            # color="#D3D3D3",  # light grey
                        )

                    # overlay stat init
                    if add_stat:
                        axs[m, n].plot(
                            self._stat_layer_analysis_dict[collage_name][encoder_name][
                                metric
                            ],
                            label="stat transfer",
                            color="#9bbb59",  # green
                            linestyle="dashed"
                            # color="#A9A9A9",  # dark grey
                            # linestyle="dotted",
                        )

                    # overlay onehot baseline
                    if add_onehot:
                        axs[m, n].axhline(
                            self._onehot_baseline_dict[onehot_name][metric],
                            label="onehot",
                            color="#000000",  # black or #D3D3D3 light grey
                            linestyle="dotted",
                        )

            # add xlabels
            for ax in axs[self._metric_numb[collage_name] - 1]:
                ax.set_xlabel("layers", fontsize=16)
                ax.tick_params(axis="x", labelsize=16)

            # add column names
            for ax, col in zip(axs[0], encoder_names):
                ax.set_title(col, fontsize=16)

            # add row names
            for ax, row in zip(
                axs[:, 0], self._metric_dict[collage_name.split("_")[0]]
            ):
                ax.set_ylabel(
                    row.replace("_", " ").replace("cross-entropy", "ce"), fontsize=16
                )
                ax.tick_params(
                    axis="y",
                    which="major",
                    reset=True,
                    labelsize=16,
                    left=True,
                    right=False,  # no right side tick on the plot
                    labelleft=True,
                    labelright=False,
                )
                ax.relim()  # make sure all the data fits
                ax.autoscale()

            # set the plot yticks
            plt.gca().yaxis.set_major_formatter(FormatStrFormatter("%.2f"))
            plt.gca().yaxis.tick_left()
            plt.gca().autoscale()

            # add legend
            handles, labels = axs[0, 0].get_legend_handles_labels()

            if len(labels) == 7:

                # Add two empty dummy legend items
                # using the first label info

                axs[0, 0].axhline(
                    self._onehot_baseline_dict[onehot_name][
                        self._metric_dict[collage_name.split("_")[0]][0]
                    ],
                    label=" ",
                    color="w",
                    alpha=0,
                )

                axs[0, 0].axhline(
                    self._onehot_baseline_dict[onehot_name][
                        self._metric_dict[collage_name.split("_")[0]][0]
                    ],
                    label=" ",
                    color="w",
                    alpha=0,
                )

                adjusted_handles, adjusted_labels = axs[
                    0, 0
                ].get_legend_handles_labels()
                adjusted_y = 1.045
                ncol = 3
                legend_params = {
                    "labelspacing": 0.1,  # vertical space between the legend entries, default 0.5
                    "handletextpad": 0.2,  # space between the legend the text, default 0.8
                    "handlelength": 0.95,  # length of the legend handles, default 2.0
                    "columnspacing": 1,  # spacing between columns, default 2.0
                }

            else:
                adjusted_handles, adjusted_labels = handles, labels
                adjusted_y = 1.025
                ncol = 2
                legend_params = {}

            fig.legend(
                adjusted_handles,
                adjusted_labels,
                loc="upper left",
                bbox_to_anchor=[0.05, adjusted_y],
                fontsize=16,
                frameon=False,
                ncol=ncol,
                **legend_params,
            )

            # add whole plot level title
            fig.suptitle(
                collage_name.replace("_", " ").replace("cross-entropy", "ce"),
                y=1.0025,
                fontsize=24,
                fontweight="bold",
            )
            fig.align_labels()
            fig.tight_layout()

            for plot_ext in [".svg", ".png"]:
                plt.savefig(
                    os.path.join(collage_folder, collage_name + plot_ext),
                    bbox_inches="tight",
                )

            plt.close()

    def parse_result_dicts(
        self,
        folder_path: str,
        task: str,
        dataset: str,
        split: str,
        encoder_name: str,
        flatten_emb: bool | str,
    ):
        """
        Parse the output result dictionaries for plotting

        Args:
        - folder_path: str, the folder path for the datasets

        Returns:
        - dict, encode name as key with a dict as its value
            where metric name as keys and the array of losses as values
        - str, details for collage plot
        """

        # get the list of output pickle files
        pkl_list = glob(f"{folder_path}/*.pkl")

        _, _, max_layer_numb = get_emb_info(encoder_name)

        # init the ouput dict
        output_numb_dict = {
            metric: np.zeros([max_layer_numb]) for metric in self._metric_dict[task]
        }

        # loop through the list of the pickle files
        for pkl_file in pkl_list:
            # get the layer number
            layer_numb = int(get_filename(pkl_file).split("-")[-1].split("_")[-1])
            # load the result dictionary
            try:
                result_dict = pickle_load(pkl_file)
            except Exception as e:
                print(f"{pkl_file} with err: ", e)

            # populate the processed dictionary
            for metric in self._metric_dict[task]:
                subset, kind = metric.split("_")
                if kind == "rho":
                    output_numb_dict[metric][layer_numb] = result_dict[subset][kind][0]
                else:
                    output_numb_dict[metric][layer_numb] = result_dict[subset][kind]

        # get some details for plotting and saving
        output_subfolder = checkNgen_folder(
            folder_path.replace(self._input_path, self._output_path)
        )

        for metric in output_numb_dict.keys():

            plot_name = f"{encoder_name}_{flatten_emb}_{metric}"
            plot_prefix = f"{task}_{dataset}_{split}"

            plt.figure()
            plt.plot(output_numb_dict[metric])
            plt.title(f"{plot_prefix} \n {plot_name}")
            plt.xlabel("layers")
            plt.ylabel("loss")

            for plot_ext in [".svg", ".png"]:
                plt.savefig(
                    os.path.join(output_subfolder, plot_name + plot_ext),
                    bbox_inches="tight",
                )
            plt.close()

        return output_numb_dict

    @property
    def layer_analysis_dict(self) -> dict:
        """Return a dict with dataset name as the key"""
        return self._layer_analysis_dict

    @property
    def rand_layer_analysis_dict(self) -> dict:
        """Return a dict with dataset name as the key for rand"""
        return self._rand_layer_analysis_dict

    @property
    def stat_layer_analysis_dict(self) -> dict:
        """Return a dict with dataset name as the key for stat"""
        return self._stat_layer_analysis_dict

    @property
    def onehot_baseline_dict(self) -> dict:
        """Return a dict with dataset name as the key for onehot"""
        return self._onehot_baseline_dict

    @property
    def checkpoint_analysis_dict(self) -> dict:
        """Return a dict with dataset name as the key for checkpoints"""
        return self._checkpoint_analysis_dict

In [5]:
pytorch_carp = LayerLoss(
    input_path="results/pytorch-carp",
    output_path="results/pytorch-carp_layer"
)

Plotting collage_name annotation_scl_balanced_mean...
Plotting collage_name structure_ss3_tape_processed_noflatten...


In [6]:
import pandas as pd

In [40]:
pytorch_carp.onehot_baseline_dict

defaultdict(dict,
            {'annotation_scl_balanced': {'train_cross-entropy': array([0.35247527]),
              'val_cross-entropy': array([1.37385221]),
              'test_cross-entropy': array([2.01599532]),
              'test_acc': array([0.37402597]),
              'test_rocauc': array([0.66190611])},
             'structure_ss3_tape_processed': {'train_cross-entropy': array([1.00965377]),
              'val_cross-entropy': array([1.00609217]),
              'casp12_acc': array([0.48194598]),
              'casp12_rocauc': array([0.62021253]),
              'cb513_acc': array([0.48816847]),
              'cb513_rocauc': array([0.64051902]),
              'ts115_acc': array([0.50855104]),
              'ts115_rocauc': array([0.64197196])}})

In [7]:
pd.json_normalize(pytorch_carp.layer_analysis_dict)

Unnamed: 0,annotation_scl_balanced_mean.carp_600k.train_cross-entropy,annotation_scl_balanced_mean.carp_600k.val_cross-entropy,annotation_scl_balanced_mean.carp_600k.test_cross-entropy,annotation_scl_balanced_mean.carp_600k.test_acc,annotation_scl_balanced_mean.carp_600k.test_rocauc,annotation_scl_balanced_mean.carp_76M.train_cross-entropy,annotation_scl_balanced_mean.carp_76M.val_cross-entropy,annotation_scl_balanced_mean.carp_76M.test_cross-entropy,annotation_scl_balanced_mean.carp_76M.test_acc,annotation_scl_balanced_mean.carp_76M.test_rocauc,...,structure_ss3_tape_processed_noflatten.carp_38M.ts115_acc,structure_ss3_tape_processed_noflatten.carp_38M.ts115_rocauc,structure_ss3_tape_processed_noflatten.carp_600k.train_cross-entropy,structure_ss3_tape_processed_noflatten.carp_600k.val_cross-entropy,structure_ss3_tape_processed_noflatten.carp_600k.casp12_acc,structure_ss3_tape_processed_noflatten.carp_600k.casp12_rocauc,structure_ss3_tape_processed_noflatten.carp_600k.cb513_acc,structure_ss3_tape_processed_noflatten.carp_600k.cb513_rocauc,structure_ss3_tape_processed_noflatten.carp_600k.ts115_acc,structure_ss3_tape_processed_noflatten.carp_600k.ts115_rocauc
0,"[1.8769907888613249, 1.7527372648841457, 1.695...","[1.9061075278690882, 1.7826685564858573, 1.727...","[1.9521732330322266, 1.9269767999649048, 1.910...","[0.2623376623376623, 0.2753246753246753, 0.316...","[0.6495972923926321, 0.6703754074198381, 0.673...","[1.71919468202089, 1.4959022904697217, 1.44056...","[1.7575252737317766, 1.534557512828282, 1.4600...","[2.0076744556427, 2.0575806498527527, 1.995046...","[0.3038961038961039, 0.35324675324675325, 0.35...","[0.6315851180553764, 0.7013079524965977, 0.710...",...,"[0.46471855642337734, 0.6232830595206033, 0.69...","[0.5852224245003864, 0.7806285091979067, 0.844...","[1.0358617109795139, 0.89385723659437, 0.81611...","[1.0326182277579057, 0.8892227097561485, 0.810...","[0.45355567805953695, 0.5525082690187431, 0.57...","[0.5753563391975582, 0.6990188518563044, 0.746...","[0.44878082571349404, 0.5898309781102798, 0.64...","[0.593458300550087, 0.7515188960364423, 0.8002...","[0.4690614058712631, 0.5999192028009696, 0.655...","[0.5998865729309671, 0.7559258255636626, 0.809..."


In [17]:
pytorch_carp.layer_analysis_dict["annotation_scl_balanced_mean"]["carp_600k"]["train_cross-entropy"]

array([1.87699079, 1.75273726, 1.69514218, 1.60551146, 1.53652087,
       1.46996107, 1.40448275, 1.33026753, 1.30898465, 1.26282578,
       1.26654993, 1.21502749, 1.19079669, 1.183124  , 1.19900847,
       1.17011334, 1.18096242])

In [41]:
pd.DataFrame.from_dict(pytorch_carp.onehot_baseline_dict)

Unnamed: 0,annotation_scl_balanced,structure_ss3_tape_processed
train_cross-entropy,[0.3524752729817441],[1.0096537658613023]
val_cross-entropy,[1.373852210385459],[1.0060921719199734]
test_cross-entropy,[2.0159953236579895],
test_acc,[0.37402597402597404],
test_rocauc,[0.6619061060529179],
casp12_acc,,[0.48194597574421166]
casp12_rocauc,,[0.6202125329637339]
cb513_acc,,[0.48816846771958994]
cb513_rocauc,,[0.6405190154371231]
ts115_acc,,[0.5085510368973876]


In [47]:
pd.DataFrame.from_dict(sklearn_carp.onehot_baseline_dict)

Unnamed: 0,proeng_aav_one_vs_many,proeng_aav_two_vs_many,proeng_gb1_low_vs_high,proeng_gb1_two_vs_rest,proeng_gb1_sampled,proeng_thermo_mixed_split
train_mse,[6.606138986229867],[16.67168599756062],[0.1770679055989013],[2.090415559759884],[2.0079257240201596],[2760.1282354723244]
val_mse,[8.361448368900476],[16.951583715058266],[0.17003734650743707],[2.0447116896087842],[1.9726564028720246],[2778.184268114801]
test_mse,[66.71775739461496],[24.814891839771693],[4.028476326030474],[5.335454660039244],[2.025209681370326],[2838.811461721218]
test_ndcg,[0.960145115831988],[0.9629290404042454],[0.9178409557926928],[0.8905737288756033],[0.9342600571731149],[0.9554623184824335]
test_rho,[0.1903477426748507],[-0.0015626678198874598],[0.32173083298452543],[0.5428396414906184],[0.7885047525700145],[0.1226657473529901]


In [8]:
pd.DataFrame.from_dict(pytorch_carp.layer_analysis_dict)

Unnamed: 0,annotation_scl_balanced_mean,structure_ss3_tape_processed_noflatten
carp_600k,"{'train_cross-entropy': [1.8769907888613249, 1...","{'train_cross-entropy': [1.0358617109795139, 0..."
carp_76M,"{'train_cross-entropy': [1.71919468202089, 1.4...","{'train_cross-entropy': [1.0396245995612994, 0..."
carp_38M,"{'train_cross-entropy': [1.7385239789360447, 1...","{'train_cross-entropy': [1.0429124097301536, 0..."
carp_640M,"{'train_cross-entropy': [1.7193063999477185, 1...","{'train_cross-entropy': [1.0420560608171436, 0..."


In [20]:
pd.json_normalize(pd.DataFrame.from_dict(pytorch_carp.layer_analysis_dict)['annotation_scl_balanced_mean'])

Unnamed: 0,train_cross-entropy,val_cross-entropy,test_cross-entropy,test_acc,test_rocauc
0,"[1.8769907888613249, 1.7527372648841457, 1.695...","[1.9061075278690882, 1.7826685564858573, 1.727...","[1.9521732330322266, 1.9269767999649048, 1.910...","[0.2623376623376623, 0.2753246753246753, 0.316...","[0.6495972923926321, 0.6703754074198381, 0.673..."
1,"[1.71919468202089, 1.4959022904697217, 1.44056...","[1.7575252737317766, 1.534557512828282, 1.4600...","[2.0076744556427, 2.0575806498527527, 1.995046...","[0.3038961038961039, 0.35324675324675325, 0.35...","[0.6315851180553764, 0.7013079524965977, 0.710..."
2,"[1.7385239789360447, 1.4665500490288985, 1.395...","[1.7737009525299072, 1.5041547673089164, 1.417...","[2.035592496395111, 2.0621028542518616, 2.0096...","[0.2857142857142857, 0.36883116883116884, 0.37...","[0.6189812048918136, 0.7025105662414319, 0.713..."
3,"[1.7193063999477185, 1.5111642197558754, 1.437...","[1.7644925117492676, 1.5379446830068315, 1.462...","[2.0392618775367737, 2.05305552482605, 2.01585...","[0.2961038961038961, 0.34805194805194806, 0.36...","[0.6273352139525914, 0.7005424935091661, 0.706..."


In [None]:
pd.json_normalize(pd.DataFrame.from_dict(pytorch_carp.layer_analysis_dict)['annotation_scl_balanced_mean'])

In [9]:
sklearn_carp = LayerLoss(
    input_path="results/sklearn-carp",
    output_path="results/sklearn-carp_layer"
)

Plotting collage_name proeng_aav_one_vs_many_mean...
Plotting collage_name proeng_aav_two_vs_many_mean...
Plotting collage_name proeng_gb1_low_vs_high_mean...
Plotting collage_name proeng_gb1_two_vs_rest_mean...
Plotting collage_name proeng_gb1_sampled_mean...
Plotting collage_name proeng_thermo_mixed_split_mean...


In [21]:
pd.json_normalize(pd.DataFrame.from_dict(sklearn_carp.layer_analysis_dict)['proeng_aav_one_vs_many_mean'])

Unnamed: 0,train_mse,val_mse,test_mse,test_ndcg,test_rho
0,"[9.59526940116963, 6.894266263718891, 6.377686...","[11.18654191999164, 8.622644328345816, 8.33617...","[13.582533778622887, 53.712009437112215, 74.04...","[0.9704321391244177, 0.9629028051372428, 0.962...","[0.19544242309285326, 0.3282696048440418, 0.31..."
1,"[9.582026132905252, 6.934608975116892, 6.40550...","[11.20643614220332, 8.899376014238195, 8.20574...","[13.697288684643041, 48.961780681241656, 69.96...","[0.9703413636783854, 0.9633589617754629, 0.962...","[0.2116112794657769, 0.3396909252934983, 0.320..."
2,"[9.468360495727067, 8.190111410830674, 7.61644...","[10.9554842436175, 9.782055506244191, 9.209345...","[14.20679802985108, 24.53599377907496, 23.3161...","[0.9698271299324172, 0.96692411292824, 0.96943...","[0.23682032708767906, 0.3808350241317865, 0.43..."
3,"[9.574777226192223, 6.690699098364959, 6.35389...","[11.175331793579538, 8.655740728556914, 8.5822...","[13.968634660845268, 65.68949789195031, 78.864...","[0.969236054842432, 0.9626272649406287, 0.9625...","[0.18915018144628126, 0.3242017861013232, 0.31..."


In [18]:
pd.DataFrame.from_dict(sklearn_carp.layer_analysis_dict)

Unnamed: 0,proeng_aav_one_vs_many_mean,proeng_aav_two_vs_many_mean,proeng_gb1_low_vs_high_mean,proeng_gb1_two_vs_rest_mean,proeng_gb1_sampled_mean,proeng_thermo_mixed_split_mean
carp_640M,"{'train_mse': [9.59526940116963, 6.89426626371...","{'train_mse': [20.068343175878088, 16.92011420...","{'train_mse': [0.2772588428285639, 0.408520416...","{'train_mse': [2.725838148323684, 1.9707604225...","{'train_mse': [2.5043217133381526, 1.876857337...","{'train_mse': [2772.749931574079, 2768.7387720..."
carp_76M,"{'train_mse': [9.582026132905252, 6.9346089751...","{'train_mse': [19.986434661076075, 16.91308441...","{'train_mse': [0.2853286156211369, 0.411029697...","{'train_mse': [2.7027404465889115, 1.969533968...","{'train_mse': [2.4793497727220144, 1.871709821...","{'train_mse': [2772.772707139839, 2768.7354513..."
carp_600k,"{'train_mse': [9.468360495727067, 8.1901114108...","{'train_mse': [19.808120463435134, 17.93898064...","{'train_mse': [0.280335127217654, 0.3869098418...","{'train_mse': [2.674702299445244, 2.1214922598...","{'train_mse': [2.4329909760979573, 2.039620513...","{'train_mse': [2772.624627474502, 2770.3332577..."
carp_38M,"{'train_mse': [9.574777226192223, 6.6906990983...","{'train_mse': [20.04262815315973, 16.802693836...","{'train_mse': [0.27452013074623355, 0.40600096...","{'train_mse': [2.7247057205531227, 1.959347953...","{'train_mse': [2.507311420931766, 1.8454394133...","{'train_mse': [2772.7840732976197, 2768.293532..."


In [28]:
pd.merge(
    pd.DataFrame.from_dict(sklearn_carp.layer_analysis_dict),
    pd.DataFrame.from_dict(pytorch_carp.layer_analysis_dict),
    left_index=True,
    right_index=True,
).to_csv("results/summary/all_carp.csv")

In [29]:
pd.merge(
    pd.DataFrame.from_dict(sklearn_carp.rand_layer_analysis_dict),
    pd.DataFrame.from_dict(pytorch_carp.rand_layer_analysis_dict),
    left_index=True,
    right_index=True,
).to_csv("results/summary/all_carp_rand.csv")

In [30]:
pd.merge(
    pd.DataFrame.from_dict(sklearn_carp.stat_layer_analysis_dict),
    pd.DataFrame.from_dict(pytorch_carp.stat_layer_analysis_dict),
    left_index=True,
    right_index=True,
).to_csv("results/summary/all_carp_stat.csv")

In [42]:
pd.merge(
    pd.DataFrame.from_dict(sklearn_carp.onehot_baseline_dict),
    pd.DataFrame.from_dict(pytorch_carp.onehot_baseline_dict),
    left_index=True,
    right_index=True,
).to_csv("results/summary/all_carp_oh.csv")

In [50]:
pd.DataFrame.from_dict(sklearn_carp.onehot_baseline_dict).merge(
    pd.DataFrame.from_dict(pytorch_carp.onehot_baseline_dict),
    how="outer",
    left_index=True,
    right_index=True,
).fillna(False).to_csv("results/summary/all_carp_oh.csv")

In [45]:
# pd.merge(
#     pd.DataFrame.from_dict(sklearn_carp.checkpoint_analysis_dict),
#     pd.DataFrame.from_dict(pytorch_carp.checkpoint_analysis_dict),
#     left_index=True,
#     right_index=True,
# ).to_csv("results/summary/all_carp_cp.csv")

pd.DataFrame.from_dict(sklearn_carp.checkpoint_analysis_dict).merge(
    pd.DataFrame.from_dict(pytorch_carp.checkpoint_analysis_dict), how="outer"
).fillna(False).to_csv("results/summary/all_carp_cp.csv")

TypeError: unhashable type: 'dict'

In [34]:
pytorch_esm = LayerLoss(
    input_path="results/pytorch-esm",
    output_path="results/pytorch-esm_layer",
    add_checkpoint=False
)

Plotting collage_name structure_ss3_tape_processed_noflatten...
Plotting collage_name annotation_scl_balanced_mean...


In [35]:
sklearn_esm = LayerLoss(
    input_path="results/sklearn-esm",
    output_path="results/sklearn-esm_layer",
    add_checkpoint=False
)

Plotting collage_name proeng_thermo_mixed_split_mean...
Plotting collage_name proeng_gb1_two_vs_rest_mean...
Plotting collage_name proeng_gb1_low_vs_high_mean...
Plotting collage_name proeng_gb1_sampled_mean...
Plotting collage_name proeng_aav_one_vs_many_mean...
Plotting collage_name proeng_aav_two_vs_many_mean...


In [36]:
pd.merge(
    pd.DataFrame.from_dict(sklearn_esm.layer_analysis_dict),
    pd.DataFrame.from_dict(pytorch_esm.layer_analysis_dict),
    left_index=True,
    right_index=True,
).to_csv("results/summary/all_esm.csv")

In [37]:
pd.merge(
    pd.DataFrame.from_dict(sklearn_esm.rand_layer_analysis_dict),
    pd.DataFrame.from_dict(pytorch_esm.rand_layer_analysis_dict),
    left_index=True,
    right_index=True,
).to_csv("results/summary/all_esm_rand.csv")

In [38]:
pd.merge(
    pd.DataFrame.from_dict(sklearn_esm.stat_layer_analysis_dict),
    pd.DataFrame.from_dict(pytorch_esm.stat_layer_analysis_dict),
    left_index=True,
    right_index=True,
).to_csv("results/summary/all_esm_stat.csv")

In [39]:
pd.merge(
    pd.DataFrame.from_dict(sklearn_esm.onehot_baseline_dict),
    pd.DataFrame.from_dict(pytorch_esm.onehot_baseline_dict),
    left_index=True,
    right_index=True,
).to_csv("results/summary/all_esm_oh.csv")

In [None]:
class LayerLoss2CSV(LayerLoss):
    """A class for reorganizing layer by layer results into one csv"""

    def __init__(
        self,
        add_checkpoint: bool = True,
        checkpoint_list: list = [0.5, 0.25, 0.125],
        input_path: str = "results/sklearn",
        output_path: str = "results/sklearn_layer",
        metric_dict: dict[list[str]] = {
            "proeng": ["train_mse", "val_mse", "test_mse", "test_ndcg", "test_rho"],
            "annotation": [
                "train_cross-entropy",
                "val_cross-entropy",
                "test_cross-entropy",
                "test_acc",
                "test_rocauc",
            ],
            "structure": [
                "train_cross-entropy",
                "val_cross-entropy",
                "casp12_acc",
                "casp12_rocauc",
                "cb513_acc",
                "cb513_rocauc",
                "ts115_acc",
                "ts115_rocauc",
            ],
        },
    ):
        """
        Args:
        - add_checkpoint: bool = True, if add checkpoint for carp
        - checkpoint_list: list = [0.5, 0.25, 0.125],
        - input_path: str = "results/sklearn",
        - output_path: str = "results/sklearn_layer"
        - metric_dict: list[str] = ["train_mse", "test_ndcg", "test_rho"]
        """
