 # Finetune a GET Model on MCF-7 Bulk Data (Leaving out chr1, on a uniform cell line)


 This tutorial demonstrates how to train a GET model to predict expression in ATAC-seq peaks using motif information. We'll cover:

 1. Loading and configuring the model

 2. Finetune from a pretrained expression prediction GET model

 3. Perform various analysis using `gcell` package



 ## Setup

 First, let's import the necessary modules and set up our configuration.
 
 Note:
 If you run from a Mac, make sure you use the jupyter notebook rather than the VSCode interactive python editor as the later seems to have issue with multiple workers.
 If you run from Linux, both should work fine.

In [None]:
from pathlib import Path

import matplotlib.pyplot as plt
import seaborn as sns
from gcell.cell.celltype import GETHydraCellType

from get_model.config.config import load_config
from get_model.run_region import run_zarr as run


 ## Finetune Run 1



 We'll start by loading a predefined configuration and customizing it for our needs.

 The base configuration is in `get_model/config/finetune_tutorial_pbmc.yaml`

> Note: In the paper, we mainly used binary ATAC signal trained model for motif interpretation analysis. As it's hard to say whether there are mutual causal relationship between transcription and accessibility. If accessibility is added to the model, potentially it will absorb some TF's effect to itself, thereby making the interpretation more difficult. However, if the goal is to represent the cell states as precisely as possible and use the model for other downstream tasks (e.g. enhancer target prediction), adding the accessibility signal is probably better.

In [None]:
cfg = load_config('finetune_tutorial_pbmc') # load the predefined finetune tutorial config
cfg.stage = 'fit'
cfg.run.run_name = 'training_from_finetune_lora_chr1_split_QATAC'
cfg.dataset.quantitative_atac = True # We use binary ATAC signal for motif interpretation analysis

cfg.dataset.zarr_path = "/project/home/p200469/get_BIO1018/get_preprocess_output.zarr/"
cfg.dataset.celltypes = "all_chrs"

# Set a unique project name for training on all chromosomes
cfg.run.project_name = 'finetune_all_chrs'
cfg.dataset.celltypes = "all_chrs"
cfg.finetune.checkpoint = "./checkpoint-best.pth" # set the path to the pretrained checkpoint we want to finetune from
cfg.dataset.leave_out_celltypes = '' # set the celltypes you want to leave out, '' here means no celltype leave out
cfg.dataset.leave_out_chromosomes = 'chr1' # set the chromosomes you want to leave out, '' here means no chromosome leave out
cfg.machine.num_devices=0 # use 0 for cpu training; >=1 for gpu training
cfg.machine.batch_size=8 # batch size for training; check `nvidia-smi` to see the available GPU memory

cfg.machine.output_dir = "/project/home/p200469/get_BIO1018/get_ML_output"
cfg.training.epochs = 50

cfg.training.val_check_interval = 5.0 

print(f"output path: {cfg.machine.output_dir}/{cfg.run.project_name}/{cfg.run.run_name}")
print(f"training for {cfg.training.epochs} epochs")

Now we can start the finetuning

In [None]:
!curl -O https://2023-get-xf2217.s3.amazonaws.com/get_demo/checkpoints/regulatory_inference_checkpoint_fetal_adult/finetune_fetal_adult_leaveout_astrocyte/checkpoint-best.pth

In [None]:
trainer = run(cfg) # run the finetuning, takes around 2 hours on one RTX 3090
cfg.finetune.checkpoint = "./checkpoint-best.pth"
print("checkpoint path:", trainer.checkpoint_callback.best_model_path)

# Interpretation

After finetuning, we can use the checkpoint to predict expression of all accessible genes and generate jacobian matrix of (peak x motif) for every predicted genes. 
To start, we need to collect the checkpoint we produced and switch to `predict` stage. Here, let's focus on CD4 Naive cell and we need to set `cfg.leave_out_celltypes` to `cd4_naive` for the model to predict gene expression in this cell type.

In [None]:
use_checkpoint = '/project/home/p200469/get_BIO1018/get_ML_output/finetune_all_chrs/training_from_finetune_lora_chr1_split_QATAC/checkpoints/best.ckpt'
cfg.stage = 'predict'
cfg.finetune.resume_ckpt = use_checkpoint

cfg.dataset.celltypes = "all_chrs"

cfg.run.use_wandb = False # disable wandb logging when predicting
cfg.task.layer_names = [] # set to empty list to disable intermediate layer interpretation
cfg.task.gene_list = None # set to None to predict all genes; otherwise you can specify a list of genes as 'MYC,SOX10,SOX2,RET', only genes with promoter open will be used
# loop through all celltypes and run the predict stage
cfg.run.run_name='interpret_training_from_finetune_lora_chr1_split_QATAC'

cfg.dataset.leave_out_celltypes = ''
trainer = run(cfg)

# Pearson 0.896
# R^2 0.793
# Spearman 0.803

As you can see, the results is now saved to `finetune_pbmc10k_multiome/interpret_training_from_finetune_lora_cd4_tcm_no_chr_split/cd4_naive.zarr`. Now we can use the `GETHydraCellType` class from `gcell` to load it.

### Load interpretation result as `GETHydraCellType`

In [None]:
# Set dataset celltype to "all_chrs"
cfg.dataset.celltypes = "all_chrs"

# Load the configuration
cfg = load_config('finetune_tutorial_pbmc')

# Update the run name
cfg.run.run_name = 'interpret_training_from_finetune_lora_all_chrs_QATAC'

# Create a gene annotation dictionary
gene_annot_dict = {}

try:
    # Since celltypes is "all_chrs", set `leave_out_celltypes` to None
    cfg.dataset.leave_out_celltypes = None
    
    # Load data for all chromosomes
    hydra_celltype = GETHydraCellType.from_config(cfg, celltype="all_chrs")
    
    # Save the gene annotations for "all_chrs"
    gene_annot_dict["all_chrs"] = hydra_celltype.gene_annot

except Exception as e:
    print(f"Error loading all_chrs: {e}")

In [None]:
# For all genes in "all_chrs", collect the predicted and observed expression
import numpy as np
import pandas as pd

try:
    # Get the common intersected gene list for "all_chrs"
    common_gene_list = set(gene_annot_dict["all_chrs"].index)
    
    # Collect the expression for the common gene list
    gene_annot_dict["all_chrs"] = gene_annot_dict["all_chrs"].loc[np.array(common_gene_list)]
    
    # Convert the gene annotations into a DataFrame for easier manipulation and visualization
    gene_expression_df = pd.DataFrame(gene_annot_dict["all_chrs"])
    
    print("Successfully collected gene expressions for all_chrs.")
except Exception as e:
    print(f"Error processing gene annotations for all_chrs: {e}")

gene_annot_dict


In [None]:
# Create a DataFrame for "all_chrs" excluding genes from chr1
df = gene_annot_dict["all_chrs"] \
    .query('Chromosome == "chr1"') \
    .reset_index()[['obs', 'pred', 'gene_name']] \
    .groupby('gene_name').mean().reset_index()

# Pivot the DataFrame to create columns for 'pred' and 'obs'
df_pred = df.pivot(index='gene_name', values='pred').dropna()
df_obs = df.pivot(index='gene_name', values='obs').dropna()

# Calculate the correlation between predicted and observed expression for each gene
corrs = []
for gene in df_pred.index:
    try:
        corr = df_pred.loc[gene].corr(df_obs.loc[gene])
        corrs.append((corr, df_pred.loc[gene].mean(), df_obs.loc[gene].mean(), gene))
    except Exception as e:
        print(f"Error calculating correlation for {gene}: {e}")
        continue

# Create a DataFrame to store correlations
df_corr = pd.DataFrame(corrs, columns=['corr', 'pred', 'obs', 'gene_name'])

print("Correlation calculation completed.")


In [None]:
# plot example genes
plt.rcParams['figure.dpi'] = 100

fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(11, 5))
df_corr['corr'].hist(bins=20, ax=ax[0])
# add total number of genes as title
ax[0].set_title(f'Total number of genes: {len(df_corr)}\nshared open genes across MCF-7 \non leave out chr1')
ax[0].set_xlabel('Uniform Clonal Cell Correlation')
ax[0].set_ylabel('Number of Genes')

sns.kdeplot(data=df_corr, x='pred', y='corr', ax=ax[1], shade=True)
ax[1].set_title('Predicted mean vs \nUniform Clonal Cell Correlation')
ax[1].set_xlabel('Predicted Expression')
ax[1].set_ylabel('Uniform Clonal Cell Correlation')
fig.tight_layout()


We can plot the predicted and observed expression to see whether there is any issue.

In [None]:
df_pred