<a href="https://colab.research.google.com/github/goldjunge3010/masterpraktikum/blob/main/RetentionTimeReport_CustomPlots.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Install and import packages

In [1]:
# install necessary packages
!python -m pip install -q dlomix==0.0.4
!python -m pip install -q wandb

  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for fpdf (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m9.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m188.5/188.5 kB[0m [31m10.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m215.6/215.6 kB[0m [31m9.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m62.7/62.7 kB[0m [31m4.9 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for pathtools (setup.py) ... [?25l[?25hdone


In [2]:
# import necessary packages
import numpy as np
import pandas as pd
import tensorflow as tf
import re

import wandb
from wandb.keras import WandbCallback
from wandb.keras import WandbMetricsLogger
import wandb.apis.reports as wr

import dlomix
from dlomix import constants, data, eval, layers, models, pipelines, reports, utils
from dlomix.data import RetentionTimeDataset
from dlomix.models import PrositRetentionTimePredictor
from dlomix.eval import TimeDeltaMetric

[34m[1mwandb[0m: Thanks for trying out the Report API!
[34m[1mwandb[0m: For a tutorial, check out https://colab.research.google.com/drive/1CzyJx1nuOS4pdkXa2XPaRQyZdmFmLmXV
[34m[1mwandb[0m: 
[34m[1mwandb[0m: Try out tab completion to see what's available.
[34m[1mwandb[0m:   ∟ everything:    `wr.<tab>`
[34m[1mwandb[0m:       ∟ panels:    `wr.panels.<tab>`
[34m[1mwandb[0m:       ∟ blocks:    `wr.blocks.<tab>`
[34m[1mwandb[0m:       ∟ helpers:   `wr.helpers.<tab>`
[34m[1mwandb[0m:       ∟ templates: `wr.templates.<tab>`
[34m[1mwandb[0m:       
[34m[1mwandb[0m: For bugs/feature requests, please create an issue on github: https://github.com/wandb/wandb/issues


# Initialize WANDB

In [3]:
# Initialize config
config = {
  "seq_length" : 64,
  "batch_size" : 32,
  "val_ratio" : 0.2,
  "lr" : 0.001,
  "optimizer" : "Adam",
  "loss" : "mse"
}

# Initialize WANDB
PROJECT = 'rt_report_2'
RUN = "run_1"
wandb.init(project = PROJECT, name = RUN, config = config)

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

 ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


# Load data

In [4]:
# load small train dataset
TRAIN_DATAPATH = 'https://raw.githubusercontent.com/goldjunge3010/masterpraktikum/main/third_pool_tresh_1_0_train.csv'
#TRAIN_DATAPATH = 'https://raw.githubusercontent.com/wilhelm-lab/dlomix/develop/example_dataset/proteomTools_train_val.csv'

# create dataset
rtdata = RetentionTimeDataset(data_source=TRAIN_DATAPATH,
                              seq_length = config["seq_length"],
                              batch_size = config["batch_size"],
                              val_ratio = config["val_ratio"],
                              test = False,
                              sequence_col = "modified_sequence",
                              target_col = "indexed_retention_time")

print(f"Batch size: {rtdata.batch_size}")
print(f"Number training samples : {len(rtdata.train_data) * rtdata.batch_size}")
print(f"Number validation samples : {len(rtdata.val_data) * rtdata.batch_size}")

Batch size: 32
Number training samples : 1440
Number validation samples : 384


# Initialize Report

In [7]:
# Create a report
report = Report(project = "rt_report_2",
                title = "Comparison of learning rates with custom plots :)",
                description = "A quick comparison of the influence of  learning rates using the ADAM optimizer",
                dataset = rtdata)



# Log data to WANDB

In [8]:
report.log_data()

log_table_func: counts_hist_table


# Train model

In [9]:
# create Prosit retention time predictor
model = PrositRetentionTimePredictor(seq_length = config["seq_length"])

# create the optimizer object
optimizer = tf.keras.optimizers.Adam(learning_rate = config["lr"])

# compile the model  with the optimizer and the metrics we want to use, we can add our custom timedelta metric
model.compile(optimizer = optimizer,
              loss = config["loss"],
              metrics=['mean_absolute_error', TimeDeltaMetric()])

# train the model
history = model.fit(rtdata.train_data,
                    validation_data=rtdata.val_data,
                    epochs=3, callbacks=[WandbMetricsLogger(log_freq = "batch")] )

# finish wandb run
wandb.finish()

Epoch 1/3
Epoch 2/3
Epoch 3/3


VBox(children=(Label(value='0.099 MB of 0.099 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
batch/batch_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
batch/learning_rate,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
batch/loss,▂▁▁▁▁▁▁▁▂▂▂▂▂▆▇▇▄▄▃▃▃▃▃▃▃▃▂██▅▄▃▃▃▃▃▃▃▃▃
batch/mean_absolute_error,▂▁▁▁▁▁▁▁▂▂▂▂▂▇█▇▅▄▃▃▃▃▃▃▃▃▃██▅▄▃▃▃▃▃▃▃▃▃
batch/timedelta,▅▄▃▂▂▁▁▁▁▂▂▃▃▂▅▇▄▄▃▃▃▃▃▃▃▃▃▄█▄▃▃▃▂▂▂▃▃▃▃
epoch/epoch,▁▅█
epoch/learning_rate,▁▁▁
epoch/loss,▁▇█
epoch/mean_absolute_error,▁▇█
epoch/timedelta,▁██

0,1
batch/batch_step,134.0
batch/learning_rate,0.001
batch/loss,1446.35828
batch/mean_absolute_error,31.44623
batch/timedelta,31.51466
epoch/epoch,2.0
epoch/learning_rate,0.001
epoch/loss,1446.35828
epoch/mean_absolute_error,31.44623
epoch/timedelta,31.51466


# Create Report

In [10]:
report.create_report(add_data_section = True, add_train_section = True, add_val_section = True)

In [None]:
v = ['epoch/val_timedelta', 'epoch/val_loss', 'epoch/val_mean_absolute_error']
t = ['epoch/loss', 'epoch/mean_absolute_error', 'epoch/timedelta']
v.sort()
t.sort()
print(v)
print(t)
z = list(zip(v,t))

['epoch/val_loss', 'epoch/val_mean_absolute_error', 'epoch/val_timedelta']
['epoch/loss', 'epoch/mean_absolute_error', 'epoch/timedelta']


In [None]:
for a in z:
  print(list(a))

['epoch/val_loss', 'epoch/loss']
['epoch/val_mean_absolute_error', 'epoch/mean_absolute_error']
['epoch/val_timedelta', 'epoch/timedelta']


# Report class

In [6]:
class Report():

  def __init__(self, project:str, title: str, description: str, dataset: RetentionTimeDataset):
    self.entity = wandb.apis.PublicApi().default_entity
    self.project = project
    self.title = title
    self.description = description
    self.api = wandb.Api()
    self.dataset = dataset
    self.table_key_len = ""
    self.table_key_rt = ""

  def create_report(self, add_config_section = True, add_data_section = True, add_train_section = True, add_val_section = True, add_train_val_section = True):
    report = wr.Report(
        project = self.project,
        title = self.title,
        description = self.description
    )

    report.blocks = [
        wr.TableOfContents()
    ]
    if add_config_section:
      report.blocks += self.config_section()
    if add_data_section:
      report.blocks += self.data_section()
    if add_train_section:
      report.blocks += self.train_section()
    if add_val_section:
      report.blocks += self.val_section()
    if add_train_val_section:
      report.blocks += self.train_val_section()

    report.save()



  # get metrics of last run in project or from specified run_id
  def get_metrics(self, run_id = None):
    if run_id:
      # run is specified by <entity>/<project>/<run_id>
      run = self.api.run(path = f"{self.entity}/{self.project}/{run_id}")
      metrics_dataframe = run.history()
      return metrics_dataframe
    else:
      # get metrics of latest run
      # api.runs seems to have a delay
      runs = self.api.runs(path = f"{self.entity}/{self.project}")
      run = runs[0]
      metrics_dataframe = run.history()
      return metrics_dataframe

  # get metric names split into train/val, train is further split into batch/epoch
  def get_metrics_names(self):
    metrics = self.get_metrics()
    # filter strings from list that are not starting with "_" and do not contain "val"
    pre_filter = [string for string in metrics if not string.startswith("_")]
    batch_train_metrics_names = [string for string in pre_filter if ("val" not in string.lower()) & ("epoch" not in string.lower()) & ("table" not in string.lower())]
    epoch_train_metrics_names = [string for string in pre_filter if ("val" not in string.lower()) & ("batch" not in string.lower()) & ("table" not in string.lower())]
    # filter strings from list that contain "val"
    epoch_val_metrics_names = list(filter(lambda x: "val" in x.lower(), metrics))
    # filter strings from train metrics that are 'epoch/learning_rate' and 'epoch/epoch'
    strings_to_filter = ['epoch/learning_rate', 'epoch/epoch']
    epoch_train_metrics_names = [string for string in epoch_train_metrics_names if string not in strings_to_filter]
    return batch_train_metrics_names, epoch_train_metrics_names, epoch_val_metrics_names

  def get_train_val_metrics_names(self):
    _, epoch_train_metrics_names, epoch_val_metrics_names = self.get_metrics_names()
    epoch_train_metrics_names.sort()
    epoch_val_metrics_names.sort()
    return list(zip(epoch_train_metrics_names, epoch_val_metrics_names))


  def config_section(self):
    config_block = [
        wr.H1(text = "Config"),
        wr.PanelGrid(
          runsets=[
            wr.Runset(self.entity, self.project),
          ],
          panels=[
            wr.RunComparer(layout = {'w': 24})
          ],
        ),
        wr.HorizontalRule(),
    ]
    return config_block

  def data_section(self):
    data_block = [
        wr.H1(text = "Data"),
        wr.P("Lorem ipsum dolor sit amet. Aut laborum perspiciatis sit odit omnis aut aliquam voluptatibus ut rerum molestiae sed assumenda nulla ut minus illo sit sunt explicabo? Sed quia architecto est voluptatem magni sit molestiae dolores. Non animi repellendus ea enim internos et iste itaque quo labore mollitia aut omnis totam."),
        wr.PanelGrid(
          runsets=[
            wr.Runset(self.entity, self.project),
          ],
          panels=[
            wr.CustomChart(
              query = {'summaryTable': {"tableKey" : self.table_key_len}},
              chart_name='master_praktikum/hist_pep_len',
              chart_fields={'value': self.dataset.sequence_col}
            ),
            wr.CustomChart(
              query = {'summaryTable': {"tableKey" : self.table_key_rt}},
              chart_name='master_praktikum/hist_ret_time',
              chart_fields={'value': self.dataset.target_col}
            )
          ]
        ),
        wr.HorizontalRule(),
    ]
    return data_block

  def train_section(self):
    batch_train_metrics_names, epoch_train_metrics_names, _ = self.get_metrics_names()
    panel_list_batch = []
    panel_list_epoch = []
    for name in batch_train_metrics_names:
      panel_list_batch.append(wr.LinePlot(x='Step', y=[name]))
    for name in epoch_train_metrics_names:
      panel_list_epoch.append(wr.LinePlot(x='Step', y=[name]))
    train_block = [
        wr.H1(text = "Training metrics"),
        wr.P("Lorem ipsum dolor sit amet. Aut laborum perspiciatis sit odit omnis aut aliquam voluptatibus ut rerum molestiae sed assumenda nulla ut minus illo sit sunt explicabo? Sed quia architecto est voluptatem magni sit molestiae dolores. Non animi repellendus ea enim internos et iste itaque quo labore mollitia aut omnis totam."),
        wr.H2(text = "per batch"),
        wr.PanelGrid(
          runsets=[
            wr.Runset(self.entity, self.project),
          ],
          panels = panel_list_batch
        ),
        wr.H2(text = "per epoch"),
        wr.PanelGrid(
          runsets=[
            wr.Runset(self.entity, self.project),
          ],
          panels = panel_list_epoch
        ),
        wr.HorizontalRule(),
    ]
    return train_block

  def val_section(self):
    _, _, epoch_val_metrics_names = self.get_metrics_names()
    panel_list_epoch = []
    for name in epoch_val_metrics_names:
      panel_list_epoch.append(wr.LinePlot(x='Step', y=[name]))
    val_block = [
        wr.H1(text = "Validation metrics"),
        wr.P("Lorem ipsum dolor sit amet. Aut laborum perspiciatis sit odit omnis aut aliquam voluptatibus ut rerum molestiae sed assumenda nulla ut minus illo sit sunt explicabo? Sed quia architecto est voluptatem magni sit molestiae dolores. Non animi repellendus ea enim internos et iste itaque quo labore mollitia aut omnis totam."),
        wr.H2(text = "per epoch"),
        wr.PanelGrid(
          runsets=[
            wr.Runset(self.entity, self.project),
          ],
          panels = panel_list_epoch
        ),
        wr.HorizontalRule(),
    ]
    return val_block

  def train_val_section(self):
    train_val_metrics_names = self.get_train_val_metrics_names()
    panel_list_epoch = []
    for name in train_val_metrics_names:
      panel_list_epoch.append(wr.LinePlot(x='Step', y=list(name)))
    train_val_block = [
        wr.H1(text = "Train - Validation metrics"),
        wr.P("Lorem ipsum dolor sit amet. Aut laborum perspiciatis sit odit omnis aut aliquam voluptatibus ut rerum molestiae sed assumenda nulla ut minus illo sit sunt explicabo? Sed quia architecto est voluptatem magni sit molestiae dolores. Non animi repellendus ea enim internos et iste itaque quo labore mollitia aut omnis totam."),
        wr.H2(text = "per epoch"),
        wr.PanelGrid(
          runsets=[
            wr.Runset(self.entity, self.project),
          ],
          panels = panel_list_epoch
        ),
        wr.HorizontalRule(),
    ]
    return train_val_block

  # function to log sequence length table to wandb
  def log_sequence_length_table(self, data: pd.DataFrame, seq_col:str = "modified_sequence"):
    name_hist = "counts_hist"
    counts = self.count_seq_length(data, seq_col)
    # convert to df for easier handling
    counts_df = counts.to_frame()
    table = wandb.Table(dataframe = counts_df)
    # log to wandb
    hist = wandb.plot_table(
      vega_spec_name="master_praktikum/hist_pep_len",
      data_table = table,
      fields = {"value" : seq_col}
    )
    wandb.log({name_hist: hist})
    name_hist_table = name_hist + "_table"
    print(f"log_table_func: {name_hist_table}")
    return name_hist_table

  # function to count sequence length
  def count_seq_length(self, data: pd.DataFrame, seq_col: str) -> pd.Series:
      pattern = re.compile(r"\[UNIMOD:.*\]", re.IGNORECASE)
      data[seq_col].replace(pattern, "", inplace= True)
      return data[seq_col].str.len()

  # function to log retention time table to wandb
  def log_rt_table(self, data: pd.DataFrame, rt_col:str = "indexed_retention_time"):
    name_hist = "rt_hist"
    rt = data.loc[:,rt_col]
    # convert to df for easier handling
    rt_df = rt.to_frame()
    table = wandb.Table(dataframe = rt_df)
    # log to wandb
    hist = wandb.plot_table(
      vega_spec_name="master_praktikum/hist_ret_time",
      data_table = table,
      fields = {"value" : rt_col}
    )
    wandb.log({name_hist: hist})
    name_hist_table = name_hist + "_table"
    return name_hist_table

  def log_data(self):

    # check if datasource is a string
    if isinstance(self.dataset.data_source, str):
      # read corresponding file
      file_extension = self.dataset.data_source.split(".")[-1]
      match file_extension:
        case "csv":
          data = pd.read_csv(self.dataset.data_source)
        case "json":
          data = pd.read_json(self.dataset.data_source)
        case "parquet":
          data = pd.read_parquet(self.dataset.data_source, engine='fastparquet')
      self.table_key_len = self.log_sequence_length_table(data, self.dataset.sequence_col)
      self.table_key_rt = self.log_rt_table(data, self.dataset.target_col)

    # check if datasource is a tuple of two ndarrays or two lists
    if isinstance(self.dataset.data_source, tuple) and all(isinstance(item, (np.ndarray, list)) for item in self.dataset.data_source) and len(self.dataset.data_source) == 2:
      data = pd.DataFrame({self.dataset.sequence_col: self.dataset.data_source[0], self.dataset.target_col: self.dataset.data_source[1]})
      self.table_key_len = self.log_sequence_length_table(data, self.dataset.sequence_col)
      self.table_key_rt = self.log_rt_table(data, self.dataset.target_col)

    # check if datasource is a single ndarray or list
    # does not work? maybe error in RetentionTimeDataset
    if isinstance(self.dataset.data_source, (np.ndarray, list)):
      data = pd.DataFrame({self.dataset.sequence_col: self.dataset.data_source})
      self.table_key_len = self.log_sequence_length_table(data, self.dataset.sequence_col)





#TODO:
# - add custom plots for data - done!
# - add cutom plot for residuals - needs to be done in different notebook - comparing models
# - log data only one time after initialising wandb run
#   -> Workflow: load data -> initialize wandb run -> log_data(yes/no) -> train model -> create report - done! not sure if ideal but not bad at least
# increase robustness
# implement possibility to delete certain plots programatically
# RetentionTimeDataset übergeben - done - single ndarray does not work though
# Train und Val in einem Plot
# add text in sections

In [None]:
wandb.finish()

VBox(children=(Label(value='0.002 MB of 0.013 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.139411…