In [1]:
import os

# Check current directory
print("Current directory:", os.getcwd())

# Change to your code directory (replace with your actual path)
os.chdir('/hpcfs/groups/phoenix-hpc-mangiola_laboratory/haroon/full_frozen_geneformer/Geneformer/geneformer')

# Verify the change
print("New directory:", os.getcwd())

Current directory: /scratchdata1/users/a1841503/Geneformer/Jupyter
New directory: /scratchdata1/groups/phoenix-hpc-mangiola_laboratory/haroon/full_frozen_geneformer/Geneformer/geneformer


In [2]:
import os
import sys

# Change to the repository root
os.chdir("/scratchdata1/groups/phoenix-hpc-mangiola_laboratory/haroon/full_frozen_geneformer/Geneformer")

# Add current directory to Python path
if os.getcwd() not in sys.path:
    sys.path.insert(0, os.getcwd())

In [3]:
from geneformer.finetuner import FineTuner

In [4]:
from geneformer.finetuner_utils import get_train_valid_test_splits

In [5]:
import json

In [6]:
!pwd

/scratchdata1/groups/phoenix-hpc-mangiola_laboratory/haroon/full_frozen_geneformer/Geneformer


In [None]:
task_config = {
        "tasks": {
        "disease_classification": [
            "genecorpus_heart_disease", 
            "cellnexus_blood_disease", 
            "cellnexus_covid_disease"
            ],
        # "dosage_sensitivity": ["genecorpus_dosage_sensitivity"]
        },
        "aggregation_levels": [
            # "metacell_2", 
            # "metacell_4", 
            # "metacell_8", 
            # "metacell_16", 
            "metacell_32", 
            "metacell_64", 
            "metacell_128"
            ]
    }

# aggregation_level="metacell_32"
# task="dosage_sensitivity"
# dataset="genecorpus_dosage_sensitivity"
model_version="V1"
base_dir= "/hpcfs/groups/phoenix-hpc-mangiola_laboratory/haroon/geneformer"
model_variant="30M"
crossval_splits = 1 # 1 or 5
freeze_num_encoder_layers=6
freeze_entire_model=True

    # Loop over tasks
for task, datasets in task_config["tasks"].items():
    # Loop over each dataset associated with the task
    for dataset in datasets:
        # Loop over aggregation levels
        for aggregation_level in task_config["aggregation_levels"]:
            print(f"Running task={task}, dataset={dataset}, aggregation_level={aggregation_level}")

            for crossval_split in range(1, crossval_splits + 1):

                crossval_split_metrics = {}



                training_args = {
                "num_train_epochs": 10,
                "learning_rate": 0.000804,
                "lr_scheduler_type": "polynomial",
                "warmup_steps": 1812,
                "weight_decay":0.258828,
                "per_device_train_batch_size": 128,
                "seed": 73,
                "evaluation_strategy":"epoch",        # Evaluate every epoch
                "save_strategy":"epoch",              # Save checkpoint every epoch
                "metric_for_best_model":"eval_loss",  # Metric to determine "best" model # Doc: https://huggingface.co/transformers/v3.5.1/main_classes/trainer.html#:~:text=after%20each%20evaluation.-,metric_for_best_model,-(str%2C
                "greater_is_better":False,            # For loss, lower is better
                "load_best_model_at_end":True,        # KEY: Load best model at the end
                "save_total_limit":3,                 # Keep only 3 best checkpoints
                # "logging_dir": os.path.normpath("D:/geneformer_finetuning/trained_cell_classification_models/disease_classification/genecorpus_heart_disease/30M_metacell_8/250623_geneformer_cellClassifier_genecorpus_heart_disease_test/ksplit1/runs"),
                }
                

                input_data_file, cell_state_dict, filter_data_dict, train_test_id_split_dict, train_valid_id_split_dict = get_train_valid_test_splits(
                    TASK=task,
                    DATASET=dataset,
                    MODEL_VARIANT=model_variant,
                    DATASET_PATH = os.path.join(base_dir, "datasets", task, dataset)
,
                    CROSSVAL_SPLITS=crossval_splits,
                )

                finetuner = FineTuner(base_dir=base_dir,
                    aggregation_level=aggregation_level,
                    model_variant=model_variant,
                    task=task,
                    dataset=dataset)
                
                output_prefix = "test" if crossval_splits == 1 else "ksplit" + str(crossval_split)
    
                all_metrics = finetuner.finetune_model(
                    training_args=training_args, 
                    cell_state_dict = cell_state_dict, 
                    filter_data_dict = filter_data_dict, 
                    input_data_file = input_data_file, 
                    output_prefix = output_prefix, 
                    train_test_id_split_dict = train_test_id_split_dict, 
                    train_valid_id_split_dict = train_valid_id_split_dict, 
                    num_crossval_splits=1 if task != "dosage_sensitivity" else crossval_splits,
                    freeze_num_encoder_layers=freeze_num_encoder_layers,
                    freeze_entire_model=freeze_entire_model
                    )
                
                crossval_split_metrics[str(crossval_split)] = all_metrics

            metrics_path = os.path.join(finetuner.output_dir, "metrics.json")

            # Save the metrics to that path
            with open(metrics_path, "w") as f:
                json.dump(crossval_split_metrics, f, indent=4)

Running task=disease_classification, dataset=genecorpus_heart_disease, aggregation_level=metacell_32
Cuda available: True
Using 4 GPU(s)
GENE_MEDIAN_FILE: /hpcfs/groups/phoenix-hpc-mangiola_laboratory/haroon/geneformer/Geneformer/geneformer/gene_dictionaries_30m/gene_median_dictionary_gc30M.pkl
TOKEN_DICTIONARY_FILE: /hpcfs/groups/phoenix-hpc-mangiola_laboratory/haroon/geneformer/Geneformer/geneformer/gene_dictionaries_30m/token_dictionary_gc30M.pkl
ENSEMBL_MAPPING_FILE: /hpcfs/groups/phoenix-hpc-mangiola_laboratory/haroon/geneformer/Geneformer/geneformer/gene_dictionaries_30m/ensembl_mapping_dict_gc30M.pkl
Output directory: /hpcfs/groups/phoenix-hpc-mangiola_laboratory/haroon/geneformer/trained_cell_classification_models/disease_classification/genecorpus_heart_disease/30M_metacell_32
Classifier type: cell
✓ Using final trained model: /hpcfs/groups/phoenix-hpc-mangiola_laboratory/haroon/geneformer/trained_foundation_models/models/30M_AGGmetacell_32_6_emb256_SL2048_E2_B12_LR0.001_LSline

mkdir: cannot create directory ‘/hpcfs/groups/phoenix-hpc-mangiola_laboratory/haroon/geneformer/trained_cell_classification_models/disease_classification/genecorpus_heart_disease/30M_metacell_32/_geneformer_cellClassifier_test/’: File exists


  0%|          | 0/1 [00:00<?, ?it/s]

mkdir: cannot create directory ‘/hpcfs/groups/phoenix-hpc-mangiola_laboratory/haroon/geneformer/trained_cell_classification_models/disease_classification/genecorpus_heart_disease/30M_metacell_32/_geneformer_cellClassifier_test/ksplit1’: File exists
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at /hpcfs/groups/phoenix-hpc-mangiola_laboratory/haroon/geneformer/trained_foundation_models/models/30M_AGGmetacell_32_6_emb256_SL2048_E2_B12_LR0.001_LSlinear_WU10000_Oadamw/final_trained_model/metacell_32 and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


****** Validation split: 1/1 ******

Total parameters: 10,263,555
Trainable parameters before freezing model: 10,263,555
Non-trainable parameters before freezing model: 0
Trainable parameters after freezing encoder layers: 7,100,931
Non-trainable parameters after freezing encoder layers: 3,162,624
Freezing parameter: bert.embeddings.word_embeddings.weight
Freezing parameter: bert.embeddings.position_embeddings.weight
Freezing parameter: bert.embeddings.token_type_embeddings.weight
Freezing parameter: bert.embeddings.LayerNorm.weight
Freezing parameter: bert.embeddings.LayerNorm.bias
Freezing parameter: bert.encoder.layer.0.attention.self.query.weight
Freezing parameter: bert.encoder.layer.0.attention.self.query.bias
Freezing parameter: bert.encoder.layer.0.attention.self.key.weight
Freezing parameter: bert.encoder.layer.0.attention.self.key.bias
Freezing parameter: bert.encoder.layer.0.attention.self.value.weight
Freezing parameter: bert.encoder.layer.0.attention.self.value.bias
Freezi

  batch = {k: torch.tensor(v.clone().detach(), dtype=torch.int64) if isinstance(v, torch.Tensor) else torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}


Epoch,Training Loss,Validation Loss


KeyboardInterrupt: 

In [11]:
import os
from datasets import Dataset
import pandas as pd
from collections import Counter

# Read the HuggingFace dataset
dataset_path = "/hpcfs/groups/phoenix-hpc-mangiola_laboratory/haroon/geneformer/datasets/disease_classification/genecorpus_heart_disease/human_dcm_hcm_nf.dataset/"

print("Loading dataset...")
try:
    # Load the dataset
    dataset = Dataset.load_from_disk(dataset_path)
    print(f"Dataset loaded successfully!")
    print(f"Dataset length: {len(dataset)}")
    print(f"Dataset features: {list(dataset.features.keys())}")
    print()
    
    # Convert to pandas for easier analysis
    df = dataset.to_pandas()
    
    # Check if required columns exist
    if 'individual' not in df.columns or 'disease' not in df.columns:
        print("Available columns:", df.columns.tolist())
        print("Please check if 'individual' and 'disease' columns exist with these exact names")
    else:
        print("=== ANALYSIS RESULTS ===")
        print()
        
        # 1. Number of unique individuals
        unique_individuals = df['individual'].nunique()
        print(f"Number of unique individuals: {unique_individuals}")
        print()
        
        # 2. List all unique individuals
        individual_ids = df['individual'].unique()
        print("Unique individual IDs:")
        for i, ind_id in enumerate(sorted(individual_ids), 1):
            print(f"  {i}. {ind_id}")
        print()
        
        # 3. Overall disease distribution
        disease_counts = df['disease'].value_counts()
        print("Overall disease distribution:")
        for disease, count in disease_counts.items():
            print(f"  {disease}: {count} samples")
        print()
        
        # 4. Disease distribution per individual
        print("Disease distribution per individual:")
        print("=" * 50)
        
        for ind_id in sorted(individual_ids):
            individual_data = df[df['individual'] == ind_id]
            disease_dist = individual_data['disease'].value_counts()
            total_samples = len(individual_data)
            
            print(f"\nIndividual ID: {ind_id}")
            print(f"Total samples: {total_samples}")
            print("Disease distribution:")
            for disease, count in disease_dist.items():
                percentage = (count / total_samples) * 100
                print(f"  - {disease}: {count} samples ({percentage:.1f}%)")
        
        # 5. Summary statistics
        print("\n" + "=" * 50)
        print("SUMMARY STATISTICS")
        print("=" * 50)
        
        # Samples per individual
        samples_per_individual = df.groupby('individual').size()
        print(f"\nSamples per individual:")
        print(f"  Min: {samples_per_individual.min()}")
        print(f"  Max: {samples_per_individual.max()}")
        print(f"  Mean: {samples_per_individual.mean():.1f}")
        print(f"  Median: {samples_per_individual.median():.1f}")
        
        # Cross-tabulation
        print(f"\nCross-tabulation (Individual vs Disease):")
        crosstab = pd.crosstab(df['individual'], df['disease'], margins=True)
        print(crosstab)
        
        # 6. Check for any individuals with multiple diseases
        print(f"\nIndividuals with multiple disease types:")
        multi_disease_individuals = df.groupby('individual')['disease'].nunique()
        multi_disease = multi_disease_individuals[multi_disease_individuals > 1]
        
        if len(multi_disease) > 0:
            for ind_id, num_diseases in multi_disease.items():
                diseases = df[df['individual'] == ind_id]['disease'].unique()
                print(f"  {ind_id}: {num_diseases} diseases - {list(diseases)}")
        else:
            print("  No individuals have multiple disease types")

except Exception as e:
    print(f"Error loading dataset: {str(e)}")
    print()
    print("Troubleshooting steps:")
    print("1. Check if the path exists")
    print("2. Verify the dataset format")
    print("3. Check column names")
    
    # Try to list directory contents
    try:
        if os.path.exists(dataset_path):
            contents = os.listdir(dataset_path)
            print(f"Directory contents: {contents}")
        else:
            print("Dataset path does not exist")
    except Exception as e2:
        print(f"Cannot access directory: {str(e2)}")

Loading dataset...
Dataset loaded successfully!
Dataset length: 579159
Dataset features: ['input_ids', 'length', 'cell_type', 'individual', 'age', 'sex', 'disease', 'lvef']

=== ANALYSIS RESULTS ===

Number of unique individuals: 42

Unique individual IDs:
  1. 1290
  2. 1300
  3. 1304
  4. 1358
  5. 1371
  6. 1422
  7. 1425
  8. 1430
  9. 1437
  10. 1447
  11. 1462
  12. 1472
  13. 1479
  14. 1504
  15. 1508
  16. 1510
  17. 1515
  18. 1516
  19. 1539
  20. 1540
  21. 1547
  22. 1549
  23. 1558
  24. 1561
  25. 1582
  26. 1600
  27. 1602
  28. 1603
  29. 1606
  30. 1610
  31. 1617
  32. 1622
  33. 1630
  34. 1631
  35. 1678
  36. 1685
  37. 1702
  38. 1707
  39. 1718
  40. 1722
  41. 1726
  42. 1735

Overall disease distribution:
  hcm: 230652 samples
  nf: 182317 samples
  dcm: 166190 samples

Disease distribution per individual:

Individual ID: 1290
Total samples: 6996
Disease distribution:
  - dcm: 6996 samples (100.0%)

Individual ID: 1300
Total samples: 16426
Disease distribution