# Run ECLARE on sample paired data

Import environment variables from YAML file

In [None]:
import os
from config.export_env_variables import export_env_variables

export_env_variables(config_path='config')


In [None]:
# Go to DATAPATH
os.chdir(os.environ["DATAPATH"])
DATAPATH_TMP = os.environ["DATAPATH"]

Download sample data from Zenodo (uncomment to run, only run once)

In [None]:
'''
# Download the data from the DOI link
!wget https://zenodo.org/records/14799100/files/eclare_sample_zenodo.zip?download=1 -O eclare_data.zip

# Unzip the downloaded data
!unzip eclare_data.zip -d eclare_data
!unzip eclare_data/eclare_sample_zenodo.zip  # takes about 15 minutes @ 5.67 Mb/s
'''

Overwrite the DATAPATH environment variable to the path of the downloaded data


In [None]:
os.environ["DATAPATH"] = os.path.join(DATAPATH_TMP, "eclare_data", "eclare_sample_zenodo")
# generally, os.environ["DATAPATH"] = os.path.join("/path/to/sample/data", "eclare_sample_zenodo")

print("DATAPATH: ", os.environ["DATAPATH"])

### Step 1: train CLIP teacher models

In [None]:
# Got to ECLARE_ROOT
os.chdir(os.environ["ECLARE_ROOT"])

In [None]:
# Run clip_samples.sh

os.environ['N_EPOCHS'] = '5'

!${ECLARE_ROOT}/scripts/clip_scripts/clip_samples.sh $N_EPOCHS

### Step 2: perform multi-teacher distillation (ECLARE)

In [None]:
# Go to ECLARE_ROOT (in case not already there)
os.chdir(os.environ["ECLARE_ROOT"])

Identify the Job ID related to the CLIP teacher models. Should be shown in the first line output by clip_samples.sh, e.g.:<br>

Job ID: clip_03173230

Can also run code below to identify most common directory in OUTPATH:


In [None]:
# Get most recent directory in OUTPATH that starts with "clip_"
from glob import glob
clip_dirs = glob(os.path.join(os.environ["OUTPATH"], "clip_*"))
if clip_dirs:
    latest_clip_dir = max(clip_dirs, key=os.path.getmtime)
    clip_job_id = os.path.basename(latest_clip_dir)
    print(f"Most recent CLIP job directory, assigned to clip_job_id: {clip_job_id}")
else:
    print("No CLIP job directories found in OUTPATH")


Run ECLARE

In [None]:
# Run eclare_samples.sh

os.environ['N_EPOCHS'] = '5'
os.environ['CLIP_JOB_ID'] = clip_job_id.split('_')[1]  # only keep digits

!${ECLARE_ROOT}/scripts/eclare_scripts/eclare_samples.sh $N_EPOCHS $CLIP_JOB_ID

Get most recent ECLARE job ID

In [None]:
# Get most recent directory in OUTPATH that starts with "eclare_"
from glob import glob
eclare_dirs = glob(os.path.join(os.environ["OUTPATH"], "eclare_*"))
if eclare_dirs:
    latest_eclare_dir = max(eclare_dirs, key=os.path.getmtime)
    eclare_job_id = os.path.basename(latest_eclare_dir)

print(f"Most recent ECLARE job directory, assigned to eclare_job_id: {eclare_job_id}")

### Assess model performance

Define functions for importing data and metrics

In [None]:
import os
import matplotlib.pyplot as plt

from eclare.post_hoc_utils import get_metrics
from eclare.models import load_CLIP_and_ECLARE_model

In [None]:
# Get metrics
clip_job_id_split = clip_job_id.split('_')[1]
eclare_job_id_split = eclare_job_id.split('_')[1]

source_df_clip, target_df_clip, source_only_df_clip = get_metrics('clip', clip_job_id_split)   # may need to rename 'triplet_align_<job_id>' by 'clip_<job_id>'
target_df_multiclip = get_metrics('eclare', eclare_job_id_split, target_only=True) # may need to rename 'multisource_align_<job_id>' by 'multiclip_<job_id>'


In [None]:
# Load teacher CLIP and student ECLARE models
best_multiclip_idx= str(target_df_multiclip['ilisis'].droplevel(0).argmax())
paths_root = os.path.join(os.environ['OUTPATH'], eclare_job_id)
student_model_path = os.path.join(paths_root, 'PFC_Zhu', best_multiclip_idx, 'student_model.pt')

teacher_models, student_model = load_CLIP_and_ECLARE_model(student_model_path, best_multiclip_idx)

In [None]:
# Get nuclei and latents
from eclare.setup_utils import pfc_zhu_setup
from eclare.post_hoc_utils import get_latents

# Teacher data
teacher_rnas, teacher_atacs, teacher_rna_latents_dict, teacher_atac_latents_dict = {}, {}, {}, {}

for source_dataset, teacher_model in teacher_models.items():
    
    teacher_rna, teacher_atac, cell_group, _, _, _, _ = pfc_zhu_setup(teacher_model.args, pretrain=None, return_type='data')
    teacher_rnas[source_dataset] = teacher_rna
    teacher_atacs[source_dataset] = teacher_atac

    teacher_rna_latents, teacher_atac_latents = get_latents(teacher_model, teacher_rna, teacher_atac, return_tensor=True)
    teacher_rna_latents_dict[source_dataset] = teacher_rna_latents
    teacher_atac_latents_dict[source_dataset] = teacher_atac_latents

# Student data
student_rna, student_atac, cell_group, _, _, _, _ = pfc_zhu_setup(student_model.args, pretrain=None, return_type='data')
student_rna_latents, student_atac_latents = get_latents(student_model, student_rna, student_atac, return_tensor=True)

In [None]:
# Plot UMAP embeddings for teachers and student
from eclare.post_hoc_utils import plot_umap_embeddings
from eclare.post_hoc_utils import create_celltype_palette

color_map_ct = create_celltype_palette(teacher_rna.obs[cell_group].values, teacher_atac.obs[cell_group].values, plot_color_palette=False)

# teachers
for source_dataset in teacher_rnas.keys():
    plot_umap_embeddings(teacher_rna_latents_dict[source_dataset], teacher_atac_latents_dict[source_dataset], teacher_rnas[source_dataset].obs[cell_group].values, teacher_atacs[source_dataset].obs[cell_group].values, None, None, color_map_ct)
    plt.suptitle(f"PFC_Zhu embeddings using teacher model (source: {source_dataset})"); plt.tight_layout(); plt.show()

# student
plot_umap_embeddings(student_rna_latents, student_atac_latents, student_rna.obs[cell_group].values, student_atac.obs[cell_group].values, None, None, color_map_ct)
plt.suptitle(f"PFC_Zhu embeddings using student model"); plt.tight_layout(); plt.show()