 # (Continue) Pretraining a GET Model on MCF-7 ATAC


 This tutorial demonstrates how to train a GET model to predict ATAC-seq peaks using motif information. We'll cover:

 1. Loading and configuring the model

 2. Training without a pretrained checkpoint

 3. Training with a pretrained checkpoint

 4. Comparing the results



 ## Setup

 First, let's import the necessary modules and set up our configuration.
 
 Note:
 If you run from a Mac, make sure you use the jupyter notebook rather than the VSCode interactive python editor as the later seems to have issue with multiple workers.
 If you run from Linux, both should work fine.

In [None]:
#%%
from get_model.config.config import load_config, pretty_print_config
from get_model.run_region import run_zarr as run

 ## Configuration



 We'll start by loading a predefined configuration and customizing it for our needs.

 The base configuration is in `get_model/config/finetune_tutorial.yaml`

 This has been altered to allow for multiple zarr directories.

### Training on all Chromsomes
First pass with all chromosomes. 
The trainer cannot register individual chromsomes or ATAC_only. It requres both

In [None]:
import os

# Load the configuration for fine-tuning
cfg = load_config('finetune_tutorial')

# Set dataset parameters
cfg.dataset.zarr_path = "/project/home/p200469/get_BIO1018/get_preprocess_output.zarr/"
cfg.dataset.celltypes = "all_chrs"
cfg.dataset.leave_out_chromosomes = None  # Include all chromosomes in training

# Set project parameters
cfg.run.project_name = 'pretrain_all_chrs'
cfg.run.use_wandb = True  # Enable logging with Weights & Biases

# Training configuration
cfg.training.epochs = 50  # Number of training epochs
cfg.training.val_check_interval = 1.0  # Validate after each epoch

# Debugging information
print("Configuration Loaded Successfully!")
print(f"Zarr path: {cfg.dataset.zarr_path}")
print(f"Project Name: {cfg.run.project_name}")

 ### Model Selection


 We'll use the GETRegionPretrain model, which is designed to use contextual motif(+atac) information to target motif(+atac) information

 This model is particularly useful for understanding the relationship between motifs and chromatin accessibility.

In [None]:
#%%
# Switch model to finetune ATAC model
cfg.model = load_config('model/GETRegionPretrain').model.model
cfg.dataset.mask_ratio = 0.5 # mask 50% of the motifs. This has to be set for pretrain dataloader to generate proper mask

 ## Training Without Pretraining Checkpoint



 First, let's train the model from scratch (without using a pretrained checkpoint).

 This will give us a baseline for comparison.

In [None]:
#%%
#Tell machine output directory
cfg.machine.output_dir = "/project/home/p200469/get_BIO1018/get_ML_output"

# first run the model without initializing with a pretrain checkpoint
cfg.run.run_name='pretrain_MCF7_scratch' # this is a unique name for this run
cfg.finetune.checkpoint = None
cfg.finetune.use_lora = False
cfg.run.use_wandb = True
trainer = run(cfg)

In [None]:
#%%
trainer.callback_metrics

 ## Continue Training With Pretrained Checkpoint Using LoRA



 Now, let's train the model using a pretrained checkpoint. This checkpoint was trained on a large dataset

 and should help the model learn faster and potentially achieve better performance.



 Note: You'll need to download the checkpoint first:

In [None]:
#Download pulbic checkpoint file

!curl -O https://2023-get-xf2217.s3.amazonaws.com/get_demo/checkpoints/regulatory_inference_checkpoint_fetal_adult/pretrain_fetal_adult/checkpoint-799.pth

In [None]:
#%%
# now train the model with a pretrain checkpoint

cfg.machine.output_dir = "/project/home/p200469/get_BIO1018/get_ML_output"
cfg.finetune.checkpoint = './checkpoint-799.pth'
cfg.run.run_name = 'pretrain_mcf7_from_pretrain_lora'
cfg.finetune.model_key = "model"
cfg.finetune.rename_config = {
  "encoder.head.": "head_mask.",
  "encoder.region_embed": "region_embed",
  "region_embed.proj.": "region_embed.embed.",
  "encoder.cls_token": "cls_token",
}
cfg.finetune.strict = True
cfg.finetune.use_lora = True
cfg.finetune.layers_with_lora = ['region_embed', 'encoder']
trainer = run(cfg)
trainer.callback_metrics

In [None]:
#%%
# now train the model with a pretrain checkpoint without using LoRA
cfg.finetune.checkpoint = './checkpoint-799.pth'
cfg.run.run_name = 'pretrain_mcf7_from_pretrain_no_lora'
cfg.finetune.use_lora = False
trainer = run(cfg)
trainer.callback_metrics