In [3]:
%cd ~/repo/protein-transfer

/home/t-fli/repo/protein-transfer


In [4]:
%load_ext blackcellmagic

In [3]:
from scr.preprocess.data_process import TaskProcess

In [None]:
TaskProcess().sum_file_df

In [8]:
import os

from glob import glob

In [6]:
import pandas as pd

In [9]:
from scr.utils import pickle_load, get_task_data_split

In [73]:
class DatasetStripHistogram(BokehSave):
    def __init__(
        self,
        dataset_folder: str,
        split_order: list[str] | None = None,
        path2folder: str = "results/dataset_vis",
        plot_exts: list = PLOT_EXTS,
        plot_height: int = 300,
        plot_width: int = 450,
        axis_font_size: str = "10pt",
        title_font_size: str = "10pt",
        x_name: str = "",
        y_name: str = "fitness",
        gridoff: bool = True,
    ) -> None:
        """
        Args:
        - dataset_folder: str, ie. data/proeng/gb1
        - split_order: list[str], ie. ["low_vs_high, "two_vs_rest, "sampled]
        """

        self._dataset_folder = os.path.normpath(f"{dataset_folder}/*.pkl")

        assert "proeng" in self._dataset_folder, "only support proeng datasets"

        dfs = []

        for pkl in glob(self._dataset_folder):
            task, data, split = get_task_data_split(pkl)
            df = pickle_load(pkl)
            df["split"] = split
            df.loc[df["validation"] == True, "set"] = "val"

            dfs.append(df)
            print(dfs)
        
        if len(dfs) > 1:
            self._cat_dfs = pd.concat(dfs, ignore_index=True, axis=0)
        else:
            
            self._cat_dfs = dfs[0]

        set_order = ["train", "val", "test"]

        if split_order is None:
            cat_orders = set_order
        else:
            cat_orders = [(i, j) for i in split_order for j in set_order]

        self.bokeh_plot = striphistogram(
                self._cat_dfs,
                q="target",
                cats=["split", "set"],
                spread="jitter",
                # jitter=True,
                color_column="set",
                top_level="histogram",
                marker_kwargs={"alpha": 0.1},
                fill_kwargs={"fill_alpha": 0.1},
                order=cat_orders,
                # spread_kwargs={'distribution': 'normal', 'width': 0.1},
                q_axis="y",
            )

        super(DatasetStripHistogram, self).__init__(
            bokeh_plot=self.bokeh_plot,
            path2folder=path2folder,
            plot_name="-".join(get_task_data_split(self._dataset_folder)[:-1]),
            plot_exts=plot_exts,
            plot_height=plot_height,
            plot_width=plot_width,
            axis_font_size=axis_font_size,
            title_font_size=title_font_size,
            x_name=x_name,
            y_name=y_name,
            gridoff=gridoff,
        )

In [51]:
dfs = []
for pkl in glob("data/proeng/gb1/*.pkl"):
    task, data, split = get_task_data_split(pkl)
    df = pickle_load(pkl)
    df["split"] = split
    df.loc[df["validation"] == True, "set"] = "val"

    dfs.append(df)

In [52]:
cat_dfs = pd.concat(dfs, ignore_index=True, axis=0)
cat_dfs.head()

Unnamed: 0,sequence,target,set,validation,mut_name,mut_numb,split
0,MQYKLILNGKTLKGETTTEAVDAATAEKVFKQYANDNGVDGEWTYD...,1.0,test,,parent,0,low_vs_high
1,MQYKLILNGKTLKGETTTEAVDAATAEKVFKQYANDNGIDGEWTYD...,1.445905,test,,V39I,1,low_vs_high
2,MQYKLILNGKTLKGETTTEAVDAATAEKVFKQYANDNGLDGEWTYD...,1.690164,test,,V39L,1,low_vs_high
3,MQYKLILNGKTLKGETTTEAVDAATAEKVFKQYANDNGMDGEWTYD...,1.17055,test,,V39M,1,low_vs_high
4,MQYKLILNGKTLKGETTTEAVDAATAEKVFKQYANDNGVAGEWTYD...,2.401243,test,,D40A,1,low_vs_high


In [24]:
"""For dataset vis"""

from __future__ import annotations

import os

from glob import glob

import pandas as pd

import iqplot

from scr.params.vis import PLOT_EXTS
from scr.vis.vis_utils import BokehSave
from scr.vis.iqplot_striphis import striphistogram
from scr.utils import get_task_data_split, read_std_csv, pickle_load, read_std_csv


class DatasetECDF(BokehSave):
    def __init__(
        self,
        dataset_path: str,
        path2folder: str = "results/dataset_vis",
        plot_exts: list = PLOT_EXTS,
        plot_height: int = 300,
        plot_width: int = 450,
        axis_font_size: str = "10pt",
        title_font_size: str = "10pt",
        x_name: str = "fitness",
        y_name: str = "ecdf",
        gridoff: bool = True,
    ) -> None:

        df = read_std_csv(dataset_path)

        df.loc[df["validation"] == True, "set"] = "val"

        self.bokeh_plot = iqplot.ecdf(
            df,
            q="target",
            cats="set",
            conf_int=True,
            # style="staircase",
            order=["train", "val", "test"],
            legend_location="bottom_right",
            marker_kwargs={"alpha": 0.5},
            fill_kwargs={"fill_alpha": 0.1}
            # line_kwargs={"line_width": 2.5},
        )

        super(DatasetECDF, self).__init__(
            bokeh_plot=self.bokeh_plot,
            path2folder=path2folder,
            plot_name="-".join(get_task_data_split(dataset_path)),
            plot_exts=plot_exts,
            plot_height=plot_height,
            plot_width=plot_width,
            axis_font_size=axis_font_size,
            title_font_size=title_font_size,
            x_name=x_name,
            y_name=y_name,
            gridoff=gridoff,
        )


class DatasetStripHistogram(BokehSave):
    def __init__(
        self,
        dataset_folder: str,
        split_order: list[str] | None = None,
        path2folder: str = "results/dataset_vis",
        plot_exts: list = PLOT_EXTS,
        plot_height: int = 400,
        plot_width: int = 600,
        axis_font_size: str = "10pt",
        title_font_size: str = "10pt",
        x_name: str = "",
        y_name: str = "fitness",
        gridoff: bool = True,
    ) -> None:
        """
        Args:
        - dataset_folder: str, ie. data/proeng/gb1
        - split_order: list[str], ie. ["low_vs_high", "two_vs_rest", "sampled"]
        """

        self._dataset_folder = os.path.normpath(dataset_folder)
        self._dataset_paths = glob(f"{self._dataset_folder}/*.pkl")

        assert "proeng" in self._dataset_folder, "only support proeng datasets"

        if len(self._dataset_paths) == 0:
            glob(f"{os.path.normpath(dataset_folder)}/*.csv")
            self._cat_dfs = read_std_csv(
                glob(f"{os.path.normpath(dataset_folder)}/*.csv")[0]
            )
            self._cat_dfs.loc[self._cat_dfs["validation"] == True, "set"] = "val"
        else:
            dfs = []
            for pkl in self._dataset_paths:
                task, data, split = get_task_data_split(pkl)
                df = pickle_load(pkl)
                df["split"] = split
                df.loc[df["validation"] == True, "set"] = "val"

            dfs.append(df)

            self._cat_dfs = pd.concat(dfs, ignore_index=True, axis=0)

        set_order = ["train", "val", "test"]

        if split_order is None:
            cat_orders = set_order
            cats_list = ["set"]
        else:
            cat_orders = [(i, j) for i in split_order for j in set_order]
            cats_list = ["split", "set"]

        self.bokeh_plot = striphistogram(
            self._cat_dfs,
            q="target",
            cats=cats_list,
            spread="jitter",
            # jitter=True,
            color_column="set",
            top_level="histogram",
            marker_kwargs={"alpha": 0.1},
            fill_kwargs={"fill_alpha": 0.1},
            order=cat_orders,
            # spread_kwargs={'distribution': 'normal', 'width': 0.1},
            q_axis="y",
        )

        super(DatasetStripHistogram, self).__init__(
            bokeh_plot=self.bokeh_plot,
            path2folder=path2folder,
            plot_name="-".join(get_task_data_split(self._dataset_folder)[:-1]),
            plot_exts=plot_exts,
            plot_height=plot_height,
            plot_width=plot_width,
            axis_font_size=axis_font_size,
            title_font_size=title_font_size,
            x_name=x_name,
            y_name=y_name,
            gridoff=gridoff,
        )

In [None]:
DatasetStripHistogram("data/proeng/gb1", split_order=["low_vs_high", "two_vs_rest", "sampled"], plot_width=800)
DatasetStripHistogram("data/proeng/aav", split_order=["one_vs_many", "two_vs_many",])

In [20]:
pd.read_csv("data/proeng/thermo/mixed_split.csv").validation.unique()

array([nan, True], dtype=object)

In [23]:
DatasetStripHistogram("data/proeng/thermo", plot_width=400)

<__main__.DatasetStripHistogram at 0x7f2fbb21d9a0>

In [14]:
import iqplot

In [15]:
import pandas as pd

from bokeh.io import push_notebook, show, output_notebook
from bokeh.layouts import row 
from bokeh.plotting import figure

output_notebook() 

In [16]:
from __future__ import annotations

import os

import pandas as pd

import iqplot

import bokeh
from bokeh.io import show, export_svg, export_png
from bokeh.plotting import show, figure
from bokeh.models.annotations import Title

bokeh.io.output_notebook()

import holoviews as hv
from holoviews import dim

hv.extension("bokeh")



In [17]:
from scr.vis.iqplot_striphis import striphistogram

In [28]:
[(i, j) for i in ['low_vs_high', 'two_vs_rest', 'sampled'] for j in ["train", "val", "test"]]

[('low_vs_high', 'train'),
 ('low_vs_high', 'val'),
 ('low_vs_high', 'test'),
 ('two_vs_rest', 'train'),
 ('two_vs_rest', 'val'),
 ('two_vs_rest', 'test'),
 ('sampled', 'train'),
 ('sampled', 'val'),
 ('sampled', 'test')]

In [27]:
cat_dfs.split.unique()

array(['low_vs_high', 'two_vs_rest', 'sampled'], dtype=object)

In [47]:
from pandas.api.types import CategoricalDtype

In [48]:
cat_dfs["set"] = cat_dfs["set"].astype(CategoricalDtype(
    ["train", "val", "test"], 
    ordered=True
))
cat_dfs = cat_dfs.sort_values("set")
cat_dfs.head()

Unnamed: 0,sequence,target,set,validation,mut_name,mut_numb,split
7927,MQYKLILNGKTLKGETTTEAVDAATAEKVFKQYANDNGWLPEWTYD...,0.001794,train,,V39W:D40L:G41P:V54Y,4,low_vs_high
8629,MQYKLILNGKTLKGETTTEAVDAATAEKVFKQYANDNGYTQEWTYD...,0.005123,train,,V39Y:D40T:G41Q:V54Q,4,low_vs_high
8630,MQYKLILNGKTLKGETTTEAVDAATAEKVFKQYANDNGYTVEWTYD...,0.002236,train,,V39Y:D40T:G41V:V54P,4,low_vs_high
8631,MQYKLILNGKTLKGETTTEAVDAATAEKVFKQYANDNGYTYEWTYD...,0.013052,train,,V39Y:D40T:G41Y:V54C,4,low_vs_high
8632,MQYKLILNGKTLKGETTTEAVDAATAEKVFKQYANDNGYTYEWTYD...,0.001218,train,,V39Y:D40T:G41Y:V54F,4,low_vs_high


In [53]:
cat_dfs.set = pd.Categorical(cat_dfs.set, categories = ["train", "val", "test"], ordered=True)
cat_dfs = cat_dfs.sort_values(["set"])
cat_dfs.head()

Unnamed: 0,sequence,target,set,validation,mut_name,mut_numb,split
7927,MQYKLILNGKTLKGETTTEAVDAATAEKVFKQYANDNGWLPEWTYD...,0.001794,train,,V39W:D40L:G41P:V54Y,4,low_vs_high
8629,MQYKLILNGKTLKGETTTEAVDAATAEKVFKQYANDNGYTQEWTYD...,0.005123,train,,V39Y:D40T:G41Q:V54Q,4,low_vs_high
8630,MQYKLILNGKTLKGETTTEAVDAATAEKVFKQYANDNGYTVEWTYD...,0.002236,train,,V39Y:D40T:G41V:V54P,4,low_vs_high
8631,MQYKLILNGKTLKGETTTEAVDAATAEKVFKQYANDNGYTYEWTYD...,0.013052,train,,V39Y:D40T:G41Y:V54C,4,low_vs_high
8632,MQYKLILNGKTLKGETTTEAVDAATAEKVFKQYANDNGYTYEWTYD...,0.001218,train,,V39Y:D40T:G41Y:V54F,4,low_vs_high


In [54]:
cat_dfs.head()

Unnamed: 0,sequence,target,set,validation,mut_name,mut_numb,split
7927,MQYKLILNGKTLKGETTTEAVDAATAEKVFKQYANDNGWLPEWTYD...,0.001794,train,,V39W:D40L:G41P:V54Y,4,low_vs_high
8629,MQYKLILNGKTLKGETTTEAVDAATAEKVFKQYANDNGYTQEWTYD...,0.005123,train,,V39Y:D40T:G41Q:V54Q,4,low_vs_high
8630,MQYKLILNGKTLKGETTTEAVDAATAEKVFKQYANDNGYTVEWTYD...,0.002236,train,,V39Y:D40T:G41V:V54P,4,low_vs_high
8631,MQYKLILNGKTLKGETTTEAVDAATAEKVFKQYANDNGYTYEWTYD...,0.013052,train,,V39Y:D40T:G41Y:V54C,4,low_vs_high
8632,MQYKLILNGKTLKGETTTEAVDAATAEKVFKQYANDNGYTYEWTYD...,0.001218,train,,V39Y:D40T:G41Y:V54F,4,low_vs_high


In [55]:
sh = striphistogram(
        cat_dfs,
        q="target",
        cats=["split", "set"],
        spread="jitter",
        # jitter=True,
        color_column="set",
        top_level="histogram",
        marker_kwargs={"alpha": 0.1},
        fill_kwargs={"fill_alpha": 0.1},
        order=[(i, j) for i in ['low_vs_high', 'two_vs_rest', 'sampled'] for j in ["train", "val", "test"]],
        # spread_kwargs={'distribution': 'normal', 'width': 0.1},
        q_axis="y",
        
    )
show(sh)

In [56]:
from scr.vis.vis_utils import BokehSave

In [58]:
BokehSave(sh,
        path2folder = "results/dataset_vis",
        plot_name = "proeng gb1",
        plot_height = 400,
        plot_width = 600,
        # axis_font_size: str = "10pt",
        # title_font_size: str = "10pt",
        x_name = "",
        y_name = "fitness",)

<scr.vis.vis_utils.BokehSave at 0x7f4d4de1e9a0>

In [24]:
test = hv.render(
            hv.Violin(df, kdims="set", vdims="target",).opts(
                # height=self._plot_height,
                # width=self._plot_width,
                # violin_fill_color=self._violin_fill_color,
                # cmap=self._cmap,
            )
        )