In [None]:
import datetime
from geneformer import Classifier

import numpy as np
import torch
import os
import random

In [None]:
# http://p2-gpu-1:8144/lab?token=61cf43513f8b71cd83fa2309252ddcd14bf34e05457c22d7

In [None]:
torch.cuda.is_available()

In [None]:
seed_num = 0
random.seed(seed_num)
np.random.seed(seed_num)
seed_val = 42
torch.manual_seed(seed_val)
torch.cuda.manual_seed_all(seed_val)

BASE_DIR = "/hpcfs/groups/phoenix-hpc-mangiola_laboratory/haroon/geneformer"
AGGREGATION_LEVEL = "metacell_8"  # Set aggregation level to 'singlecell' for single-cell data
MODEL_VARIANT = "30M"  # Specify the model variant, e.g., "30M" or "95M"
TASK = "cell_type_classification"  # Specify the task, e.g., "disease_classification" or "dosage_sensitivity"
DATASET = "cellnexus_cell_types"

VALID_COMBINATIONS = {
    "disease_classification": ["genecorpus_heart_disease", "cellnexus_blood_disease"],
    "dosage_sensitivity": ["genecorpus_dosage_sensitivity"],
    "cell_type_classification": ["cellnexus_cell_types"]
}

assert TASK in VALID_COMBINATIONS, f"Unknown TASK: '{TASK}'"
assert DATASET in VALID_COMBINATIONS[TASK], \
    f"For TASK='{TASK}', DATASET must be one of {VALID_COMBINATIONS[TASK]}, got '{DATASET}'"
                         
if MODEL_VARIANT == "30M":
    GENE_MEDIAN_FILE = os.path.join(BASE_DIR, "Geneformer/geneformer/gene_dictionaries_30m/gene_median_dictionary_gc30M.pkl")
    TOKEN_DICTIONARY_FILE = os.path.join(BASE_DIR, "Geneformer/geneformer/gene_dictionaries_30m/token_dictionary_gc30M.pkl")
    ENSEMBL_MAPPING_FILE = os.path.join(BASE_DIR, "Geneformer/geneformer/gene_dictionaries_30m/ensembl_mapping_dict_gc30M.pkl")
elif MODEL_VARIANT == "95M":
    GENE_MEDIAN_FILE = os.path.join(BASE_DIR, "Geneformer/geneformer/gene_median_dictionary_gc95M.pkl")
    TOKEN_DICTIONARY_FILE = os.path.join(BASE_DIR, "Geneformer/geneformer/token_dictionary_gc95M.pkl")
    ENSEMBL_MAPPING_FILE = os.path.join(BASE_DIR, "Geneformer/geneformer/ensembl_mapping_dict_gc95M.pkl")
else:
    raise ValueError("MODEL_VARIANT must be either '30M' or '95M'")

# Print the final paths to verify
print(f"GENE_MEDIAN_FILE: {GENE_MEDIAN_FILE}")
print(f"TOKEN_DICTIONARY_FILE: {TOKEN_DICTIONARY_FILE}")
print(f"ENSEMBL_MAPPING_FILE: {ENSEMBL_MAPPING_FILE}")


if TASK == "disease_classification" or TASK == "cell_type_classification":
    classifier_type = "cell"
elif TASK == "dosage_sensitivity":
    classifier_type = "gene"

else:
    raise ValueError("TASK must be either 'disease_classification' or 'dosage_sensitivity'")


model_version = "V1" # For now it does not do anything, but it is here for future compatibility

 # model_version : str
 #            | To auto-select settings for model version other than current default.
 #            | Current options: V1: models pretrained on ~30M cells, V2: models pretrained on ~104M cells   

output_dir = os.path.join(BASE_DIR, "trained_cell_classification_models", TASK, DATASET, str(MODEL_VARIANT) + "_" + str(AGGREGATION_LEVEL))
os.makedirs(output_dir, exist_ok=True)

In [None]:
output_dir

In [None]:
DATASET_PATH = os.path.join(BASE_DIR, "datasets", TASK, DATASET)

DATASET_PATH


In [None]:
if TASK == "disease_classification" and DATASET == "genecorpus_heart_disease":
    input_data_file = os.path.join(DATASET_PATH, "human_dcm_hcm_nf.dataset")
    cell_state_dict = {"state_key": "disease", "states": "all"}
    filter_data_dict={"cell_type":["Cardiomyocyte1","Cardiomyocyte2","Cardiomyocyte3"]}

    # previously balanced splits with prepare_data and validate functions
    # argument attr_to_split set to "individual" and attr_to_balance set to ["disease","lvef","age","sex","length"]
    train_ids = ["1447", "1600", "1462", "1558", "1300", "1508", "1358", "1678", "1561", "1304", "1610", "1430", "1472", "1707", "1726", "1504", "1425", "1617", "1631", "1735", "1582", "1722", "1622", "1630", "1290", "1479", "1371", "1549", "1515"]
    eval_ids = ["1422", "1510", "1539", "1606", "1702"]
    test_ids = ["1437", "1516", "1602", "1685", "1718"]
    
    train_test_id_split_dict = {"attr_key": "individual",
                                "train": train_ids+eval_ids,
                                "test": test_ids}

    train_valid_id_split_dict = {"attr_key": "individual",
                            "train": train_ids,
                            "eval": eval_ids}
elif TASK == "disease_classification" and DATASET == "cellnexus_blood_disease":
    input_data_file = os.path.join(DATASET_PATH, "tokenized_" + str(MODEL_VARIANT), "cellnexus_singlecell.dataset")
    import json
    filter_data_dict = None
    cell_state_dict = {"state_key": "disease", "states": "all"}


    # Load train_test split dictionary
    train_test_file = os.path.join(DATASET_PATH, "train_test_id_split_dict.json")
    print(f"Loading train_test split from: {train_test_file}")

    with open(train_test_file, 'r') as f:
        train_test_id_split_dict = json.load(f)

    # Load train_valid split dictionary  
    train_valid_file = os.path.join(DATASET_PATH, "train_valid_id_split_dict.json")
    print(f"Loading train_valid split from: {train_valid_file}")

    with open(train_valid_file, 'r') as f:
        train_valid_id_split_dict = json.load(f)

    # Verify the loaded dictionaries
    print("\n=== Train-Test Split Dictionary ===")
    print(f"Attribute key: {train_test_id_split_dict['attr_key']}")
    print(f"Train samples: {len(train_test_id_split_dict['train'])}")
    print(f"Test samples: {len(train_test_id_split_dict['test'])}")

    print("\n=== Train-Valid Split Dictionary ===")
    print(f"Attribute key: {train_valid_id_split_dict['attr_key']}")
    print(f"Train samples: {len(train_valid_id_split_dict['train'])}")
    print(f"Eval samples: {len(train_valid_id_split_dict['eval'])}")

    # Show a few sample IDs for verification
    print(f"\nSample train IDs: {train_test_id_split_dict['train'][:5]}")
    print(f"Sample test IDs: {train_test_id_split_dict['test'][:5]}")
    print(f"Sample eval IDs: {train_valid_id_split_dict['eval'][:5]}")

    print("\n✅ Dictionaries loaded successfully!")
        # previously balanced splits with prepare_data and validate functions
        # argument attr_to_split set to "individual" and attr_to_balance set to ["disease","
elif TASK == "cell_type_classification" and DATASET == "cellnexus_cell_types":
    input_data_file = os.path.join(DATASET_PATH, "tokenized_" + str(MODEL_VARIANT), "cellnexus_singlecell.dataset")
    cell_state_dict = {"state_key": "cell_type", "states": "all"}
    filter_data_dict = None
    

In [None]:
input_data_file

In [None]:

output_prefix = DATASET + "_test"


In [None]:
output_dir

In [None]:
torch.cuda.device_count()

In [None]:

training_args = {
    "num_train_epochs": 5,
    "learning_rate": 0.000804,
    "lr_scheduler_type": "polynomial",
    "warmup_steps": 1812,
    "weight_decay":0.258828,
    "per_device_train_batch_size": 128,
    "seed": 73,
    
    
    # "logging_dir": os.path.normpath("D:/geneformer_finetuning/trained_cell_classification_models/disease_classification/genecorpus_heart_disease/30M_metacell_8/250623_geneformer_cellClassifier_genecorpus_heart_disease_test/ksplit1/runs"),
}

# OF NOTE: token_dictionary_file must be set to the gc-30M token dictionary if using a 30M series model
# (otherwise the Classifier will use the current default model dictionary)
# 30M token dictionary: https://huggingface.co/ctheodoris/Geneformer/blob/main/geneformer/gene_dictionaries_30m/token_dictionary_gc30M.pkl
if TASK == "disease_classification":
    cc = Classifier(classifier=classifier_type,
                    # model_version = model_version,
                    cell_state_dict = cell_state_dict,
                    filter_data=filter_data_dict,
                    training_args=training_args,
                    max_ncells=None,
                    freeze_layers = 2,
                    num_crossval_splits = 1,
                    ngpu =  torch.cuda.device_count(),
                    forward_batch_size=200,
                    token_dictionary_file = TOKEN_DICTIONARY_FILE,
                    nproc=16)
                    
elif TASK == "cell_type_classification":
    cc = Classifier(classifier=classifier_type,
                    # model_version = model_version,
                    cell_state_dict = cell_state_dict,
                    filter_data=filter_data_dict,
                    training_args=training_args,
                    max_ncells=None,
                    freeze_layers = 4,
                    num_crossval_splits = 1,
                    split_sizes = {"train": 0.8, "valid": 0.1, "test": 0.1},
                    stratify_splits_col = "cell_type_backup", #cell_type is (from cell_state_dict) internall renamed to label
                    ngpu =  torch.cuda.device_count(),
                    forward_batch_size=200,
                    token_dictionary_file = TOKEN_DICTIONARY_FILE,
                    nproc=16)

### num_crossval_splits : {0, 1, 5}
        #     | 0: train on all data without splitting
        #     | 1: split data into train and eval sets by designated split_sizes["valid"]
        #     | 5: split data into 5 folds of train and eval sets by designated split_sizes["valid"]
        # split_sizes : None, dict
        #     | Dictionary of proportion of data to hold out for train, validation, and test sets
        #     | {"train": 0.8, "valid": 0.1, "test": 0.1} if intending 80/10/10 train/valid/test split

In [None]:
input_data_file

In [None]:
output_dir

In [None]:
import os
import shutil
from pathlib import Path

def check_and_copy_datasets(output_dir):
    """
    Check if dataset paths exist and copy them to output directory
    
    Args:
        output_dir (str): Destination directory to copy datasets
    """
    
    # Define the dataset paths
    base_path = "/hpcfs/groups/phoenix-hpc-mangiola_laboratory/haroon/geneformer/trained_cell_classification_models/cell_type_classification/cellnexus_cell_types/30M_metacell_16/"
    
    train_dataset_path = os.path.join(base_path, "cellnexus_cell_types_test_labeled_train.dataset")
    test_dataset_path = os.path.join(base_path, "cellnexus_cell_types_test_labeled_test.dataset")
    ids_path = os.path.join(base_path, "cellnexus_cell_types_test_id_class_dict.pkl")
    
    paths_to_check = {
        "train": train_dataset_path,
        "test": test_dataset_path,
        "ids": ids_path
    }
    
    # Create output directory if it doesn't exist
    Path(output_dir).mkdir(parents=True, exist_ok=True)
    print(f"Output directory: {output_dir}")
    
    # Check if paths exist and copy them
    for dataset_type, path in paths_to_check.items():
        print(f"\nChecking {dataset_type} dataset...")
        print(f"Path: {path}")
        
        if os.path.exists(path):
            print(f"✅ {dataset_type.capitalize()} dataset EXISTS")
            
            # Define destination path
            dataset_name = os.path.basename(path)
            destination = os.path.join(output_dir, dataset_name)
            
            try:
                print(f"Copying to: {destination}")
                
                # Copy the dataset (use copytree for directories, copy2 for files)
                if os.path.isdir(path):
                    if os.path.exists(destination):
                        print(f"Destination already exists, skipping: {destination}")
                    #     shutil.rmtree(destination)
                    else:
                        shutil.copytree(path, destination)
                else:
                    if os.path.exists(destination):
                        print(f"Destination already exists, skipping: {destination}")
                    else:
                        shutil.copy2(path, destination)
                
                print(f"✅ Successfully copied {dataset_type} dataset")
                
            except Exception as e:
                print(f"❌ Error copying {dataset_type} dataset: {str(e)}")
                
        else:
            print(f"❌ {dataset_type.capitalize()} dataset DOES NOT EXIST")
    
    print(f"\n📁 Contents of output directory '{output_dir}':")
    try:
        for item in os.listdir(output_dir):
            item_path = os.path.join(output_dir, item)
            if os.path.isdir(item_path):
                print(f"  📂 {item}/")
            else:
                print(f"  📄 {item}")
    except Exception as e:
        print(f"Error listing output directory: {str(e)}")


In [None]:
# Prepare the data for training
# Example input_data_file for 30M model: https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/example_input_files/cell_classification/disease_classification/human_dcm_hcm_nf.dataset
if TASK == "disease_classification":
    cc.prepare_data(input_data_file=input_data_file,
                    output_directory=output_dir,
                    output_prefix=output_prefix,
                    split_id_dict=train_test_id_split_dict)

elif TASK == "cell_type_classification":

    # Define the dataset paths
    base_path = "/hpcfs/groups/phoenix-hpc-mangiola_laboratory/haroon/geneformer/trained_cell_classification_models/cell_type_classification/cellnexus_cell_types/30M_metacell_16/"
    
    train_dataset_path = os.path.join(base_path, "cellnexus_cell_types_test_labeled_train.dataset")
    test_dataset_path = os.path.join(base_path, "cellnexus_cell_types_test_labeled_test.dataset")
        
    if os.path.exists(train_dataset_path) and os.path.exists(test_dataset_path):

        print("Existing data for cell type classification. Skipping preparation...")        
        
        check_and_copy_datasets(output_dir)
    else:
        print("Preparing data for cell type classification...")
        cc.prepare_data(input_data_file=input_data_file,
                        output_directory=output_dir,
                        output_prefix=output_prefix,
                        test_size=0.1,
                        split_id_dict=None)


In [None]:
import os
import glob

# Base path
pretrained_model_path = os.path.join(BASE_DIR, "trained_foundation_models", "models", 
                                   f"30M_AGG{AGGREGATION_LEVEL}_6_emb256_SL2048_E2_B12_LR0.001_LSlinear_WU10000_Oadamw")

# Check if final_trained_model subfolder exists
final_model_path = os.path.join(pretrained_model_path, "final_trained_model", AGGREGATION_LEVEL)
final_model_path_empty = False

if os.path.exists(final_model_path) and os.path.isdir(final_model_path):
    # Check if the folder has model files (not just empty)
    model_files = glob.glob(os.path.join(final_model_path, "*.bin")) + \
                  glob.glob(os.path.join(final_model_path, "*.safetensors")) + \
                  glob.glob(os.path.join(final_model_path, "config.json"))
    
    if model_files:
        pretrained_model_path = final_model_path
        print(f"✓ Using final trained model: {pretrained_model_path}")
    else:
        print(f"⚠ Final model folder exists but appears empty: {final_model_path}")
        final_model_path_empty = True
        # Fall through to checkpoint search
else:
    print(f"ℹ Final model folder not found: {final_model_path}")

# If final model not found or empty, find most recent checkpoint
if not (os.path.exists(final_model_path)) or final_model_path_empty:
    # Find all checkpoint folders
    checkpoint_pattern = os.path.join(pretrained_model_path, "checkpoint-*")
    checkpoint_folders = glob.glob(checkpoint_pattern)
    
    if checkpoint_folders:
        # Extract checkpoint numbers and find the highest one
        checkpoint_numbers = []
        for folder in checkpoint_folders:
            folder_name = os.path.basename(folder)
            if folder_name.startswith("checkpoint-"):
                try:
                    checkpoint_num = int(folder_name.split("-")[1])
                    checkpoint_numbers.append((checkpoint_num, folder))
                except ValueError:
                    continue
        
        if checkpoint_numbers:
            # Sort by checkpoint number and get the highest one
            most_recent_checkpoint = max(checkpoint_numbers, key=lambda x: x[0])
            pretrained_model_path = most_recent_checkpoint[1]
            print(f"✓ Using most recent checkpoint: {pretrained_model_path} (step {most_recent_checkpoint[0]})")
        else:
            print(f"❌ No valid checkpoint folders found in: {pretrained_model_path}")
            raise FileNotFoundError(f"No trained model or checkpoints found in {pretrained_model_path}")
    else:
        print(f"❌ No checkpoint folders found in: {pretrained_model_path}")
        raise FileNotFoundError(f"No trained model or checkpoints found in {pretrained_model_path}")

print(f"Final pretrained model path: {pretrained_model_path}")

In [None]:
pretrained_model_path

In [None]:
f"{output_dir}/{output_prefix}_labeled_train.dataset"

In [None]:

if TASK == "disease_classification":

    # Example 6 layer 30M Geneformer model: https://huggingface.co/ctheodoris/Geneformer/blob/main/gf-6L-30M-i2048/model.safetensors
    all_metrics = cc.validate(model_directory=pretrained_model_path,
                            prepared_input_data_file=f"{output_dir}/{output_prefix}_labeled_train.dataset",
                            id_class_dict_file=f"{output_dir}/{output_prefix}_id_class_dict.pkl",
                            output_directory=output_dir,
                            output_prefix=output_prefix,
                            split_id_dict=train_valid_id_split_dict)
                            # to optimize hyperparameters, set n_hyperopt_trials=100 (or alternative desired # of trials)

elif TASK == "cell_type_classification":
    all_metrics = cc.validate(model_directory=pretrained_model_path,
                            prepared_input_data_file=f"{output_dir}/{output_prefix}_labeled_train.dataset",
                            id_class_dict_file=f"{output_dir}/{output_prefix}_id_class_dict.pkl",
                            output_directory=output_dir,
                            output_prefix=output_prefix,
                            split_id_dict=None)
                            # to optimize hyperparameters, set n_hyperopt_trials=100 (or alternative desired # of trials)



In [None]:
cc = Classifier(classifier=classifier_type,
                cell_state_dict = cell_state_dict,
                forward_batch_size=200,
                token_dictionary_file = TOKEN_DICTIONARY_FILE,
                nproc=16)

In [None]:
output_prefix

In [None]:
training_date = "250625"


In [None]:
all_metrics_test = cc.evaluate_saved_model(
        model_directory=f"{output_dir}/{training_date}_geneformer_cellClassifier_{output_prefix}/ksplit1/",
        id_class_dict_file=f"{output_dir}/{output_prefix}_id_class_dict.pkl",
        test_data_file=f"{output_dir}/{output_prefix}_labeled_test.dataset",
        output_directory=output_dir,
        output_prefix=output_prefix,
    )

In [None]:
from sklearn import preprocessing
from sklearn.metrics import (
    ConfusionMatrixDisplay,
    accuracy_score,
    auc,
    confusion_matrix,
    f1_score,
    roc_curve,
)
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from pathlib import Path


def plot_conf_mat(
        conf_mat_dict,
        output_directory,
        output_prefix,
        custom_class_order=None,
    ):
        """
        Plot confusion matrix results of evaluating the fine-tuned model.

        **Parameters**

        conf_mat_dict : dict
            | Dictionary of model_name : confusion_matrix_DataFrame
            | (all_metrics["conf_matrix"] from self.validate)
        output_directory : Path
            | Path to directory where plots will be saved
        output_prefix : str
            | Prefix for output file
        custom_class_order : None, list
            | List of classes in custom order for plots.
            | Same order will be used for all models.
        """

        for model_name in conf_mat_dict.keys():
            plot_confusion_matrix(
                conf_mat_dict[model_name],
                model_name,
                output_directory,
                output_prefix,
                custom_class_order,
            )




# plot confusion matrix
def plot_confusion_matrix(
    conf_mat_df, title, output_dir, output_prefix, custom_class_order
):
    # fig = plt.figure()
    # fig.set_size_inches(10, 10)
    # sns.set(font_scale=1)
    fig = plt.figure()
    fig.set_size_inches(20, 16)  # Much larger figure
    sns.set(font_scale=0.8)  
    sns.set_style("whitegrid", {"axes.grid": False})
    if custom_class_order is not None:
        conf_mat_df = conf_mat_df.reindex(
            index=custom_class_order, columns=custom_class_order
        )
    display_labels = generate_display_labels(conf_mat_df)
    conf_mat = preprocessing.normalize(conf_mat_df.to_numpy(), norm="l1")
    display = ConfusionMatrixDisplay(
        confusion_matrix=conf_mat, display_labels=display_labels
    )
    display.plot(cmap="Blues", values_format=".2f")


    



    plt.xticks(rotation=90, ha='center', fontsize=10)  # Vertical rotation, smaller font
    plt.yticks(rotation=0, fontsize=10)                # Smaller font for y-axis
    plt.tight_layout()                                 # Better spacing
    # Rotate x-axis labels to prevent overlap
    plt.yticks(rotation=0)
    # plt.title(title)
    plt.title(f"Cell Type ({AGGREGATION_LEVEL})", fontsize=16, pad=20)
    

    output_file = (Path(output_dir) / f"{output_prefix}_conf_mat").with_suffix(".pdf")
    display.figure_.savefig(output_file, bbox_inches="tight")
    plt.show()


def generate_display_labels(conf_mat_df):
    display_labels = []
    i = 0
    for label in conf_mat_df.index:
        display_labels += [f"{label}\nn={conf_mat_df.iloc[i,:].sum():.0f}"]
        i = i + 1
    return display_labels

In [None]:
output_dir

In [None]:
plot_conf_mat(
        conf_mat_dict={"Geneformer": all_metrics_test["conf_matrix"]},
        output_directory=output_dir,
        output_prefix=output_prefix,
        # custom_class_order=["nf","hcm","dcm"],
)

In [None]:
# cc.plot_predictions(
#     predictions_file=f"{output_dir}/{output_prefix}_pred_dict.pkl",
#     id_class_dict_file=f"{output_dir}/{output_prefix}_id_class_dict.pkl",
#     title="cell_type" if  TASK == "cell_type_classification" else "disease",
#     output_directory=output_dir,
#     output_prefix=output_prefix,
#     # custom_class_order=["nf","hcm","dcm"],
#     # ["Non-failing","Hypertrophic","Dilated Cardiomyopathy"]
# )

In [None]:
all_metrics_test


In [None]:
import json
json.dump(all_metrics_test, open(Path(output_dir) / f"{output_prefix}_all_metrics_test.json", 'w'), indent=2, default=str)

# Plot Embeddings

In [None]:
from geneformer import EmbExtractor


In [None]:
os.path.join(output_dir, "embeddings")

In [None]:
os.makedirs(os.path.join(output_dir, "embeddings"), exist_ok=True)

In [None]:
# initiate EmbExtractor
# OF NOTE: token_dictionary_file must be set to the gc-30M token dictionary if using a 30M series model
# (otherwise the EmbExtractor will use the current default model dictionary)
embex = EmbExtractor(model_type="CellClassifier",
                     num_classes=9 if DATASET == "cellnexus_blood_disease" else 51, # 9 classes for disease classification
                     filter_data=filter_data_dict,
                     max_ncells= 100000, # 10000 if DATASET == "cellnexus_blood_disease" else 1000,
                     emb_layer=0, # | Embedding layer to extract.
            # | The last layer is most specifically weighted to optimize the given learning objective.
            # | Generally, it is best to extract the 2nd to last layer to get a more general representation.
            # | -1: 2nd to last layer
            # | 0: last layer
                     emb_label=["cell_type"],
                     labels_to_plot=["cell_type"],
                     forward_batch_size=200,
                     nproc=16,
                     emb_mode = "cell",
                     token_dictionary_file=TOKEN_DICTIONARY_FILE) # change from current default dictionary for 30M model series

# extracts embedding from input data
# input data is tokenized rank value encodings generated by Geneformer tokenizer (see tokenizing_scRNAseq_data.ipynb)
# example dataset for 30M model series: https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/example_input_files/cell_classification/disease_classification/human_dcm_hcm_nf.dataset
embs = embex.extract_embs(model_directory=f"{output_dir}/{training_date}_geneformer_cellClassifier_{output_prefix}/ksplit1/", # example 30M fine-tuned model
                          input_data_file=input_data_file,
                          output_directory=os.path.join(output_dir, "embeddings"),
                          output_prefix="test")


In [None]:
input_data_file

In [None]:
default_kwargs_dict = {"size": 10}

In [None]:
from collections import Counter
def gen_heatmap_class_dict(classes, label_colors_series):
    class_color_dict_df = pd.DataFrame(
        {"classes": classes, "color": label_colors_series}
    )
    class_color_dict_df = class_color_dict_df.drop_duplicates(subset=["classes"])
    return dict(zip(class_color_dict_df["classes"], class_color_dict_df["color"]))

def gen_heatmap_class_colors(labels, df):
    pal = sns.cubehelix_palette(
        len(Counter(labels).keys()),
        light=0.9,
        dark=0.1,
        hue=1,
        reverse=True,
        start=1,
        rot=-2,
    )
    lut = dict(zip(map(str, Counter(labels).keys()), pal))
    colors = pd.Series(labels, index=df.index).map(lut)
    return colors

def make_colorbar(embs_df, label):
    labels = list(embs_df[label])

    cell_type_colors = gen_heatmap_class_colors(labels, embs_df)
    label_colors = pd.DataFrame(cell_type_colors, columns=[label])

    # create dictionary for colors and classes
    label_color_dict = gen_heatmap_class_dict(labels, label_colors[label])
    return label_colors, label_color_dict

In [None]:
import scanpy as sc
import seaborn as sns
import logging

logger = logging.getLogger(__name__)

def plot_heatmap(embs_df, emb_dims, label, output_file, kwargs_dict):
    sns.set_style("white")
    sns.set(font_scale=2)
    plt.figure(figsize=(15, 15), dpi=150)
    label_colors, label_color_dict = make_colorbar(embs_df, label)

    default_kwargs_dict = {
        "row_cluster": True,
        "col_cluster": True,
        "row_colors": label_colors,
        "standard_scale": 1,
        "linewidths": 0,
        "xticklabels": False,
        "yticklabels": False,
        "figsize": (15, 15),
        "center": 0,
        "cmap": "magma",
    }

    if kwargs_dict is not None:
        default_kwargs_dict.update(kwargs_dict)
    g = sns.clustermap(
        embs_df.iloc[:, 0:emb_dims].apply(pd.to_numeric), **default_kwargs_dict
    )

    plt.setp(g.ax_row_colors.get_xmajorticklabels(), rotation=45, ha="right")

    for label_color in list(label_color_dict.keys()):
        g.ax_col_dendrogram.bar(
            0, 0, color=label_color_dict[label_color], label=label_color, linewidth=0
        )

        g.ax_col_dendrogram.legend(
            title=f"{label}",
            loc="lower center",
            ncol=4,
            bbox_to_anchor=(0.5, 1),
            facecolor="white",
        )
    
    logger.info(f"Output file: {output_file}")
    plt.savefig(output_file, bbox_inches="tight")
    plt.show()



In [None]:
import anndata

def plot_umap(embs_df, emb_dims, label, output_file, kwargs_dict, seed=0):
    only_embs_df = embs_df.iloc[:, :emb_dims]
    only_embs_df.index = pd.RangeIndex(0, only_embs_df.shape[0], name=None).astype(str)
    only_embs_df.columns = pd.RangeIndex(0, only_embs_df.shape[1], name=None).astype(
        str
    )
    vars_dict = {"embs": only_embs_df.columns}
    obs_dict = {"cell_id": list(only_embs_df.index), f"{label}": list(embs_df[label])}
    adata = anndata.AnnData(X=only_embs_df, obs=obs_dict, var=vars_dict)
    sc.tl.pca(adata, svd_solver="arpack")
    sc.pp.neighbors(adata, random_state=seed)
    sc.tl.umap(adata, random_state=seed)
    # sns.set(rc={"figure.figsize": (10, 10)}, font_scale=2.3)
    sns.set(rc={"figure.figsize": (30, 30)}, font_scale=2.3)

    sns.set_style("white")
    default_kwargs_dict = {"size": 200}
    if kwargs_dict is not None:
        default_kwargs_dict.update(kwargs_dict)

    cats = set(embs_df[label])

    with plt.rc_context():
        ax = sc.pl.umap(adata, color=label, show=False, **default_kwargs_dict)
        # Move legend to bottom
        ax.legend(
            markerscale=2,
            frameon=False,
            loc="upper center",                    # Changed from "center left"
            bbox_to_anchor=(0.5, -0.05),          # Changed: (x, y) - x=0.5 centers horizontally, y=-0.05 places below plot
            ncol=(3 if len(cats) <= 14 else 4 if len(cats) <= 30 else 5),  # Increased ncol for horizontal layout
        )

        plt.title(f"Cell Type ({AGGREGATION_LEVEL})", fontsize=16, pad=20)

        
        plt.savefig(output_file, bbox_inches="tight")
        plt.show()


In [None]:

def plot_embs(
    embs,
    plot_style,
    output_directory,
    output_prefix,
    max_ncells,
    max_ncells_to_plot=1000,
    kwargs_dict=None,
    emb_label=["cell_type"],
    labels_to_plot=["cell_type"],
):
    """
    Plot embeddings, coloring by provided labels.

    **Parameters:**

    embs : pandas.core.frame.DataFrame
        | Pandas dataframe containing embeddings output from extract_embs
    plot_style : str
        | Style of plot: "heatmap" or "umap"
    output_directory : Path
        | Path to directory where plots will be saved as pdf
    output_prefix : str
        | Prefix for output file
    max_ncells_to_plot : None, int
        | Maximum number of cells to plot.
        | Default is 1000 cells randomly sampled from embeddings.
        | If None, will plot embeddings from all cells.
    kwargs_dict : dict
        | Dictionary of kwargs to pass to plotting function.

    **Examples:**

    .. code-block :: python

        >>> embex.plot_embs(embs=embs,
        ...                 plot_style="heatmap",
        ...                 output_directory="path/to/output_directory",
        ...                 output_prefix="output_prefix")

    """

    if plot_style not in ["heatmap", "umap"]:
        logger.error(
            "Invalid option for 'plot_style'. " "Valid options: {'heatmap','umap'}"
        )
        raise

    if (plot_style == "umap") and (labels_to_plot is None):
        logger.error("Plotting UMAP requires 'labels_to_plot'. ")
        raise

    if max_ncells_to_plot is not None:
        if max_ncells_to_plot > max_ncells:
            max_ncells_to_plot = max_ncells
            logger.warning(
                "max_ncells_to_plot must be <= max_ncells. "
                f"Changing max_ncells_to_plot to {max_ncells}."
            )
        elif max_ncells_to_plot < max_ncells:
            embs = embs.sample(max_ncells_to_plot, axis=0)

    if emb_label is None:
        label_len = 0
    else:
        label_len = len(emb_label)

    emb_dims = embs.shape[1] - label_len

    if emb_label is None:
        emb_labels = None
    else:
        emb_labels = embs.columns[emb_dims:]

    if plot_style == "umap":
        for label in labels_to_plot:
            if label not in emb_labels:
                logger.warning(
                    f"Label {label} from labels_to_plot "
                    f"not present in provided embeddings dataframe."
                )
                continue
            output_prefix_label = output_prefix + f"_umap_{label}"
            output_file = (
                Path(output_directory) / output_prefix_label
            ).with_suffix(".pdf")
            plot_umap(embs, emb_dims, label, output_file, kwargs_dict)

    if plot_style == "heatmap":
        for label in labels_to_plot:
            if label not in emb_labels:
                logger.warning(
                    f"Label {label} from labels_to_plot "
                    f"not present in provided embeddings dataframe."
                )
                continue
            output_prefix_label = output_prefix + f"_heatmap_{label}"
            output_file = (
                Path(output_directory) / output_prefix_label
            ).with_suffix(".pdf")
            plot_heatmap(embs, emb_dims, label, output_file, kwargs_dict)


In [None]:

# plot UMAP of cell embeddings
# note: scanpy umap necessarily saves figs to figures directory
plot_embs(embs=embs, 
        plot_style="umap",
        max_ncells_to_plot= 100000, #10000 if DATASET == "cellnexus_blood_disease" else 1000,  # Set to None to plot all cells
        output_directory=os.path.join(output_dir, "embeddings"),  
        output_prefix="emb_plot",
        kwargs_dict=default_kwargs_dict,
        max_ncells = 100000,
        emb_label=["cell_type"],
        labels_to_plot=["cell_type"],)


In [None]:
# plot heatmap of cell embeddings
# embex.plot_embs(embs=embs, 
#                 plot_style="heatmap",
#                 max_ncells_to_plot=10000 if DATASET == "cellnexus_blood_disease" else 1000,
#                 output_directory=os.path.join(output_dir, "embeddings"),
#                 output_prefix="heatmap_plot")

# Check model frozen layers

In [None]:
import torch
from transformers import BertForSequenceClassification

model_path = "/hpcfs/groups/phoenix-hpc-mangiola_laboratory/haroon/geneformer/trained_foundation_models/models/30M_AGGsinglecell_6_emb256_SL2048_E2_B12_LR0.001_LSlinear_WU10000_Oadamw/final_trained_model/singlecell/"
def_freeze_layers = 6 
mode = "train"
output_hidden_states = (mode == "eval")
model_args = {
        "pretrained_model_name_or_path": model_path,
        "output_hidden_states": output_hidden_states,
        "output_attentions": False,
    }
model_args["num_labels"] = 9


model = BertForSequenceClassification.from_pretrained(**model_args)


In [None]:
def count_frozen_layers(model):
    """Count how many layers are completely frozen"""
    frozen_layers = 0
    total_layers = len(model.bert.encoder.layer)
    
    for i, layer in enumerate(model.bert.encoder.layer):
        # Check if all parameters in this layer are frozen
        layer_params = list(layer.parameters())
        if layer_params:  # Make sure layer has parameters
            all_frozen = all(not param.requires_grad for param in layer_params)
            if all_frozen:
                frozen_layers += 1
            else:
                break  # Assuming layers are frozen sequentially from start
    
    return frozen_layers, total_layers

def count_trainable_params(model):
    """Count total trainable vs frozen parameters"""
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total = sum(p.numel() for p in model.parameters())
    frozen = total - trainable
    return trainable, frozen, total

# Before freezing
print("=== BEFORE FREEZING ===")
frozen_layers_before, total_layers = count_frozen_layers(model)
trainable_before, frozen_before, total_before = count_trainable_params(model)

print(f"Frozen layers: {frozen_layers_before}/{total_layers}")
print(f"Trainable parameters: {trainable_before:,}")
print(f"Frozen parameters: {frozen_before:,}")
print(f"Total parameters: {total_before:,}")

# Your existing freezing code
if def_freeze_layers > 0:
    modules_to_freeze = model.bert.encoder.layer[:def_freeze_layers]
    for module in modules_to_freeze:
        for param in module.parameters():
            param.requires_grad = False

# After freezing
print("\n=== AFTER FREEZING ===")
frozen_layers_after, total_layers = count_frozen_layers(model)
trainable_after, frozen_after, total_after = count_trainable_params(model)

print(f"Frozen layers: {frozen_layers_after}/{total_layers}")
print(f"Trainable parameters: {trainable_after:,}")
print(f"Frozen parameters: {frozen_after:,}")
print(f"Total parameters: {total_after:,}")

# Summary
print("\n=== SUMMARY ===")
print(f"Successfully froze {frozen_layers_after - frozen_layers_before} additional layers")
print(f"Reduced trainable parameters by {trainable_before - trainable_after:,}")

In [None]:
def analyze_trainable_components(model):
    """Analyze which components are still trainable"""
    components = {
        'embeddings': 0,
        'encoder_layers': 0,
        'pooler': 0,
        'classifier': 0,
        'other': 0
    }
    
    for name, param in model.named_parameters():
        if param.requires_grad:
            param_count = param.numel()
            
            if 'embeddings' in name:
                components['embeddings'] += param_count
            elif 'encoder.layer' in name:
                components['encoder_layers'] += param_count
            elif 'pooler' in name:
                components['pooler'] += param_count
            elif 'classifier' in name:
                components['classifier'] += param_count
            else:
                components['other'] += param_count
                
    return components

# Analyze what's still trainable
trainable_components = analyze_trainable_components(model)
print("Trainable parameters breakdown:")
for component, count in trainable_components.items():
    if count > 0:
        print(f"  {component}: {count:,} parameters")

# Create Json of Dictionaries for IDs

In [None]:
import json
from pathlib import Path
from datasets import load_from_disk
import numpy as np

# Define paths
train_dataset_path = '/hpcfs/groups/phoenix-hpc-mangiola_laboratory/haroon/geneformer/trained_cell_classification_models/cell_type_classification/cellnexus_cell_types/30M_metacell_16/cellnexus_cell_types_test_labeled_train.dataset'
test_dataset_path = '/hpcfs/groups/phoenix-hpc-mangiola_laboratory/haroon/geneformer/trained_cell_classification_models/cell_type_classification/cellnexus_cell_types/30M_metacell_16/cellnexus_cell_types_test_labeled_test.dataset'
output_dir = '/hpcfs/groups/phoenix-hpc-mangiola_laboratory/haroon/geneformer/datasets/cell_type_classification/cellnexus_cell_types/'

# Create output directory if it doesn't exist
Path(output_dir).mkdir(parents=True, exist_ok=True)

print("Loading datasets...")
# Load the training dataset
train_data = load_from_disk(train_dataset_path)
print(f"Train dataset loaded: {len(train_data)} samples")

# Load the test dataset
test_data = load_from_disk(test_dataset_path)
print(f"Test dataset loaded: {len(test_data)} samples")

# Set eval_size (you may need to adjust this value)
eval_size = 0.1  # 20% for validation, adjust as needed

print(f"\nPerforming train-validation split with eval_size={eval_size}...")
# Perform train-test split on the training data
data_dict = train_data.train_test_split(
    test_size=eval_size,
    stratify_by_column="cell_type_backup",
    seed=42,
)

# Extract the split datasets
train_split = data_dict['train']
val_split = data_dict['test']  # This is actually validation data from the split

print(f"Train split: {len(train_split)} samples")
print(f"Validation split: {len(val_split)} samples")


In [None]:
val_split['sample_id']

In [None]:
val_split['sample_id']

In [None]:

# Extract input_ids from each dataset
print("\nExtracting input_ids...")
train_ids = [str(sample) for sample in train_split['input_ids']]
eval_ids = [str(sample) for sample in val_split['input_ids']]
test_ids = [str(sample) for sample in test_data['input_ids']]

print(f"Train IDs: {len(train_ids)}")
print(f"Eval IDs: {len(eval_ids)}")
print(f"Test IDs: {len(test_ids)}")

# Check for overlaps
print("\nChecking for ID overlaps...")
train_set = set(train_ids)
eval_set = set(eval_ids)
test_set = set(test_ids)

# Check overlaps between splits
train_eval_overlap = train_set.intersection(eval_set)
train_test_overlap = train_set.intersection(test_set)
eval_test_overlap = eval_set.intersection(test_set)

print(f"Train-Eval overlap: {len(train_eval_overlap)} IDs")
print(f"Train-Test overlap: {len(train_test_overlap)} IDs")
print(f"Eval-Test overlap: {len(eval_test_overlap)} IDs")

if train_eval_overlap:
    print(f"WARNING: Found {len(train_eval_overlap)} overlapping IDs between train and eval!")
    print(f"Sample overlapping IDs: {list(train_eval_overlap)[:5]}")

if train_test_overlap:
    print(f"WARNING: Found {len(train_test_overlap)} overlapping IDs between train and test!")
    print(f"Sample overlapping IDs: {list(train_test_overlap)[:5]}")

if eval_test_overlap:
    print(f"WARNING: Found {len(eval_test_overlap)} overlapping IDs between eval and test!")
    print(f"Sample overlapping IDs: {list(eval_test_overlap)[:5]}")

if not any([train_eval_overlap, train_test_overlap, eval_test_overlap]):
    print("✓ No overlapping IDs found between splits!")

# Create the split dictionaries
print("\nCreating split dictionaries...")
train_test_id_split_dict = {
    "attr_key": "input_ids",
    "train": train_ids + eval_ids,  # Combine train and eval for training
    "test": test_ids
}

train_valid_id_split_dict = {
    "attr_key": "input_ids",
    "train": train_ids,
    "eval": eval_ids
}

# Save the dictionaries as JSON
print("Saving split dictionaries as JSON...")
train_test_path = Path(output_dir) / "train_test_id_split_dict.json"
train_valid_path = Path(output_dir) / "train_valid_id_split_dict.json"

with open(train_test_path, 'w') as f:
    json.dump(train_test_id_split_dict, f, indent=2)

with open(train_valid_path, 'w') as f:
    json.dump(train_valid_id_split_dict, f, indent=2)

print(f"✓ Train-test split dictionary saved to: {train_test_path}")
print(f"✓ Train-validation split dictionary saved to: {train_valid_path}")

# Print summary statistics
print("\n" + "="*50)
print("SUMMARY")
print("="*50)
print(f"Total samples processed: {len(train_ids) + len(eval_ids) + len(test_ids)}")
print(f"Train samples: {len(train_ids)}")
print(f"Validation samples: {len(eval_ids)}")
print(f"Test samples: {len(test_ids)}")
print(f"Train+Val for final training: {len(train_ids + eval_ids)}")
print("\nSplit dictionaries saved successfully!")

# Optional: Save a human-readable summary
summary_path = Path(output_dir) / "split_summary.txt"
with open(summary_path, 'w') as f:
    f.write(f"Dataset Split Summary\n")
    f.write(f"===================\n\n")
    f.write(f"Source datasets:\n")
    f.write(f"- Train dataset: {train_dataset_path}\n")
    f.write(f"- Test dataset: {test_dataset_path}\n\n")
    f.write(f"Split configuration:\n")
    f.write(f"- Eval size: {eval_size}\n")
    f.write(f"- Stratify by: cell_type_backup\n")
    f.write(f"- Random seed: 42\n\n")
    f.write(f"Final splits:\n")
    f.write(f"- Train samples: {len(train_ids)}\n")
    f.write(f"- Validation samples: {len(eval_ids)}\n")
    f.write(f"- Test samples: {len(test_ids)}\n")
    f.write(f"- Total samples: {len(train_ids) + len(eval_ids) + len(test_ids)}\n\n")
    f.write(f"ID Overlaps:\n")
    f.write(f"- Train-Eval: {len(train_eval_overlap)}\n")
    f.write(f"- Train-Test: {len(train_test_overlap)}\n")
    f.write(f"- Eval-Test: {len(eval_test_overlap)}\n")

print(f"✓ Summary saved to: {summary_path}")