In [1]:
%cd ~/repo/protein-transfer

/home/t-fli/repo/protein-transfer


In [2]:
%load_ext blackcellmagic

In [3]:
from scr.analysis.perlayer import LayerLoss

In [17]:
if defaultdict(dict):
    print("not empty")
else:
    print("empty")

empty


In [20]:
"""Analyzing per layer output"""

from __future__ import annotations

from collections import defaultdict

import os
from glob import glob
import numpy as np

import matplotlib.pyplot as plt
from matplotlib.ticker import FormatStrFormatter

from scr.params.emb import TRANSFORMER_INFO, CARP_INFO
from scr.utils import pickle_load, get_filename, checkNgen_folder


class LayerLoss:
    """A class for handling layer analysis"""

    def __init__(
        self,
        input_path: str = "results/sklearn",
        output_path: str = "results/sklearn_layer",
        metric_dict: dict[list[str]] = {
            "proeng": ["train_mse", "test_ndcg", "test_rho"],
            "annotation": ["train_cross-entropy", "test_acc", "test_rocauc"],
        },
    ):
        """
        Args:
        - input_path: str = "results/sklearn",
        - output_path: str = "results/sklearn_layer"
        - metric_dict: list[str] = ["train_mse", "test_ndcg", "test_rho"]
        """
        # get rid of the last "/" if any
        self._input_path = os.path.normpath(input_path)
        # get the list of subfolders for each dataset
        self._dataset_folders = glob(f"{self._input_path}/*/*/*/*/*")
        # glob("results/train_val_test/*/*/*/*/*")

        # get rid of the last "/" if any
        self._output_path = os.path.normpath(output_path)
        self._metric_dict = metric_dict

        # init a dictionary for recording outputs
        self._onehot_baseline_dict = defaultdict(dict)
        self._layer_analysis_dict = defaultdict(dict)
        self._rand_layer_analysis_dict = defaultdict(dict)
        self._stat_layer_analysis_dict = defaultdict(dict)

        for dataset_folder in self._dataset_folders:
            # dataset_folder = "results/train_val_test/proeng/gb1/two_vs_rest/esm1b_t33_650M_UR50S/max"
            # get the details for the dataset such as proeng/gb1/two_vs_rest
            task_subfolder = dataset_folder.split(self._input_path + "/")[-1]
            # task_subfolder = "proeng/gb1/two_vs_rest/esm1b_t33_650M_UR50S/max"
            task, dataset, split, encoder_name, flatten_emb = task_subfolder.split("/")

            # get number of metircs
            metric_numb = len(self._metric_dict[task])

            # parse results for plotting the collage and onehot
            self._layer_analysis_dict[f"{task}_{dataset}_{split}_{flatten_emb}"][
                encoder_name
            ] = self.parse_result_dicts(
                dataset_folder, task, dataset, split, encoder_name, flatten_emb
            )

            # check if reset param experimental results exist
            reset_param_path = f"{self._input_path}-rand"

            if os.path.exists(reset_param_path):
                self._rand_layer_analysis_dict[
                    f"{task}_{dataset}_{split}_{flatten_emb}"
                ][encoder_name] = self.parse_result_dicts(
                    dataset_folder.replace(self._input_path, reset_param_path),
                    task,
                    dataset,
                    split,
                    encoder_name,
                    flatten_emb,
                )
                add_rand = True
            else:
                add_rand = False

            # check if resample param experimental results exist
            resample_param_path = f"{self._input_path}-stat"

            if os.path.exists(resample_param_path):
                self._stat_layer_analysis_dict[
                    f"{task}_{dataset}_{split}_{flatten_emb}"
                ][encoder_name] = self.parse_result_dicts(
                    dataset_folder.replace(self._input_path, resample_param_path),
                    task,
                    dataset,
                    split,
                    encoder_name,
                    flatten_emb,
                )
                add_stat = True
            else:
                add_stat = False

            # check if resample param experimental results exist
            onehot_path = f"{self._input_path}-onehot"

            if os.path.exists(onehot_path):
                self._onehot_baseline_dict[
                    f"{task}_{dataset}_{split}"
                ] = self.parse_result_dicts(
                    dataset_folder.replace(self._input_path, onehot_path)
                    .replace(encoder_name, "onehot")
                    .replace(flatten_emb, "flatten"),
                    task,
                    dataset,
                    split,
                    "onehot",
                    "flatten",
                )
                add_onehot = True
            else:
                add_onehot = False

        collage_folder = os.path.join(self._output_path, "collage")
        checkNgen_folder(collage_folder)

        for collage_name, encoder_dict in self._layer_analysis_dict.items():

            onehot_name = "_".join(collage_name.split("_")[:-1])

            if set(list(TRANSFORMER_INFO.keys())) == set(encoder_dict.keys()):
                # set the key rankings to default
                encoder_names = list(TRANSFORMER_INFO.keys())
                encoder_label = "esm"
            elif set(list(CARP_INFO.keys())) == set(encoder_dict.keys()):
                # set the key rankings to default
                encoder_names = list(CARP_INFO.keys())
                encoder_label = "carp"
            else:
                encoder_names = list(set(encoder_dict.keys()))

            fig, axs = plt.subplots(
                metric_numb,
                len(encoder_names),
                sharey="row",
                sharex="col",
                figsize=(20, 10),
            )
            for m, metric in enumerate(self._metric_dict[task]):
                for n, encoder_name in enumerate(encoder_names):
                    axs[m, n].plot(
                        encoder_dict[encoder_name][metric], label=encoder_label
                    )

                    # overlay onehot baseline
                    if add_onehot:
                        axs[m, n].axhline(
                            self._onehot_baseline_dict[onehot_name][metric],
                            label="onehot",
                            color="#D3D3D3",  # light grey
                            linestyle="dotted",
                        )

                    # overlay random init
                    if add_rand:
                        axs[m, n].plot(
                            self._rand_layer_analysis_dict[collage_name][encoder_name][
                                metric
                            ],
                            label="random init",
                            color="#D3D3D3",  # light grey
                        )

                    # overlay stat init
                    if add_stat:
                        axs[m, n].plot(
                            self._stat_layer_analysis_dict[collage_name][encoder_name][
                                metric
                            ],
                            label="stat transfer",
                            color="#A9A9A9",  # dark grey
                            # linestyle="dotted",
                        )

            # add xlabels
            for ax in axs[metric_numb - 1]:
                ax.set_xlabel("layers", fontsize=16)
                ax.tick_params(axis="x", labelsize=16)

            # add column names
            for ax, col in zip(axs[0], encoder_names):
                ax.set_title(col, fontsize=16)

            # add row names
            for ax, row in zip(axs[:, 0], self._metric_dict[task]):
                ax.set_ylabel(row.replace("_", " "), fontsize=16)
                ax.tick_params(axis="y", labelsize=16)

            # set the plot yticks
            plt.gca().yaxis.set_major_formatter(FormatStrFormatter("%.2f"))

            # add legend
            handles, labels = axs[0, 0].get_legend_handles_labels()
            fig.legend(
                handles,
                labels,
                loc="upper left",
                bbox_to_anchor=[0.05, 1.025],
                fontsize=16,
                frameon=False,
                ncol=2,
            )

            # add whole plot level title
            fig.suptitle(
                collage_name.replace("_", " "), y=1.0025, fontsize=24, fontweight="bold"
            )
            fig.tight_layout()

            for plot_ext in [".svg", ".png"]:
                plt.savefig(
                    os.path.join(collage_folder, collage_name + plot_ext),
                    bbox_inches="tight",
                )

            plt.close()

    def parse_result_dicts(
        self,
        folder_path: str,
        task: str,
        dataset: str,
        split: str,
        encoder_name: str,
        flatten_emb: bool | str,
    ):
        """
        Parse the output result dictionaries for plotting

        Args:
        - folder_path: str, the folder path for the datasets

        Returns:
        - dict, encode name as key with a dict as its value
            where metric name as keys and the array of losses as values
        - str, details for collage plot
        """

        # get the list of output pickle files
        pkl_list = glob(f"{folder_path}/*.pkl")

        # get the max layer number for the array
        if encoder_name in TRANSFORMER_INFO.keys():
            max_layer_numb = TRANSFORMER_INFO[encoder_name][1] + 1
        elif encoder_name in CARP_INFO.keys():
            max_layer_numb = CARP_INFO[encoder_name][1]
        else:
            max_layer_numb = 1

        # init the ouput dict
        output_numb_dict = {
            metric: np.zeros([max_layer_numb]) for metric in self._metric_dict[task]
        }

        # loop through the list of the pickle files
        for pkl_file in pkl_list:
            # get the layer number
            layer_numb = int(get_filename(pkl_file).split("-")[-1].split("_")[-1])
            # load the result dictionary
            result_dict = pickle_load(pkl_file)

            # populate the processed dictionary
            for metric in self._metric_dict[task]:
                subset, kind = metric.split("_")
                if kind == "rho":
                    output_numb_dict[metric][layer_numb] = result_dict[subset][kind][0]
                else:
                    output_numb_dict[metric][layer_numb] = result_dict[subset][kind]

        # get some details for plotting and saving
        output_subfolder = checkNgen_folder(
            folder_path.replace(self._input_path, self._output_path)
        )

        for metric in output_numb_dict.keys():

            plot_name = f"{encoder_name}_{flatten_emb}_{metric}"
            plot_prefix = f"{task}_{dataset}_{split}"

            plt.figure()
            plt.plot(output_numb_dict[metric])
            plt.title(f"{plot_prefix} \n {plot_name}")
            plt.xlabel("layers")
            plt.ylabel("loss")

            for plot_ext in [".svg", ".png"]:
                plt.savefig(
                    os.path.join(output_subfolder, plot_name + plot_ext),
                    bbox_inches="tight",
                )
            plt.close()

        return output_numb_dict

    @property
    def layer_analysis_dict(self) -> dict:
        """Return a dict with dataset name as the key"""
        return self._layer_analysis_dict

In [19]:
layer_analysis_dict = LayerLoss(
    input_path="results/train_val_test",
    output_path="results/analysis_layer",
).layer_analysis_dict

In [10]:
layer_analysis_dict = LayerLoss(
    input_path="results/sklearn-scaley-noloader-rand",
    output_path="results/sklearn-scaley-noloader-rand_layer",
).layer_analysis_dict

Making results/sklearn-scaley-noloader-rand_layer/proeng/gb1/sampled ...
Making results/sklearn-scaley-noloader-rand_layer/proeng/gb1/sampled/esm1b_t33_650M_UR50S ...
Making results/sklearn-scaley-noloader-rand_layer/proeng/gb1/sampled/esm1b_t33_650M_UR50S/mean ...
Making results/sklearn-scaley-noloader-rand_layer/proeng/gb1/sampled/esm1_t6_43M_UR50S ...
Making results/sklearn-scaley-noloader-rand_layer/proeng/gb1/sampled/esm1_t6_43M_UR50S/mean ...
Making results/sklearn-scaley-noloader-rand_layer/proeng/gb1/sampled/esm1_t12_85M_UR50S ...
Making results/sklearn-scaley-noloader-rand_layer/proeng/gb1/sampled/esm1_t12_85M_UR50S/mean ...
Making results/sklearn-scaley-noloader-rand_layer/proeng/gb1/sampled/esm1_t34_670M_UR50S ...
Making results/sklearn-scaley-noloader-rand_layer/proeng/gb1/sampled/esm1_t34_670M_UR50S/mean ...


In [6]:
layer_analysis_dict = LayerLoss(
    input_path="results/sklearn-scaley-noloader",
    output_path="results/sklearn-scaley-noloader_layer",
).layer_analysis_dict

In [8]:
layer_analysis_dict = LayerLoss(
    input_path="results/sklearn-scaley-fixa-noloader",
    output_path="results/sklearn-scaley-fixa-noloader_layer",
).layer_analysis_dict

In [6]:
layer_analysis_dict = LayerLoss(
    input_path="results/sklearn-scaley-fixa",
    output_path="results/sklearn-scaley-fixa_layer",
).layer_analysis_dict

In [32]:
layer_analysis_dict.keys(), layer_analysis_dict["proeng_gb1_low_vs_high_mean"].keys(), layer_analysis_dict["proeng_gb1_low_vs_high_mean"]["esm1b_t33_650M_UR50S"].keys()

(dict_keys(['proeng_gb1_two_vs_rest_mean', 'proeng_gb1_low_vs_high_mean', 'proeng_aav_one_vs_many_mean']),
 dict_keys(['esm1b_t33_650M_UR50S', 'esm1_t6_43M_UR50S', 'esm1_t12_85M_UR50S', 'esm1_t34_670M_UR50S']),
 dict_keys(['train_mse', 'test_ndcg', 'test_rho']))

In [26]:
layer_analysis_dict = LayerLoss(
    input_path="results/sklearn-scaley-noloader",
    output_path="results/sklearn-scaley-noloader_layer",
).layer_analysis_dict

Making results/sklearn-scaley-noloader_layer-onehot ...
Making results/sklearn-scaley-noloader_layer-onehot/proeng ...
Making results/sklearn-scaley-noloader_layer-onehot/proeng/gb1 ...
Making results/sklearn-scaley-noloader_layer-onehot/proeng/gb1/sampled ...
Making results/sklearn-scaley-noloader_layer-onehot/proeng/gb1/sampled/onehot ...
Making results/sklearn-scaley-noloader_layer-onehot/proeng/gb1/sampled/onehot/flatten ...
Making results/sklearn-scaley-noloader_layer-onehot/proeng/gb1/two_vs_rest ...
Making results/sklearn-scaley-noloader_layer-onehot/proeng/gb1/two_vs_rest/onehot ...
Making results/sklearn-scaley-noloader_layer-onehot/proeng/gb1/two_vs_rest/onehot/flatten ...
Making results/sklearn-scaley-noloader_layer-onehot/proeng/gb1/low_vs_high ...
Making results/sklearn-scaley-noloader_layer-onehot/proeng/gb1/low_vs_high/onehot ...
Making results/sklearn-scaley-noloader_layer-onehot/proeng/gb1/low_vs_high/onehot/flatten ...
Making results/sklearn-scaley-noloader_layer-oneh