In [2]:
import sys
from pathlib import Path

# Get the root directory of the project
project_root = Path("/home/lxz/scmamba/KCellFM_tutorial/novel_cell_classification_bert").parent.parent
# project_root = Path(__file__).parent.parent
sys.path.append(str(project_root))

In [3]:
import os
import json
import pickle
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast
from tqdm import tqdm
from sklearn.metrics import accuracy_score, classification_report, f1_score
from sklearn.metrics.pairwise import cosine_similarity
from scipy.stats import spearmanr
from scipy.spatial.distance import cdist
from scipy.stats import rankdata
from transformers import AutoModel, AutoTokenizer, BertModel, BertConfig
from models.model import MambaModel
from models.gene_tokenizer import GeneVocab

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
# configuration parameters
class Config:
    # Data path
    pretrain_subset_path = "/mnt/HHD16T/DATA/lxz/sctab/merlin_cxg_2023_05_15_sf-log1p/novel_cell_clssification/train_data_subset_0_01.h5ad"
    var_parquet_path = "/mnt/HHD16T/DATA/lxz/sctab/merlin_cxg_2023_05_15_sf-log1p/var.parquet"
    cell_type_parquet_path = "/mnt/HHD16T/DATA/lxz/sctab/merlin_cxg_2023_05_15_sf-log1p/categorical_lookup/cell_type.parquet"
    converted_data_dir = Path(
        "/mnt/HHD16T/DATA/lxz/sctab/merlin_cxg_2023_05_15_sf-log1p/novel_cell_clssification/scCello_ood_celltype_data1/filtered_data_10_percent")
    json_relationship_path = Path("../data/celltype_relationship.json")

    triples_ontology_path = "/home/lxz/scmamba/novel_cell_classification_bert/data/triples.txt"

    # Output path
    ontology_graph_path = "../data/cell_ontology_graph.json"
    id_to_node_path = "../data/cell_id_to_node_id_mapping.json"
    id_to_name_path = "../data/cell_id_to_cell_name_mapping.json"
    cell_type_repr_path = "../data/cell_type_representations_mamba.json"
    novel_cell_emb_path = "/mnt/HHD16T/DATA/lxz/sctab/merlin_cxg_2023_05_15_sf-log1p/novel_cell_clssification/data_hard_disk/mamba_bert/novel_cell_embeddings_mamba_bert.pkl"
    cell_type_label_path = "/home/lxz/scmamba/novel_cell_classification_bert/data/novel_cell_name_to_label_mamba_bert.json"
    results_output_path = "/home/lxz/scmamba/novel_cell_classification_bert/results/novel_cell_classification_results_mamba_bert.csv"

    # Pre-trained model configuration
    pretrained_model_path = "/home/lxz/scmamba/model_state/cell_cls_3loss_6layer_final.pth"
    gene_vocab_path = "/home/lxz/scmamba/vocab.json"
    ensembl_ID_to_gene_name_path = "/home/lxz/scmamba/novel_cell_classification/data/ensembl_ID_to_gene_name_dict_gc30M.pkl"

    vocab = GeneVocab.from_file(gene_vocab_path)

    # PubMedbert model configuration
    pretrained_bert_model_path = "/home/lxz/PubMedbert"
    pretrained_bert_checkpoint_path = "/home/lxz/PubMedbert/finetune_bert.pth"

    # Pre-trained model parameters
    max_seq_len = 4096  # Maximum length of sequence
    ntoken = len(vocab)  # Vocabulary size, to be loaded from the vocabulary later
    embsize = 512  # Embedding dimension
    nhead = 8  # Attention head count
    d_hid = 512  # Hidden layer dimension
    nlayers = 6  # layers
    dropout = 0.1  # dropout
    pad_token = "<pad>"  # <pad> token
    pad_value = -2  # <pad> value
    input_emb_style = "continuous"  # Input embedding style
    cell_emb_style = "cls"  # Cell embedding style
    class_num = 164  # Number of categories

    # Device setting
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    batch_size = 64
    num_workers = 8

    # PPR parameters
    alpha = 0.9
    threshold = 1e-4

    # Difficulty grading parameters
    difficulty_ratios = [0.1, 0.25, 0.5, 0.75, 1.0]
    selected_ratios = difficulty_ratios 
    num_samplings = 20

In [5]:
def extract_is_a_cells(triplet_path):
    """Extract all cell types of is_a relationships from triplet.txt"""
    is_a_cells = set()
    try:
        with open(triplet_path, 'r', encoding='utf-8') as f:
            for line_num, line in enumerate(f, 1):
                line = line.strip()
                if not line:
                    continue

                if ' is_a ' in line:
                    cell1, cell2 = line.split(' is_a ', 1)
                    cell1 = cell1.strip()
                    cell2 = cell2.strip()

                    is_a_cells.add(cell1)
                    is_a_cells.add(cell2)
                else:
                    continue  # Skip rows with non is_a relationships

        print(f"Successfully extracted{len(is_a_cells)}unique is'a related cell types")
        return is_a_cells

    except FileNotFoundError:
        print(f"Error: Triplet.txt file not found（Path：{triplet_path}）")
        return set()
    except Exception as e:
        print(f"Error: processing triplet.txt：{str(e)}")
        return set()

In [6]:
class MambaEmbeddingExtractor(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.pad_token_id = config.vocab[config.pad_token]

        # Load gene vocabulary list
        self.vocab = GeneVocab.from_file(config.gene_vocab_path)
        config.ntoken = len(self.vocab)  # Dynamically set vocabulary size

        # Initialize pre-trained model
        self.model = MambaModel(
            ntoken=config.ntoken,
            embsize=config.embsize,
            nhead=config.nhead,
            d_hid=config.d_hid,
            nlayers=config.nlayers,
            dropout=config.dropout,
            pad_token_id=self.pad_token_id,
            input_emb_style=config.input_emb_style,
            cell_emb_style=config.cell_emb_style
        )

        # Load pre-trained weights
        self._load_pretrained_weights(config.pretrained_model_path, config.device)

        # Do not fine tune the model, freeze all parameters
        for param in self.model.parameters():
            param.requires_grad = False

        self.pad_token_id = self.vocab[self.vocab.pad_token] if self.vocab.pad_token else 0
        self.cls_token_id = self.vocab["<cls>"] if "<cls>" in self.vocab else 1
        self.max_seq_len = config.max_seq_len

    def _load_pretrained_weights(self, model_path, device):
        try:
            pretrained_dict = torch.load(model_path, map_location=device)
            model_dict = self.model.state_dict()

            # Filter out unmatched weights
            pretrained_dict = {k: v for k, v in pretrained_dict.items()
                               if k in model_dict and v.shape == model_dict[k].shape}

            model_dict.update(pretrained_dict)
            self.model.load_state_dict(model_dict)
            print("Successfully loaded pre-trained weights")
        except Exception as e:
            print(f"Failed to load pre-trained weights: {str(e)}")
            raise

    def forward(self, input_ids, values, attention_mask=None):
        src_key_padding_mask = (input_ids == self.pad_token_id) if attention_mask is None else attention_mask.bool()
        output = self.model(src=input_ids, values=values, src_key_padding_mask=src_key_padding_mask)
        return output["cell_emb"]

In [7]:
class finetune_BERT(nn.Module):
    def __init__(self, bert_model):
        super().__init__()
        self.bert = bert_model
        self.projection = nn.Sequential(
            nn.Linear(768, 512),
            nn.ReLU(),
            nn.LayerNorm(512),
            nn.Dropout(0)
        )

    def forward(self, **inputs):
        outputs = self.bert(**inputs)
        pooled = outputs.last_hidden_state.mean(dim=1)
        return self.projection(pooled)


# 4. PubMedBERT embedding extractor
class BertEmbeddingExtractor(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.tokenizer = AutoTokenizer.from_pretrained(config.pretrained_bert_model_path)
        self.finetuned_model = finetune_BERT(
            AutoModel.from_pretrained(config.pretrained_bert_model_path)
        ).to(config.device)

        # Load pre training weights and set them as evaluation mode
        self.checkpoint = torch.load(config.pretrained_bert_checkpoint_path, map_location=config.device)
        self.finetuned_model.bert.load_state_dict(self.checkpoint['bert_state'])
        self.finetuned_model.projection.load_state_dict(self.checkpoint['projection_state'])
        self.finetuned_model.eval()

    def get_embedding(self, text):
        model = self.finetuned_model
        tokenizer = self.tokenizer
        device = self.config.device

        inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=25).to(device)
        with torch.no_grad():
            output = model(**inputs)
        return output.cpu().numpy()


In [8]:
class NovelCellDataset(Dataset):
    def __init__(self, parquet_files, ensembl_to_gene_name_dict):
        self.parquet_files = parquet_files
        self.ensembl_to_gene_name_dict = ensembl_to_gene_name_dict
        self.data = []

        # vocab
        self.vocab = GeneVocab.from_file(Config.gene_vocab_path)
        self.pad_token_id = self.vocab[self.vocab.pad_token] if self.vocab.pad_token else 0
        self.cls_token_id = self.vocab["<cls>"] if "<cls>" in self.vocab else 1

        # max_seq_length
        self.max_seq_length = Config.max_seq_len

        # Preloading all data
        self._load_data()

    def _load_data(self):
        """Load all .parquet file data and filter out zero expression values"""
        for file in tqdm(self.parquet_files, desc="加载新细胞数据"):
            try:
                df = pd.read_parquet(file)
                for _, row in df.iterrows():
                    # Handling illegal data: deleting the first element
                    expr_nums = row['gene_expression_nums'][1:]
                    gene_names = row['gene_ensembl_ids'][1:]

                    # Convert to numpy array for easy filtering
                    expr_array = np.array(expr_nums)
                    gene_array = np.array(gene_names, dtype=object)

                    # Mask for obtaining non-zero expression values (>0)
                    nonzero_mask = expr_array > 0
                    # Filter to obtain non-zero expression values and their corresponding gene names
                    nonzero_expr = expr_array[nonzero_mask]
                    nonzero_genes = gene_array[nonzero_mask]

                    # Map token
                    valid_tokens = []
                    valid_expr = []
                    for gene, expr in zip(nonzero_genes, nonzero_expr):
                        if gene is not None:
                            if gene.startswith("ENSG"):
                                gene_name = self.ensembl_to_gene_name_dict.get(gene, None)
                                if gene_name is not None and gene_name in self.vocab:
                                    valid_tokens.append(self.vocab[gene_name])
                                    valid_expr.append(expr)
                            else:
                                gene_name = gene
                                if gene_name is not None and gene_name in self.vocab:
                                    valid_tokens.append(self.vocab[gene_name])
                                    valid_expr.append(expr)

                    if valid_tokens:
                        # Add CLS token
                        unsorted_tokens = [self.cls_token_id] + list(valid_tokens)
                        unsorted_expr = [0.0] + list(valid_expr)

                        self.data.append({
                            "expr": unsorted_expr,
                            "tokens": unsorted_tokens,
                            "cell_type": row['cell_type']
                        })
            except Exception as e:
                print(f"Error: loading file {file.name}: {str(e)}")
                continue

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        tokens = item["tokens"]
        expr = item["expr"]
        cell_type = item["cell_type"]

        # Truncate or fill to maximum length
        if len(tokens) >= self.max_seq_length:
            input_ids = tokens[:self.max_seq_length]
            values = expr[:self.max_seq_length]
            attention_mask = [0] * self.max_seq_length  # 0 indicates validity (consistent with the src_key_madding.mask of the pre-trained model)
        else:
            padding_len = self.max_seq_length - len(tokens)
            input_ids = tokens + [self.pad_token_id] * padding_len
            values = expr + [Config.pad_value] * padding_len
            attention_mask = [0] * len(tokens) + [1] * padding_len

        return {
            "input_ids": torch.tensor(input_ids),
            "values": torch.tensor(values, dtype=torch.float32),
            "attention_mask": torch.tensor(attention_mask),
            "cell_type": cell_type
        }

In [9]:
def generate_novel_cell_embeddings():
    """Generate embedded representations for new cells"""
    if os.path.exists(Config.novel_cell_emb_path) and os.path.exists(Config.cell_type_label_path):
        print("New cell embedding and label mapping already exist, load directly...")
        with open(Config.novel_cell_emb_path, 'rb') as f:
            novel_cell_data = pickle.load(f)
        with open(Config.cell_type_label_path, 'r') as f:
            cell_type_to_label = json.load(f)
        return novel_cell_data, cell_type_to_label

    print("Start generating new cell embeddings...")
    # Generate mapping from cell types to numerical labels
    cell_types = set()
    parquet_files = sorted(Config.converted_data_dir.glob("*.parquet"))

    if not parquet_files:
        print(f"Error: Parquet file not found in{Config.converted_data_dir}")
        return None, None

    for file in tqdm(parquet_files, desc="Collect novel cell types"):
        try:
            df = pd.read_parquet(file)
            current_types = set(df['cell_type'].unique())
            cell_types.update(current_types)
        except Exception as e:
            print(f"Error: processing file {file.name}: {str(e)}")
            continue

    # Sort and generate label mappings
    sorted_cell_types = sorted(cell_types)
    cell_type_to_label = {cell_type: idx for idx, cell_type in enumerate(sorted_cell_types)}

    # Save label mapping
    with open(Config.cell_type_label_path, 'w') as f:
        json.dump(cell_type_to_label, f, indent=2)

    try:
        with open(Config.ensembl_ID_to_gene_name_path, 'rb') as f:
            ensembl_to_gene_name_dict = pickle.load(f)
    except Exception as e:
        print(f"Failed to read{Config.ensembl_ID_to_gene_name_path}：{str(e)}")
        return None, None

    # Create dataset and dataset loader
    dataset = NovelCellDataset(parquet_files, ensembl_to_gene_name_dict)
    print(f"The novel cell dataset has been loaded and contains {len(dataset)} valid cells")

    dataloader = DataLoader(
        dataset,
        batch_size=Config.batch_size,
        shuffle=False,
        num_workers=Config.num_workers,
        pin_memory=True
    )

    # load pre-trained model
    model = MambaEmbeddingExtractor(Config).to(Config.device)
    model.eval()

    # Extract embedding
    novel_cell_data = []  # Save (embedding, cell_type, label)
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Extract novel cell embeddings"):
            input_ids = batch["input_ids"].to(Config.device)
            values = batch["values"].to(Config.device)
            attention_mask = batch["attention_mask"].to(Config.device)
            cell_types = batch["cell_type"]

            with autocast():
                embeddings = model(input_ids, values, attention_mask=attention_mask)

            embeddings_np = embeddings.cpu().numpy()

            # Save embeddings and corresponding cell types
            for emb, cell_type in zip(embeddings_np, cell_types):
                label = cell_type_to_label[cell_type]
                novel_cell_data.append((emb, cell_type, label))

    # Save novel cell embeddings
    with open(Config.novel_cell_emb_path, 'wb') as f:
        pickle.dump(novel_cell_data, f)

    print(f"Novel cell embedding has been saved to {Config.novel_cell_emb_path}")
    return novel_cell_data, cell_type_to_label

In [10]:
def generate_novel_cell_type_embeddings():
    if os.path.exists(Config.cell_type_label_path):
        print("The novel cell type already exists, load directly...")
        with open(Config.cell_type_label_path, 'r') as f:
            novel_cell_type_to_label = json.load(f)
    else:
        print("Novel cell type does not exist, program exits")
        exit()

    bertEmbeddingExtractor = BertEmbeddingExtractor(Config)

    novel_cell_type_emb = []
    novel_cell_type_emb_dict = {}
    for cell_type in tqdm(novel_cell_type_to_label.keys(), desc="Generate novel cell type embeddings"):
        emb = bertEmbeddingExtractor.get_embedding(cell_type)
        novel_cell_type_emb.append(emb.squeeze(0))
        novel_cell_type_emb_dict[cell_type] = emb.squeeze(0)

    novel_cell_type_emb = np.array(novel_cell_type_emb)
    return novel_cell_type_emb_dict, novel_cell_type_emb

In [11]:
def generate_ontology_cell_type_emb():
    if os.path.exists(Config.triples_ontology_path):
        print("The knowledge graph cell type already exists, load directly...")
        ontology_cell_type_unique = extract_is_a_cells(Config.triples_ontology_path)
    else:
        print("Knowledge graph cell type does not exist, program exits")
        exit()

    bertEmbeddingExtractor = BertEmbeddingExtractor(Config)

    ontology_cell_type_emb = []
    ontology_cell_type_emb_dict = {}
    for cell_type in tqdm(ontology_cell_type_unique, desc="Emb of cell types in generating knowledge graphs"):
        emb = bertEmbeddingExtractor.get_embedding(cell_type)
        ontology_cell_type_emb.append(emb.squeeze(0))
        ontology_cell_type_emb_dict[cell_type] = emb.squeeze(0)
    ontology_cell_type_emb = np.array(ontology_cell_type_emb)
    return ontology_cell_type_emb_dict, ontology_cell_type_emb

In [12]:
def get_upper_cell_name_to_lower_cell_name_mapping(cell_name_list):
    """Convert cell type names to lowercase mapping dictionary"""
    cell_name_lower_dict = {}
    for name in cell_name_list:
        new_name = name.lower()
        cell_name_lower_dict[name] = new_name
    return cell_name_lower_dict

In [13]:
def classify_novel_cells(novel_cell_data, ontology_cell_type_emb_dict, novel_cell_type_to_label):
    # Prepare data
    cell_embeddings = np.array([item[0] for item in novel_cell_data])  # Cell emb
    true_labels = np.array([item[2] for item in novel_cell_data])  # True label
    true_cell_types = [item[1] for item in novel_cell_data]  # Real cell type name

    # Retrieve the cell type (in lowercase) from novel_cell_type_to_label
    novel_types = set(novel_cell_type_to_label.keys())
    novel_types_lower = {t.lower() for t in novel_types}

    # Select the types that exist in novel_cell_type_to_label from ontology
    filtered_ontology_types = []
    filtered_ontology_embs = []

    # Create ontology type to lowercase mapping
    all_ontology_types = list(ontology_cell_type_emb_dict.keys())
    ontology_type_lower = get_upper_cell_name_to_lower_cell_name_mapping(all_ontology_types)

    # screening process
    for ont_type in all_ontology_types:
        ont_type_lower = ontology_type_lower[ont_type]
        if ont_type_lower in novel_types_lower:
            filtered_ontology_types.append(ont_type)
            filtered_ontology_embs.append(ontology_cell_type_emb_dict[ont_type])

    # Check the screening results
    print(f"Select {len(filtered_ontology_types)} types from ontology that exist in novel cell data")
    if len(filtered_ontology_types) == 0:
        print("Error: No overlapping cell types found, unable to classify")
        return None, None, None, None, None

    # Convert to numpy array
    filtered_ontology_embs = np.array(filtered_ontology_embs)

    # Create filtered type to lowercase mapping and label mapping
    filtered_type_lower = get_upper_cell_name_to_lower_cell_name_mapping(filtered_ontology_types)
    novel_type_lower = {k.lower(): v for k, v in novel_cell_type_to_label.items()}

    # Calculate the cosine similarity between all cells and the selected type embeddings
    print("Calculate the cosine similarity between cell embeddings and filtered type embeddings...")
    similarities = cosine_similarity(cell_embeddings, filtered_ontology_embs)  # Shape: [n_cells, n_filtered_types]

    ########################################
    # """Using Spearman to measure the correlation between vectors"""
    # # Convert data to rank
    # cell_ranks = rankdata(cell_embeddings, axis=1)
    # ontology_ranks = rankdata(filtered_ontology_embs, axis=1)
    #
    # # Calculate correlation (1-rank distance)
    # # Note: Correlation is used here instead of distance, so subtract 1 from it
    # similarities = 1 - cdist(cell_ranks, ontology_ranks, metric='correlation')
    ########################################

    # Predict the type of each cell
    pred_labels = []
    for i in range(len(cell_embeddings)):
        # Find the most similar screened type
        max_sim_idx = np.argmax(similarities[i])
        pred_type = filtered_ontology_types[max_sim_idx]
        pred_type_lower = filtered_type_lower[pred_type]

        # Obtain the corresponding label
        pred_label = novel_type_lower[pred_type_lower]
        pred_labels.append(pred_label)

    pred_labels = np.array(pred_labels)

    return true_labels, true_cell_types, pred_labels, similarities, filtered_ontology_types

In [14]:
def evaluate_with_difficulty_levels(true_labels, true_cell_types, pred_labels, similarities,
                                    filtered_ontology_types, novel_cell_type_to_label, selected_ratios=None):
    if selected_ratios is None:
        selected_ratios = Config.selected_ratios

    novel_type_lower = {k.lower(): k for k in novel_cell_type_to_label.keys()}
    
    filtered_to_novel_type = {}
    for ont_type in filtered_ontology_types:
        ont_type_lower = ont_type.lower()
        if ont_type_lower in novel_type_lower:
            filtered_to_novel_type[ont_type] = novel_type_lower[ont_type_lower]

    n_filtered_types = len(filtered_ontology_types)
    all_result, avg_result = [], []

    # Create a mapping from labels to cell types
    label_to_type = {v: k for k, v in novel_cell_type_to_label.items()}

    # Ensure that the output directory exists
    os.makedirs(os.path.dirname(Config.results_output_path), exist_ok=True)

    for ratio in selected_ratios:
        # Calculate the number of types to be selected this time
        num_selected = int(n_filtered_types * ratio + 0.5)
        num_selected = max(1, num_selected)  # Ensure that at least one type is selected

        print(f"\nAssess difficulty level: {ratio} (Select {num_selected}/{n_filtered_types} ontology overlap types)")

        all_acc, all_f1 = [], []
        num_samplings = 1 if ratio >= 1.0 else Config.num_samplings

        for _ in range(num_samplings):
            selected_filtered_indices = np.random.permutation(n_filtered_types)[:num_selected]
            # Retrieve the selected ontology type
            selected_filtered_types = [filtered_ontology_types[i] for i in selected_filtered_indices]
            # Map to the corresponding new cell type
            selected_novel_types = [filtered_to_novel_type[t] for t in selected_filtered_types]
            # Obtain the original label corresponding to the new cell type
            selected_novel_labels = [novel_cell_type_to_label[t] for t in selected_novel_types]

            # Create a new label mapping (renumber within the selected type)
            new_label_mapping = {old_label: idx for idx, old_label in enumerate(selected_novel_labels)}

            # Filter out cells belonging to the selected type
            used_indices = np.isin(true_labels, selected_novel_labels)

            if np.sum(used_indices) == 0:
                print("Warning: No cells of the selected type were found, skip this sampling")
                continue

            # Remap real labels
            new_true_labels = [new_label_mapping[x] for x in np.array(true_labels)[used_indices]]

            selected_similarities = similarities[used_indices][:, selected_filtered_indices]
            new_preds = np.argmax(selected_similarities, axis=1)

            # calculated metrics
            acc = accuracy_score(new_true_labels, new_preds)
            f1 = f1_score(new_true_labels, new_preds, average="macro")

            all_acc.append(acc)
            all_f1.append(f1)

            # For a complete dataset, output a detailed report
            if ratio >= 1.0:
                print("\nComplete classification report:")
                print(classification_report(
                    new_true_labels,
                    new_preds,
                    target_names=selected_novel_types,
                    digits=4,
                    zero_division=0
                ))
                break

        # calculate the average
        divnum = len(all_acc) if len(all_acc) > 0 else 1
        avg_acc = round(sum(all_acc) / divnum, 4) if divnum > 0 else 0
        avg_f1 = round(sum(all_f1) / divnum, 4) if divnum > 0 else 0

        avg_result.append((ratio, avg_acc, avg_f1))
        all_result.append((ratio, all_acc, all_f1))

        print(f"Difficulty level {ratio} -  Average accuracy: {avg_acc}, Average F1 score: {avg_f1}")

    # Save results
    avg_result_df = pd.DataFrame(avg_result, columns=["ratio", "acc", "f1"])
    avg_result_df.to_csv(Config.results_output_path, index=False)
    print(f"\nThe evaluation results have been saved to {Config.results_output_path}")

    return avg_result_df

In [15]:
def main():
    # Step 1: Extract cell types from the knowledge graph
    triples_cell_type_unique = extract_is_a_cells(Config.triples_ontology_path)
    if not triples_cell_type_unique:
        print("Unable to continue with subsequent tasks as the cell types related to is_a were not extracted")
        return

    # Step 2: Generate new cell embeddings
    novel_cell_data, cell_type_to_label = generate_novel_cell_embeddings()
    if novel_cell_data is None:
        print("Unable to generate novel cell embeddings, program exits")
        return

    # Step 3: Use BERT model to generate embeddings for new cell types
    print("Generate BERT embeddings for novel cell types...")
    _, _ = generate_novel_cell_type_embeddings()

    # Step 4: Use BERT model to generate embeddings for cell types in the knowledge graph
    print("BERT embedding for generating knowledge graph cell types...")
    ontology_cell_type_emb_dict, _ = generate_ontology_cell_type_emb()

    # Step 5: Classify novel cell data
    print("Start classifying novel cells...")
    true_labels, true_cell_types, pred_labels, similarities, filtered_ontology_types = classify_novel_cells(
        novel_cell_data, ontology_cell_type_emb_dict, cell_type_to_label)

    if true_labels is None:
        print("Classification failed, unable to conduct difficulty level assessment")
        return

    # Step 6: Evaluate by difficulty level
    evaluate_with_difficulty_levels(
        true_labels,
        true_cell_types,
        pred_labels,
        similarities,
        filtered_ontology_types,
        cell_type_to_label
    )

In [16]:
if __name__ == "__main__":
    main()

Successfully extracted2686unique is'a related cell types
New cell embedding and label mapping already exist, load directly...
Generate BERT embeddings for novel cell types...
The novel cell type already exists, load directly...


Generate novel cell type embeddings: 100%|██████████████████████████████████████████████| 75/75 [00:04<00:00, 16.73it/s]


BERT embedding for generating knowledge graph cell types...
The knowledge graph cell type already exists, load directly...
Successfully extracted2686unique is'a related cell types


Emb of cell types in generating knowledge graphs: 100%|████████████████████████████| 2686/2686 [00:14<00:00, 180.34it/s]


Start classifying novel cells...
Select 72 types from ontology that exist in novel cell data
Calculate the cosine similarity between cell embeddings and filtered type embeddings...

Assess difficulty level: 0.1 (Select 7/72 ontology overlap types)
Difficulty level 0.1 -  Average accuracy: 0.6932, Average F1 score: 0.5631

Assess difficulty level: 0.25 (Select 18/72 ontology overlap types)
Difficulty level 0.25 -  Average accuracy: 0.5406, Average F1 score: 0.359

Assess difficulty level: 0.5 (Select 36/72 ontology overlap types)
Difficulty level 0.5 -  Average accuracy: 0.4251, Average F1 score: 0.2498

Assess difficulty level: 0.75 (Select 54/72 ontology overlap types)
Difficulty level 0.75 -  Average accuracy: 0.3767, Average F1 score: 0.1943

Assess difficulty level: 1.0 (Select 72/72 ontology overlap types)

Complete classification report:
                                                                                      precision    recall  f1-score   support

                 


The evaluation results have been saved to /home/lxz/scmamba/novel_cell_classification_bert/results/novel_cell_classification_results_mamba_bert.csv
