# SUBMISSION NOTEBOOK

## SETUP

#### imports 

In [3]:
import pandas as pd
import os
import datetime
import subprocess
import re
from omegaconf import OmegaConf
import pathlib
from tqdm.auto import tqdm
import itertools

#### CONFIG:

In [4]:
# USER = 'YOURUSERNAME'
# BASE = f'/home/{USER}/RetinalRisk'
CODE_BASE = '/sc-projects/sc-proj-ukb-cvd/code/RetinalRisk'
SUBMISSION_BASE = '/sc-projects/sc-proj-ukb-cvd/submissions/RetinalRisk'

TAG = 230905
JOBNAME = f'fullrun_retina'

EXPERIMENT_NAME = f'22_retinalrisk_{TAG}_{JOBNAME}'   # name under which to store the generated .sh scripts and yamls
TEMPLATE_CONFIG = f'{CODE_BASE}/config/'   # template yaml to use
TRAIN_SCRIPT = f'{CODE_BASE}/retinalrisk/scripts/train_retina.py'     # python train script to use

# os.makedirs(f'/home/{USER}/tmp/{EXPERIMENT_NAME}/job_submissions', exist_ok=True)
# os.makedirs(f'/home/{USER}/tmp/{EXPERIMENT_NAME}/job_configs', exist_ok=True)

os.makedirs(f'{SUBMISSION_BASE}/{EXPERIMENT_NAME}/job_submissions', exist_ok=True)
os.makedirs(f'{SUBMISSION_BASE}/{EXPERIMENT_NAME}/job_configs', exist_ok=True)
os.makedirs(f'{SUBMISSION_BASE}/{EXPERIMENT_NAME}/job_outputs', exist_ok=True)

In [10]:
BASE_HYPERPARAMS = [
 f'setup.name={TAG}_{JOBNAME}',
    "training.gradient_checkpointing=False",
    "training.patience=40",
    "datamodule/covariates=no_covariates",
    "model=image",
    "setup.use_data_artifact_if_available=False",
    "head=mlp",
    "head.kwargs.num_hidden=512",
    "head.kwargs.num_layers=2",
    "head.dropout=0.5",
    "training.optimizer_kwargs.weight_decay=0.001",
    "training.optimizer_kwargs.lr=0.0001",
    "model.freeze_encoder=False",
    "model.encoder=convnext_small",
    "datamodule.batch_size=256",
    "training.warmup_period=8",
    "datamodule/augmentation=contrast_sharpness_posterize",
    "datamodule.img_size_to_gpu=420",
    "datamodule.num_workers=16",
    "model.pretrained=True",
 ]

RETAGESEX_HYPERPARAMS = [
 f'setup.name={TAG}_RetAgeSex',
    'training.gradient_checkpointing=False', 
    'training.patience=40', 
    'datamodule/covariates=agesex', 
    'model=image', 
    'setup.use_data_artifact_if_available=False', 
    'head=mlp', 
    'head.kwargs.num_hidden=512', 
    'head.kwargs.num_layers=2', 
    'head.dropout=0.5', 
    'training.optimizer_kwargs.weight_decay=0.001', 
    'training.optimizer_kwargs.lr=0.0001', 
    'model.freeze_encoder=False', 
    'model.encoder=convnext_small', 
    'datamodule.batch_size=256', 
    'training.warmup_period=8', 
    'datamodule/augmentation=contrast_sharpness_posterize', 
    'datamodule.img_size_to_gpu=420',
    'datamodule.num_workers=16',
    'model.pretrained=True',
 ]

In [11]:
parameters = {
    'datamodule.partition': [0, 4, 5, 7, 9, 10, 20], # Partitions with eye test centers
    # partition 0 should have no samples and should fail, included as a sanity check
    #'datamodule.partition': [i for i in range(0, 5)], # CHRISTINA
    #'datamodule.partition': [i for i in range(5, 10)], # PAUL 
    #'datamodule.partition': [i for i in range(10, 16)], # THORE 
    #'datamodule.partition': [i for i in range(16, 22)], # LUKAS 
}

parameters_retagesex = {
    'datamodule.partition': [20], # only with best partition
}

#### Functions

In [12]:
def make_job_script(job_name, base_params, hyperparams):
    
    params_str = ' '.join(base_params + hyperparams)

    job_script_str = f'''#!/bin/bash
#SBATCH --job-name={job_name}                # Specify job name
#SBATCH --partition=gpu                     # Specify partition name
#SBATCH --nodes=1-1                          # Specify number of nodes
#SBATCH --cpus-per-gpu=62
#SBATCH --mem=400GB                          # Use entire memory of node
#SBATCH --gres=gpu:nvidia_a100_80gb_pcie:1   # Generic resources; 1 80GB GPU
#SBATCH --time=48:00:00                      # Set a limit on the total run time
#SBATCH --error={SUBMISSION_BASE}/{EXPERIMENT_NAME}/job_outputs/slurm-%A_%a.err
#SBATCH --output={SUBMISSION_BASE}/{EXPERIMENT_NAME}/job_outputs/slurm-%A_%a.out


source ~/miniconda3/etc/profile.d/conda.sh
conda activate /sc-projects/sc-proj-ukb-cvd/environments/retina

python {TRAIN_SCRIPT} --config-path {TEMPLATE_CONFIG} ''' + params_str
    
    return job_script_str

In [13]:
def submit(path, job_name, job_script, time_stamp=None):
    if not time_stamp:
        time_stamp = datetime.datetime.now().strftime("%Y-%m-%d_%H:%M:%S")
        
    script_path_long = f'{path}/{job_name}_{time_stamp}.sh'

    with open(script_path_long, 'w') as outfile: 
        outfile.write(job_script)
    script_path = f'{path}/{job_name}.sh'
    try:
        os.unlink(script_path)
    except FileNotFoundError: # because we cannot overwrite symlinks directly
        pass
    os.symlink(os.path.realpath(script_path_long), script_path)

    print('\n\nSubmission:\n===========\n')
    sub_cmd = f'sbatch < {script_path}'
    print(sub_cmd)
    
    ret = subprocess.run(sub_cmd, shell=True, cwd=os.getcwd(), stdout=subprocess.PIPE, stderr=subprocess.PIPE)
#     print(ret.stdout.decode())

## RUN RETINA + AGE + SEX Training

In [14]:
jobids = []

In [15]:
for i, hp_vals in enumerate(itertools.product(*parameters_retagesex.values(), repeat=1)):
    hyperparams = [f"{p}={v}" for p, v in zip(parameters.keys(), hp_vals)]
    job_script = make_job_script(#user=USER,
                                 job_name=JOBNAME,
                                 base_params=RETAGESEX_HYPERPARAMS,
                                 hyperparams=hyperparams)
    print(job_script)

    # jobid = submit(path=f"/home/{USER}/tmp/{EXPERIMENT_NAME}/job_submissions",
    jobid = submit(path=f"{SUBMISSION_BASE}/{EXPERIMENT_NAME}/job_submissions",
                   job_name=JOBNAME+f'_{i}',
                   job_script=job_script)

    jobids.append(jobid)

#!/bin/bash
#SBATCH --job-name=fullrun_retina                # Specify job name
#SBATCH --partition=gpu                     # Specify partition name
#SBATCH --nodes=1-1                          # Specify number of nodes
#SBATCH --cpus-per-gpu=62
#SBATCH --mem=400GB                          # Use entire memory of node
#SBATCH --gres=gpu:nvidia_a100_80gb_pcie:1   # Generic resources; 1 80GB GPU
#SBATCH --time=48:00:00                      # Set a limit on the total run time
#SBATCH --error=/sc-projects/sc-proj-ukb-cvd/submissions/RetinalRisk/22_retinalrisk_230905_fullrun_retina/job_outputs/slurm-%A_%a.err
#SBATCH --output=/sc-projects/sc-proj-ukb-cvd/submissions/RetinalRisk/22_retinalrisk_230905_fullrun_retina/job_outputs/slurm-%A_%a.out


source ~/miniconda3/etc/profile.d/conda.sh
conda activate /sc-projects/sc-proj-ukb-cvd/environments/retina

python /sc-projects/sc-proj-ukb-cvd/code/RetinalRisk/retinalrisk/scripts/train_retina.py --config-path /sc-projects/sc-proj-ukb-cvd/code/Retinal

## RUN RETINA TRAINING

In [7]:
jobids = []

In [8]:
for i, hp_vals in enumerate(itertools.product(*parameters.values(), repeat=1)):
    hyperparams = [f"{p}={v}" for p, v in zip(parameters.keys(), hp_vals)]
    job_script = make_job_script(#user=USER,
                                 job_name=JOBNAME,
                                 base_params=BASE_HYPERPARAMS,
                                 hyperparams=hyperparams)
    print(job_script)

    # jobid = submit(path=f"/home/{USER}/tmp/{EXPERIMENT_NAME}/job_submissions",
    jobid = submit(path=f"{SUBMISSION_BASE}/{EXPERIMENT_NAME}/job_submissions",
                   job_name=JOBNAME+f'_{i}',
                   job_script=job_script)

    jobids.append(jobid)

#!/bin/bash
#SBATCH --job-name=fullrun_retina                # Specify job name
#SBATCH --partition=gpu                     # Specify partition name
#SBATCH --nodes=1-1                          # Specify number of nodes
#SBATCH --cpus-per-gpu=62
#SBATCH --mem=400GB                          # Use entire memory of node
#SBATCH --gres=gpu:nvidia_a100_80gb_pcie:1   # Generic resources; 1 80GB GPU
#SBATCH --time=50:00:00                      # Set a limit on the total run time
#SBATCH --error=/sc-projects/sc-proj-ukb-cvd/submissions/RetinalRisk/22_retinalrisk_230905_fullrun_retina/job_outputs/slurm-%A_%a.err
#SBATCH --output=/sc-projects/sc-proj-ukb-cvd/submissions/RetinalRisk/22_retinalrisk_230905_fullrun_retina/job_outputs/slurm-%A_%a.out


source ~/miniconda3/etc/profile.d/conda.sh
conda activate /sc-projects/sc-proj-ukb-cvd/environments/retina

python /sc-projects/sc-proj-ukb-cvd/code/RetinalRisk/retinalrisk/scripts/train_retina.py --config-path /sc-projects/sc-proj-ukb-cvd/code/Retinal

In [10]:
print(jobids)

[None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None]


# AUSFÜHRUNG BIS HIER REICHT, DANKE!

In [11]:
@@ halt.

SyntaxError: invalid syntax (244539343.py, line 1)

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import hydra
import numpy as np
import torch
from hydra import compose, initialize
from omegaconf import OmegaConf
import pandas as pd
import wandb

from tqdm.auto import tqdm

In [None]:
from ehrgraphs.data.datamodules import EHRGraphDataModule
from ehrgraphs.training import setup_training

In [None]:
api = wandb.Api()
runs = api.runs(path="cardiors/RecordGraphs", filters={"display_name": "220420_t0_ablation"})

In [None]:
for r in runs:
    print()

In [None]:
# %%
run_df = pd.DataFrame(
    [
        dict(
            run_id=r.id,
            buffer_years=eval(r.config["_content"]["datamodule"])["t0_mode"],
            val_mean_cindex=r.summary["valid/mean_CIndex_max"],
        )
        for r in runs if r.state == 'finished'
    ]
)

In [None]:
run_df

In [None]:
tmp = run_df.copy()
tmp = tmp.sort_values('val_mean_cindex', ascending=False)

In [None]:
from plotnine import *
%matplotlib inline

In [None]:
order = tmp['buffer_years'].values.tolist()

In [None]:
tmp['cat'] = pd.Categorical(tmp['buffer_years'], categories=order)

In [None]:
tmp.head()

In [None]:
(ggplot() 
 + geom_point(
     tmp,
     aes(x='val_mean_cindex', y='cat',
         fill='val_mean_cindex',
         color='val_mean_cindex'
        ),
 )
 + theme(figure_size=(5, 5))
#  + scale_fill_brewer(type='qual', palette=3)
#  + scale_color_brewer(type='qual', palette=3)
 + theme_classic()
)