In [None]:
import numpy as np
import scanpy as sc
from sklearn.model_selection import train_test_split
from geneformer import TranscriptomeTokenizer
import scipy.sparse as sp
import datetime
from geneformer import Classifier
from collections import Counter
import os
import matplotlib.pyplot as plt
import pandas as pd
import urllib.request
import pickle
from huggingface_hub import snapshot_download
from utils import plot_confusion, plot_cell_type_distribution, get_high_fraction_celltype_indices

# Define the repository ID and local directory
repo_id = "ctheodoris/Geneformer"
local_dir = "."

# Download all files in the repository
snapshot_download(repo_id, local_dir=local_dir, allow_patterns=["gf-6L-30M-i2048/*"])

print(f"All files downloaded to: {local_dir}")

base_url = "https://huggingface.co/ctheodoris/Geneformer/resolve/main/geneformer/gene_dictionaries_30m/"
files = [
    "ensembl_mapping_dict_gc30M.pkl",
    "gene_median_dictionary_gc30M.pkl",
    "gene_name_id_dict_gc30M.pkl",
    "token_dictionary_gc30M.pkl"
]

output_dir = "./gene_dictionaries_30m"
os.makedirs(output_dir, exist_ok=True)

for file in files:
    output_file = os.path.join(output_dir, file)
    if not os.path.exists(output_file):
        print(f"Downloading {file}...")
        urllib.request.urlretrieve(base_url + file, output_file)
        print(f"Downloaded {file}")
    else:
        print(f"{file} already exists.")

# Load the data
cell_file = "data/cells.npy"
cells = np.load(cell_file, allow_pickle=True).ravel()[0]

# Extract data
expressions = cells["UMI"].toarray()  # Gene expression matrix (n_cells x n_genes)
gene_names = cells["gene_ids"]  # Gene names
cell_types = cells["classes"]  # Cell types (n_cells,)

plot_cell_type_distribution(cell_types)
high_fraction_indices = get_high_fraction_celltype_indices(cell_types, 0.05)
expressions = expressions[high_fraction_indices]
cell_types = cell_types[high_fraction_indices]

# Create a DataFrame for stratified sampling
cell_df = pd.DataFrame({"cell_types": cell_types})

# Perform stratified sampling to select 10% of the data
_, subsample_indices = train_test_split(
    np.arange(len(cell_types)),  # Use indices for subsampling
    test_size=0.01,  # 10% subsample
    stratify=cell_df["cell_types"],  # Stratify by cell types
    random_state=42  # For reproducibility
)

print(f"Original dataset size: {len(cell_types)}")

# Subset the data based on sampled indices
expressions = expressions[subsample_indices, :]  # Subset expression matrix
cell_types = cell_types[subsample_indices]  # Subset cell types
# Output sizes for verification
print(f"Subsampled dataset size: {len(cell_types)}")

plot_cell_type_distribution(cell_types)

# Example data
# Replace `expressions`, `cell_types`, and `gene_names` with your actual data
adata = sc.AnnData(X=expressions)
adata.obs["cell_types"] = cell_types
adata.var_names = gene_names
adata.var["ensembl_id"] = gene_names
adata.obs["n_counts"] = adata.X.sum(1)  # total read count per cell
adata.obs["cell_id"] = adata.obs_names.values

# Convert matrix to sparse format if not already
if not sp.issparse(adata.X):
    adata.X = sp.csr_matrix(adata.X)

# Save the AnnData object
adata.write_h5ad("data/adata.h5ad")

from geneformer import TranscriptomeTokenizer

tokenizer = TranscriptomeTokenizer(
    custom_attr_name_dict={"cell_types": "cell_types", "cell_id": "cell_id"},
    model_input_size=2048,  # For 30M model series
    special_token=False,   # 30M models require this to be False
    gene_median_file="./gene_dictionaries_30m/gene_median_dictionary_gc30M.pkl",
    token_dictionary_file="./gene_dictionaries_30m/token_dictionary_gc30M.pkl",
    gene_mapping_file="./gene_dictionaries_30m/ensembl_mapping_dict_gc30M.pkl"
)

tokenizer.tokenize_data(
    data_directory="./data",
    output_directory="./tokenized_data",
    output_prefix="my_dataset",
    file_format="h5ad",
    use_generator=False
)

# Example: Map cell types to numeric IDs
cell_types_v2 = list(adata.obs["cell_types"].unique())
id_class_dict = {i: class_id for i, class_id in enumerate(cell_types_v2)}

# Save the dictionary
with open("./tokenized_data/my_dataset_id_class_dict.pkl", "wb") as f:
    pickle.dump(id_class_dict, f)

current_date = datetime.datetime.now()
datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}{current_date.hour:02d}{current_date.minute:02d}{current_date.second:02d}"
datestamp_min = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}"

output_prefix = "cm_classifier_test"
output_dir = f"output_directory/{datestamp}"
os.makedirs(output_dir, exist_ok=True)

filter_data_dict={"cell_types":list(adata.obs["cell_types"].unique())}
training_args = {
    "num_train_epochs": 5,
    "learning_rate": 0.000804,
    "lr_scheduler_type": "polynomial",
    "warmup_steps": 1812,
    "weight_decay":0.258828,
    "per_device_train_batch_size": 12,
    "seed": 73,
}
cc = Classifier(classifier="cell",
                cell_state_dict = {"state_key": "cell_types", "states": "all"},
                filter_data=None, #none = fine tune with all input data
                training_args=training_args,
                max_ncells=None,
                freeze_layers = 2, # freeze the last 2 layer of the model
                num_crossval_splits = 1, #only 1 train, test and eval. no cross validation
                forward_batch_size=200,
                nproc=1)
                #rare_threshold=.05)

# Step 1: Split into train (70%) and temp (30%)
train_indices, temp_indices = train_test_split(
    np.arange(len(cell_types)), test_size=0.3, random_state=42, stratify=cell_types
)

# Step 2: Split temp (30%) into validation (15%) and test (15%)
cell_types_temp = cell_types[temp_indices]
eval_indices, test_indices = train_test_split(
    temp_indices, test_size=0.5, random_state=42, stratify=cell_types_temp
)

# Get cell IDs corresponding to these indices
train_ids = adata.obs['cell_id'].iloc[train_indices].tolist()
eval_ids = adata.obs['cell_id'].iloc[eval_indices].tolist()
test_ids = adata.obs['cell_id'].iloc[test_indices].tolist()

# Output the sizes
print(f"Total samples: {len(cell_types)}")
print(f"Training samples: {len(train_ids)}")
print(f"Validation samples: {len(eval_ids)}")
print(f"Test samples: {len(test_ids)}")


# Verify that all classes are present in each split
def print_class_distribution(indices, split_name):
    split_cell_types = cell_types[indices]
    class_counts = Counter(split_cell_types)
    print(f"Class distribution in {split_name} set:")
    for cls, count in class_counts.items():
        print(f"  {cls}: {count}")

print_class_distribution(train_indices, 'training')
print_class_distribution(eval_indices, 'validation')
print_class_distribution(test_indices, 'test')


In [None]:

# Prepare data with correct 'attr_key' and 'train', 'test' IDs
train_test_id_split_dict = {
    "attr_key": "cell_id",
    "train": train_ids + eval_ids,
    "test": test_ids
}


# Run prepare_data to create the labeled dataset
cc.prepare_data(
    input_data_file="tokenized_data/my_dataset.dataset",
    output_directory=output_dir,
    output_prefix=output_prefix,
    split_id_dict=train_test_id_split_dict
)

from datasets import load_from_disk

# Run prepare_data to create the labeled dataset
cc.prepare_data(
    input_data_file="tokenized_data/my_dataset.dataset",
    output_directory=output_dir,
    output_prefix=output_prefix,
    split_id_dict=train_test_id_split_dict
)

# Define the path to save the labeled dataset
labeled_dataset_path = f"{output_dir}/{output_prefix}_labeled_train.dataset"

# Load the dataset from the path
labeled_dataset = load_from_disk(labeled_dataset_path)

print("Columns in labeled dataset:", labeled_dataset.column_names)
assert 'label' in labeled_dataset.column_names, "The 'label' column is missing in the labeled dataset."

# Confirm save
print(f"Labeled dataset saved at {labeled_dataset_path}")


In [None]:
# Update the train_valid_id_split_dict to use 'cell_id'
train_valid_id_split_dict = {
    "attr_key": "cell_id",
    "train": train_ids,
    "eval": eval_ids
}


test_number = 5
output_directory = os.path.abspath(f"./results/{test_number}")
model_directory = os.path.abspath("gf-6L-30M-i2048")
prepared_input_data_file = os.path.abspath(f"{output_dir}/{output_prefix}_labeled_train.dataset")
id_class_dict_file = os.path.abspath("./tokenized_data/my_dataset_id_class_dict.pkl")
os.makedirs(f"./results/{test_number}", exist_ok=True)

# Validate using the labeled dataset
all_metrics = cc.validate(
    model_directory=model_directory,
    prepared_input_data_file=prepared_input_data_file,
    id_class_dict_file=id_class_dict_file,
    output_directory=output_directory,
    output_prefix="my_fine_tuned_model",
    split_id_dict=train_valid_id_split_dict,
    n_hyperopt_trials=10
)

In [None]:
import os
from ray.tune import ExperimentAnalysis

# Convert result_dir to an absolute path
result_dir = os.path.abspath("results/5/241120_geneformer_cellClassifier_my_fine_tuned_model/ksplit1/_objective_2024-11-20_12-16-10")

# Load experiment analysis
analysis = ExperimentAnalysis(result_dir)

# Get the best trial based on `eval_macro_f1` (maximize)
best_trial = analysis.get_best_trial(metric="eval_macro_f1", mode="max")
best_config = analysis.get_best_config(metric="eval_macro_f1", mode="max")
best_checkpoint = analysis.get_best_checkpoint(best_trial, metric="eval_macro_f1", mode="max")

# Print results
print("Best trial hyperparameters:", best_config)
print("Best checkpoint path:", best_checkpoint)

checkpoint_files_path = os.path.join(best_checkpoint.path, "checkpoint-34")
print(checkpoint_files_path)

from ray.tune.analysis.experiment_analysis import ExperimentAnalysis

# Dynamically set them in the Classifier
cc = Classifier(
    classifier="cell",
    cell_state_dict = {"state_key": "cell_types", "states": "all"},
    forward_batch_size=200,
    nproc=1,
    training_args=best_config,
    #rare_threshold = .05
)

# Evaluate the saved model using the best checkpoint
all_metrics = cc.evaluate_saved_model(
    model_directory=os.path.abspath(checkpoint_files_path),
    id_class_dict_file=f"{output_dir}/{output_prefix}_id_class_dict.pkl",
    test_data_file=f"{output_dir}/{output_prefix}_labeled_test.dataset",
    output_directory=output_dir,
    output_prefix=output_prefix,
    predict=True,  # Set to False if predictions are not required
)


In [None]:
for key, val in all_metrics.items():
    print(f"{key}:{val}")
    print("")

In [None]:
plot_confusion(all_metrics)

In [None]:
# Define paths for predictions and id-class dictionary
predictions_file = os.path.join(output_dir, f"{output_prefix}_pred_dict.pkl")
id_class_dict_file = os.path.join(output_dir, f"{output_prefix}_id_class_dict.pkl")

# Verify that the files exist
if not os.path.exists(predictions_file):
    raise FileNotFoundError(f"Predictions file not found: {predictions_file}")
if not os.path.exists(id_class_dict_file):
    raise FileNotFoundError(f"ID-Class Dictionary file not found: {id_class_dict_file}")

# Define the custom class order

# Plot predictions
cc.plot_predictions(
    predictions_file=predictions_file,
    id_class_dict_file=id_class_dict_file,
    title="cell type",
    output_directory=output_dir,
    output_prefix=output_prefix,
)


In [1]:
import datetime

In [7]:
datetime.datetime.now().strftime("%y%m%d")  # Format: DDMMYY


'241120'