In [None]:
import os
import json
import numpy as np
import pandas as pd

In [None]:
data_dir = "/clusterfs/nilah/ruchir/src/finetuning-enformer/finetuning/data/h5_bins_384_chrom_split/"
enformer_data_dir = "/global/scratch/users/aniketh/enformer_data/"
root_save_dir = "/global/scratch/users/aniketh/finetune-enformer/"
models_dir = "/global/scratch/users/aniketh/finetune-enformer/saved_models/"
test_preds_dir = "/global/scratch/users/aniketh/finetune-enformer/test_preds_final/"
rest_unseen_preds_dir = (
    "/global/scratch/users/aniketh/finetune-enformer/rest_unseen_preds_final/"
)
ISM_preds_dir = "/global/scratch/users/aniketh/finetune-enformer/ISM/"
code_dir = "/global/home/users/aniketh/finetuning-enformer/"
fasta_path = "/clusterfs/nilah/aniketh/hg19/hg19.fa"
malinois_data_path = (
    "/clusterfs/nilah/aniketh/Malinois/all_sequences_variant_effect_data.csv"
)

train_h5_path = os.path.join(data_dir, "train.h5")
val_h5_path = os.path.join(data_dir, "val.h5")
test_h5_path = os.path.join(data_dir, "test.h5")
rest_unseen_h5_path = os.path.join(data_dir, "rest_unseen.h5")

counts_path = os.path.join(
    code_dir, "process_geuvadis_data", "log_tpm", "corrected_log_tpm.annot.csv.gz"
)
gene_class_path = os.path.join(
    code_dir, "finetuning", "data", "h5_bins_384_chrom_split", "gene_class.csv"
)
assert os.path.exists(counts_path) and os.path.exists(gene_class_path)

slurm_template = "slurm_template.sh"
temp_script_path = "temp_script.sh"

In [None]:
all_main_run_names = {
    "baseline": "",
    "classification": "NCCL_P2P_DISABLE=1 python finetuning/train_pairwise_classification_parallel_h5_dataset_dynamic_sampling_dataset.py {train_h5_path} {val_h5_path} {run_name} {models_dir} --batch_size 1 --lr 0.0001 --weight_decay 0.001 --data_seed {data_seed} --resume_from_checkpoint",
    "regression": "NCCL_P2P_DISABLE=1 python finetuning/train_pairwise_regression_parallel_h5_dataset.py {train_h5_path} {val_h5_path} {run_name} {models_dir} --batch_size 1 --lr 0.0001 --weight_decay 0.001 --use_scheduler --warmup_steps 1000 --data_seed {data_seed} --resume_from_checkpoint",
    "single_regression_counts": "NCCL_P2P_DISABLE=1 python finetuning/train_single_counts_parallel_h5_dataset.py {train_h5_path} {val_h5_path} {run_name} {models_dir} --batch_size 2 --lr 0.0001 --weight_decay 0.001 --use_scheduler --warmup_steps 1000 --data_seed {data_seed} --resume_from_checkpoint",
    #     "single_regression": "NCCL_P2P_DISABLE=1 python finetuning/train_single_parallel_h5_dataset.py {train_h5_path} {val_h5_path} {run_name} {models_dir} --batch_size 2 --lr 0.0001 --weight_decay 0.001 --use_scheduler --warmup_steps 1000 --data_seed {data_seed} --resume_from_checkpoint", ###DOEST PERFORM WELL, NOT USED
    "joint_classification": "NCCL_P2P_DISABLE=1 python finetuning/train_pairwise_classification_with_enformer_data_parallel_h5_dynamic_sampling_dataset.py {train_h5_path} {val_h5_path} {enformer_data_dir} {run_name} {models_dir} --batch_size 1 --lr 0.0005 --weight_decay 0.005 --use_scheduler --warmup_steps 1000 --data_seed {data_seed} --resume_from_checkpoint",
    "joint_regression": "NCCL_P2P_DISABLE=1 python finetuning/train_pairwise_regression_with_enformer_data_parallel_h5_dynamic_sampling_dataset.py {train_h5_path} {val_h5_path} {enformer_data_dir} {run_name} {models_dir} --batch_size 1 --lr 0.0005 --weight_decay 0.005 --use_scheduler --warmup_steps 1000 --data_seed {data_seed} --resume_from_checkpoint",
    "joint_regression_with_Malinois_MPRA": "NCCL_P2P_DISABLE=1 python finetuning/train_pairwise_regression_with_MPRA_data_parallel_h5_dynamic_sampling_dataset.py {train_h5_path} {val_h5_path} {malinois_data_path} {run_name} {models_dir} --batch_size 1 --lr 0.0001 --weight_decay 0.001 --use_scheduler --warmup_steps 1000 --data_seed {data_seed} --resume_from_checkpoint",
    "single_regression_counts_random_init": "NCCL_P2P_DISABLE=1 python finetuning/train_single_counts_parallel_h5_dataset.py {train_h5_path} {val_h5_path} {run_name} {models_dir} --batch_size 2 --lr 0.0001 --weight_decay 0.001 --use_scheduler --warmup_steps 1000 --data_seed {data_seed} --resume_from_checkpoint --use_random_init",
    "regression_random_init": "NCCL_P2P_DISABLE=1 python finetuning/train_pairwise_regression_parallel_h5_dataset.py {train_h5_path} {val_h5_path} {run_name} {models_dir} --batch_size 1 --lr 0.0001 --weight_decay 0.001 --use_scheduler --warmup_steps 1000 --data_seed {data_seed} --resume_from_checkpoint --use_random_init",
    "classification_random_init": "NCCL_P2P_DISABLE=1 python finetuning/train_pairwise_classification_parallel_h5_dataset_dynamic_sampling_dataset.py {train_h5_path} {val_h5_path} {run_name} {models_dir} --batch_size 1 --lr 0.0001 --weight_decay 0.001 --data_seed {data_seed} --resume_from_checkpoint --use_random_init",
    # "joint_classification_random_init": "NCCL_P2P_DISABLE=1 python finetuning/train_pairwise_classification_with_enformer_data_parallel_h5_dynamic_sampling_dataset.py {train_h5_path} {val_h5_path} {enformer_data_dir} {run_name} {models_dir} --batch_size 1 --lr 0.0005 --weight_decay 0.005 --use_scheduler --warmup_steps 1000 --data_seed {data_seed} --resume_from_checkpoint --use_random_init",
    # "joint_regression_random_init": "NCCL_P2P_DISABLE=1 python finetuning/train_pairwise_regression_with_enformer_data_parallel_h5_dynamic_sampling_dataset.py {train_h5_path} {val_h5_path} {enformer_data_dir} {run_name} {models_dir} --batch_size 1 --lr 0.0005 --weight_decay 0.005 --use_scheduler --warmup_steps 1000 --data_seed {data_seed} --resume_from_checkpoint --use_random_init",
    # "joint_regression_with_Malinois_MPRA_random_init": "NCCL_P2P_DISABLE=1 python finetuning/train_pairwise_regression_with_MPRA_data_parallel_h5_dynamic_sampling_dataset.py {train_h5_path} {val_h5_path} {malinois_data_path} {run_name} {models_dir} --batch_size 1 --lr 0.0001 --weight_decay 0.001 --use_scheduler --warmup_steps 1000 --data_seed {data_seed} --resume_from_checkpoint --use_random_init",
}

all_seeds = [42, 97, 7]
subsample_fracs = [0.2, 0.4, 0.6, 0.8]
all_afs = (
    list(np.arange(0.01, 0.1, 0.01).round(2))
    + list(np.arange(0.1, 0.4, 0.05).round(2))
    + list(np.arange(0.41, 0.49, 0.01).round(2))
)
print(all_afs)
afs_cache_path = os.path.join(root_save_dir, "train.h5.afs.pkl")

In [None]:
if not os.path.exists(test_preds_dir):
    os.makedirs(test_preds_dir, exist_ok=True)

if not os.path.exists(rest_unseen_preds_dir):
    os.makedirs(rest_unseen_preds_dir, exist_ok=True)

main_cmd = "NCCL_P2P_DISABLE=1 python finetuning/test_models.py {test_data_path} {cur_test_preds_dir} {model_type} {cur_checkpoints_dir} --use_reverse_complement --create_best_ckpt_copy"
ISM_cmd = "NCCL_P2P_DISABLE=1 python finetuning/compute_ISM_scores.py {counts_path} {gene_class_path} {fasta_path} {predictions_dir} {model_type} {cur_checkpoints_dir} --use_reverse_complement --gene_name {gene_name}"

In [None]:
# MAIN TRAIN RUNS
for run in all_main_run_names:
    if run == "baseline":
        continue
    if "random_init" not in run:
        continue
    for i, seed in enumerate(all_seeds):
        run_name = run
        cmd = all_main_run_names[run]
        cmd = cmd.replace("{run_name}", run_name)
        cmd = cmd.replace("{data_seed}", str(seed))
        cmd = cmd.replace("{train_h5_path}", train_h5_path)
        cmd = cmd.replace("{val_h5_path}", val_h5_path)
        cmd = cmd.replace("{enformer_data_dir}", enformer_data_dir)
        cmd = cmd.replace("{malinois_data_path}", malinois_data_path)
        cmd = cmd.replace("{models_dir}", models_dir)

        print(cmd)

        temp_script = open(temp_script_path, "w+")
        for line in open(slurm_template, "r").readlines():
            temp_script.write(line)
        temp_script.write("\n")

        temp_script.write(cmd)
        temp_script.write("\n")

        temp_script.close()

        os.system(f"sbatch --requeue {temp_script_path}")

In [None]:
# SUBSAMPLED TRAIN SET RUNS FOR REGRESSION
for run in ["regression"]:
    for i, seed in enumerate(all_seeds):
        for j, frac in enumerate(subsample_fracs):
            run_name = run
            cmd = all_main_run_names[run]
            cmd = cmd.replace("{run_name}", run_name)
            cmd = cmd.replace("{data_seed}", str(seed))
            cmd = cmd.replace("{train_h5_path}", train_h5_path)
            cmd = cmd.replace("{val_h5_path}", val_h5_path)
            cmd = cmd.replace("{enformer_data_dir}", enformer_data_dir)
            cmd = cmd.replace("{models_dir}", models_dir)
            cmd = cmd + f" --train_set_subsample_ratio {frac}"

            print(cmd)

            temp_script = open(temp_script_path, "w+")
            for line in open(slurm_template, "r").readlines():
                temp_script.write(line)
            temp_script.write("\n")

            temp_script.write(cmd)
            temp_script.write("\n")

            temp_script.close()

            os.system(f"sbatch --requeue {temp_script_path}")

In [None]:
# MAIN TEST RUNS
for run in all_main_run_names:
    if run == "baseline":
        model_name = "baseline"
        cmd = main_cmd.replace("{test_data_path}", test_h5_path)
        cmd = cmd.replace(
            "{cur_test_preds_dir}", os.path.join(test_preds_dir, model_name)
        )
        cmd = cmd.replace("{model_type}", run)  # run aliases are same as model_type
        cmd = cmd.replace("{cur_checkpoints_dir}", "dummy")

        if os.path.exists(os.path.join(test_preds_dir, model_name, "test_preds.npz")):
            print(f"{model_name} predictions done")
            continue

        print(cmd)

        temp_script = open(temp_script_path, "w+")
        for line in open(slurm_template, "r").readlines():
            temp_script.write(line)
        temp_script.write("\n")

        temp_script.write(cmd)
        temp_script.write("\n")

        temp_script.close()
        os.system(f"sbatch --requeue {temp_script_path}")
    else:
        for i, seed in enumerate(all_seeds):

            train_cmd_template = all_main_run_names[run]
            lr_used_during_training = train_cmd_template.split("--lr ")[-1].split(" ")[
                0
            ]
            wd_used_during_training = train_cmd_template.split("--weight_decay ")[
                -1
            ].split(" ")[0]
            rcprob_used_during_training = 0.5
            rsmax_used_during_training = 3

            model_name = f"{run}_data_seed_{seed}_lr_{lr_used_during_training}_wd_{wd_used_during_training}_rcprob_{rcprob_used_during_training}_rsmax_{rsmax_used_during_training}"
            if "random_init" in run:
                model_name = model_name + "_random_init"

            cmd = main_cmd.replace("{test_data_path}", test_h5_path)
            cmd = cmd.replace(
                "{cur_test_preds_dir}", os.path.join(test_preds_dir, model_name)
            )
            if "random_init" in run:
                cmd = cmd.replace(
                    "{model_type}", run[: -len("_random_init")]
                )  # run aliases are same as model_type but remove suffix
            else:
                cmd = cmd.replace(
                    "{model_type}", run
                )  # run aliases are same as model_type
            cmd = cmd.replace(
                "{cur_checkpoints_dir}",
                os.path.join(models_dir, model_name, "checkpoints"),
            )
            assert os.path.exists(os.path.join(models_dir, model_name, "checkpoints"))

            if os.path.exists(
                os.path.join(test_preds_dir, model_name, "test_preds.npz")
            ):
                print(f"{run} predictions done")
                continue

            print(cmd)

            temp_script = open(temp_script_path, "w+")
            for line in open(slurm_template, "r").readlines():
                temp_script.write(line)
            temp_script.write("\n")

            temp_script.write(cmd)
            temp_script.write("\n")

            temp_script.close()

            # os.system(f"sbatch --requeue {temp_script_path}")

In [None]:
# REST UNSEEN RUNS
cnt = 0
for run in reversed(all_main_run_names):
    if run == "baseline":
        model_name = "baseline"
        cmd = main_cmd.replace("{test_data_path}", rest_unseen_h5_path)
        cmd = cmd.replace(
            "{cur_test_preds_dir}", os.path.join(rest_unseen_preds_dir, model_name)
        )
        cmd = cmd.replace("{model_type}", run)  # run aliases are same as model_type
        cmd = cmd.replace("{cur_checkpoints_dir}", "dummy")

        if os.path.exists(
            os.path.join(rest_unseen_preds_dir, model_name, "test_preds.npz")
        ):
            print(f"{model_name} predictions done")
            continue

        print(cmd)

        temp_script = open(temp_script_path, "w+")
        for line in open(slurm_template, "r").readlines():
            temp_script.write(line)
        temp_script.write("\n")

        temp_script.write(cmd)
        temp_script.write("\n")

        temp_script.close()

        os.system(f"sbatch --requeue {temp_script_path}")
    else:
        for i, seed in enumerate(reversed(all_seeds)):
            train_cmd_template = all_main_run_names[run]
            lr_used_during_training = train_cmd_template.split("--lr ")[-1].split(" ")[
                0
            ]
            wd_used_during_training = train_cmd_template.split("--weight_decay ")[
                -1
            ].split(" ")[0]
            rcprob_used_during_training = 0.5
            rsmax_used_during_training = 3

            model_name = f"{run}_data_seed_{seed}_lr_{lr_used_during_training}_wd_{wd_used_during_training}_rcprob_{rcprob_used_during_training}_rsmax_{rsmax_used_during_training}"
            if "random_init" in run:
                model_name = model_name + "_random_init"

            cmd = main_cmd.replace("{test_data_path}", rest_unseen_h5_path)
            cmd = cmd.replace(
                "{cur_test_preds_dir}", os.path.join(rest_unseen_preds_dir, model_name)
            )
            if "random_init" in run:
                cmd = cmd.replace(
                    "{model_type}", run[: -len("_random_init")]
                )  # run aliases are same as model_type but remove suffix
            else:
                cmd = cmd.replace(
                    "{model_type}", run
                )  # run aliases are same as model_type
            cmd = cmd.replace(
                "{cur_checkpoints_dir}",
                os.path.join(models_dir, model_name, "checkpoints"),
            )
            assert os.path.exists(os.path.join(models_dir, model_name, "checkpoints"))

            if os.path.exists(
                os.path.join(rest_unseen_preds_dir, model_name, "test_preds.npz")
            ):
                print(f"{model_name} predictions done")
                continue

            print(cmd)

            temp_script = open(temp_script_path, "w+")
            for line in open(slurm_template, "r").readlines():
                temp_script.write(line)
            temp_script.write("\n")

            temp_script.write(cmd)
            temp_script.write("\n")

            temp_script.close()

            # os.system(f"sbatch --requeue {temp_script_path}")

In [None]:
# TEST RUNS FOR SUBSAMPLED TRAIN SET RUNS FOR REGRESSION
for run in ["regression"]:
    for i, seed in enumerate(all_seeds):
        for j, frac in enumerate(subsample_fracs):
            train_cmd_template = all_main_run_names[run]
            lr_used_during_training = train_cmd_template.split("--lr ")[-1].split(" ")[
                0
            ]
            wd_used_during_training = train_cmd_template.split("--weight_decay ")[
                -1
            ].split(" ")[0]
            rcprob_used_during_training = 0.5
            rsmax_used_during_training = 3

            model_name = f"{run}_data_seed_{seed}_lr_{lr_used_during_training}_wd_{wd_used_during_training}_rcprob_{rcprob_used_during_training}_rsmax_{rsmax_used_during_training}_subsample_ratio_{frac}"

            cmd = main_cmd.replace("{test_data_path}", test_h5_path)
            cmd = cmd.replace(
                "{cur_test_preds_dir}", os.path.join(test_preds_dir, model_name)
            )
            cmd = cmd.replace("{model_type}", run)  # run aliases are same as model_type
            cmd = cmd.replace(
                "{cur_checkpoints_dir}",
                os.path.join(models_dir, model_name, "checkpoints"),
            )
            assert os.path.exists(os.path.join(models_dir, model_name, "checkpoints"))

            print(cmd)

            temp_script = open(temp_script_path, "w+")
            for line in open(slurm_template, "r").readlines():
                temp_script.write(line)
            temp_script.write("\n")

            temp_script.write(cmd)
            temp_script.write("\n")

            temp_script.close()

            os.system(f"sbatch --requeue {temp_script_path}")

In [None]:
# REST UNSEEN RUNS FOR SUBSAMPLED TRAIN SET RUNS FOR REGRESSION
for run in ["regression"]:
    for i, seed in enumerate(all_seeds):
        for j, frac in enumerate(subsample_fracs):
            train_cmd_template = all_main_run_names[run]
            lr_used_during_training = train_cmd_template.split("--lr ")[-1].split(" ")[
                0
            ]
            wd_used_during_training = train_cmd_template.split("--weight_decay ")[
                -1
            ].split(" ")[0]
            rcprob_used_during_training = 0.5
            rsmax_used_during_training = 3

            model_name = f"{run}_data_seed_{seed}_lr_{lr_used_during_training}_wd_{wd_used_during_training}_rcprob_{rcprob_used_during_training}_rsmax_{rsmax_used_during_training}_subsample_ratio_{frac}"

            cmd = main_cmd.replace("{test_data_path}", rest_unseen_h5_path)
            cmd = cmd.replace(
                "{cur_test_preds_dir}", os.path.join(rest_unseen_preds_dir, model_name)
            )
            cmd = cmd.replace("{model_type}", run)  # run aliases are same as model_type
            cmd = cmd.replace(
                "{cur_checkpoints_dir}",
                os.path.join(models_dir, model_name, "checkpoints"),
            )
            assert os.path.exists(os.path.join(models_dir, model_name, "checkpoints"))

            if os.path.exists(
                os.path.join(rest_unseen_preds_dir, model_name, "test_preds.npz")
            ):
                print(f"{model_name} predictions done")
                continue

            print(cmd)

            temp_script = open(temp_script_path, "w+")
            for line in open(slurm_template, "r").readlines():
                temp_script.write(line)
            temp_script.write("\n")

            temp_script.write(cmd)
            temp_script.write("\n")

            temp_script.close()

            os.system(f"sbatch --requeue {temp_script_path}")

In [None]:
# TEST RUNS USING SUBSAMPLED VARIANTS FOR REGRESSION
for run in ["regression"]:
    for i, seed in enumerate(all_seeds):
        for j, af in enumerate(all_afs):
            train_cmd_template = all_main_run_names[run]
            lr_used_during_training = train_cmd_template.split("--lr ")[-1].split(" ")[
                0
            ]
            wd_used_during_training = train_cmd_template.split("--weight_decay ")[
                -1
            ].split(" ")[0]
            rcprob_used_during_training = 0.5
            rsmax_used_during_training = 3

            model_name = f"{run}_data_seed_{seed}_lr_{lr_used_during_training}_wd_{wd_used_during_training}_rcprob_{rcprob_used_during_training}_rsmax_{rsmax_used_during_training}"

            cmd = main_cmd.replace("{test_data_path}", test_h5_path)
            cmd = cmd.replace(
                "{cur_test_preds_dir}", os.path.join(test_preds_dir, model_name)
            )
            cmd = cmd.replace("{model_type}", run)  # run aliases are same as model_type
            cmd = cmd.replace(
                "{cur_checkpoints_dir}",
                os.path.join(models_dir, model_name, "checkpoints"),
            )
            cmd = (
                cmd
                + f" --rare_variant_af_threshold {af} --train_h5_path_for_af_computation {train_h5_path} --afs_cache_path {afs_cache_path}"
            )
            assert os.path.exists(os.path.join(models_dir, model_name, "checkpoints"))

            print(cmd)

            temp_script = open(temp_script_path, "w+")
            for line in open(slurm_template, "r").readlines():
                temp_script.write(line)
            temp_script.write("\n")

            temp_script.write(cmd)
            temp_script.write("\n")

            temp_script.close()

            os.system(f"sbatch --requeue {temp_script_path}")

In [None]:
# Get ISM scores for every gene
gene_class_df = pd.read_csv(gene_class_path)
population_split_genes = gene_class_df[
    gene_class_df["class"] == "yri_split"
].reset_index(drop=True)
print(f"Number of population-split genes: {len(population_split_genes)}")

cnt = 0

for run in ["baseline", "regression"]:
    if run == "regression":
        for i, seed in enumerate(all_seeds):
            train_cmd_template = all_main_run_names[run]
            lr_used_during_training = train_cmd_template.split("--lr ")[-1].split(" ")[
                0
            ]
            wd_used_during_training = train_cmd_template.split("--weight_decay ")[
                -1
            ].split(" ")[0]
            rcprob_used_during_training = 0.5
            rsmax_used_during_training = 3

            model_name = f"{run}_data_seed_{seed}_lr_{lr_used_during_training}_wd_{wd_used_during_training}_rcprob_{rcprob_used_during_training}_rsmax_{rsmax_used_during_training}"

            for gene in population_split_genes["gene"].unique():
                cmd = ISM_cmd.replace("{counts_path}", counts_path)
                cmd = cmd.replace("{gene_class_path}", gene_class_path)
                cmd = cmd.replace("{fasta_path}", fasta_path)
                cmd = cmd.replace(
                    "{predictions_dir}", os.path.join(ISM_preds_dir, model_name)
                )
                cmd = cmd.replace(
                    "{model_type}", run
                )  # run aliases are same as model_type
                cmd = cmd.replace(
                    "{cur_checkpoints_dir}",
                    os.path.join(models_dir, model_name, "checkpoints"),
                )
                cmd = cmd.replace("{gene_name}", gene)

                temp_script = open(temp_script_path, "w+")
                for line in open(slurm_template, "r").readlines():
                    temp_script.write(line)
                temp_script.write("\n")

                temp_script.write(cmd)
                temp_script.write("\n")

                temp_script.close()
                if not os.path.exists(
                    os.path.join(ISM_preds_dir, model_name, gene, "ism_scores.npz")
                ):
                    cnt += 1
                    print(seed, gene)
                    print(cmd)
                    os.system(f"sbatch --requeue {temp_script_path}")
    elif run == "baseline":
        model_name = "baseline"
        for gene in population_split_genes["gene"].unique():
            cmd = ISM_cmd.replace("{counts_path}", counts_path)
            cmd = cmd.replace("{gene_class_path}", gene_class_path)
            cmd = cmd.replace("{fasta_path}", fasta_path)
            cmd = cmd.replace(
                "{predictions_dir}", os.path.join(ISM_preds_dir, model_name)
            )
            cmd = cmd.replace("{model_type}", run)  # run aliases are same as model_type
            cmd = cmd.replace("{cur_checkpoints_dir}", "dummy")
            cmd = cmd.replace("{gene_name}", gene)

            temp_script = open(temp_script_path, "w+")
            for line in open(slurm_template, "r").readlines():
                temp_script.write(line)
            temp_script.write("\n")

            temp_script.write(cmd)
            temp_script.write("\n")

            temp_script.close()
            if not os.path.exists(
                os.path.join(ISM_preds_dir, model_name, gene, "ism_scores.npz")
            ):
                cnt += 1
                print(gene)
                print(cmd)
                os.system(f"sbatch --requeue {temp_script_path}")