# Embedding Novel Species

This notebook will create the files you need to embed a novel species that wasn't included in the training data.

To start, you will need to download the ESM2 protein embeddings and the reference proteome for the species.

You can find precalculated ESM2 protein embeddings for many species [here](https://drive.google.com/drive/folders/1_Dz7HS5N3GoOAG6MdhsXWY1nwLoN13DJ?usp=drive_link)

For reference proteomes, you can download them from [here](https://useast.ensembl.org/info/about/species.html).

If there is no protein embedding for the species you are interested in, you can request to have it made via Github or email, or you can create it yourself following instructions [here](https://github.com/snap-stanford/SATURN/tree/main/protein_embeddings).

In [43]:
import numpy as np
import pickle
import pandas as pd
import h5py
import torch
import scanpy as sc

## Convert ESM2 protien embeddings to UCE format


In [46]:
def load_from_hdf5(file_path):
    """Load dictionary from HDF5 file.

    Args:
        file_path (str): Path to HDF5 file containing embeddings data. The file should have
            a 'keys' dataset containing gene names and an 'arrays' group containing the
            corresponding embedding arrays.

    Returns:
        dict: Dictionary mapping gene names (str) to their embedding arrays (numpy.ndarray).
            The keys are decoded from bytes to UTF-8 strings.
    """
    data_dict = {}
    with h5py.File(file_path, "r") as f:
        # Get the keys
        keys = [k.decode("utf-8") for k in f["keys"][:]]

        # Load the arrays
        arrays_group = f["arrays"]
        for key in keys:
            data_dict[key] = arrays_group[str(key)][:]

    return data_dict

def filter_out_ensembl(PE):
    total_before = len(PE.keys())

    # Count keys that start with "ENS"
    ens_count = sum(1 for key in PE.keys() if key.startswith("ENS"))

    # Calculate remaining after filtering
    remaining = total_before - ens_count

    print(f"Total keys before filtering: {total_before}")
    print(f"Keys starting with 'ENS' to be filtered: {ens_count}")
    print(f"Keys remaining after filtering: {remaining}")

    # Remove all key-value pairs that start with "ENS"
    PE = {k: v for k, v in PE.items() if not k.startswith("ENS")}

    return PE

def convert_hdf5_to_pt(hdf5_path, output_path):
    emb = load_from_hdf5(hdf5_path)
    emb_tensors = {k: torch.from_numpy(v) for k, v in emb.items()}

    # Filter out ENSP IDs
    emb_tensors = filter_out_ensembl(emb_tensors)

    # check dim of first tensor
    print(emb_tensors[list(emb_tensors.keys())[0]].shape)
    torch.save(emb_tensors, output_path)


In [61]:
SPECIES_NAME = "platypus" # short hand name for this species, will be used in arguments and files
H5_FILE_PATH = "/mnt/czi-sci-ai/generate-cross-species-secondary/protein_embedding_data/pkl/ornithorhynchus_anatinus_gene_large.h5"

OUTPUT_DIR = "/mnt/czi-sci-ai/generate-cross-species-secondary/eval/baselines/uce/33l_8ep_1024t_1280"

MODEL_FILES_PATH = "/mnt/czi-sci-ai/generate-cross-species-secondary/eval/baselines/uce/33l_8ep_1024t_1280/model_files"
PROTEIN_EMBEDDINGS_PATH = f"{MODEL_FILES_PATH}/protein_embeddings"
NEW_SPECIES_CSV_PATH = f"{MODEL_FILES_PATH}/new_species_protein_embeddings.csv"

# Path to the species proteome
SPECIES_PROTEIN_FASTA_PATH = f"/mnt/czi-sci-ai/generate-cross-species-secondary/eval/baselines/uce/33l_8ep_1024t_1280/fasta/Ornithorhynchus_anatinus.mOrnAna1.p.v1.pep.all.fa"

# Path to the ESM2 Embeddings
SPECIES_PROTEIN_EMBEDDINGS_PATH = f"{PROTEIN_EMBEDDINGS_PATH}/Ornithorhynchus_anatinus_ESM2.pt"

# primary_assembly name, this needs to be matched to the FASTA file
ASSEMBLY_NAME = "mOrnAna1.p.v1"
# NCBI Taxonomy ID, please set this so that if someone else also embeds the same species,
# randomly generated chromosome tokens will be the same
TAXONOMY_ID = 9258

In [47]:
convert_hdf5_to_pt(H5_FILE_PATH, SPECIES_PROTEIN_EMBEDDINGS_PATH)

Total keys before filtering: 17418
Keys starting with 'ENS' to be filtered: 6089
Keys remaining after filtering: 11329
torch.Size([5120])


You can view the FASTA format here, please confirm the primary_assembly name is correct.

In [48]:
!head {SPECIES_PROTEIN_FASTA_PATH}

>ENSOANP00000024997.1 pep primary_assembly:mOrnAna1.p.v1:MT:2807:3763:1 gene:ENSOANG00000019388.1 transcript:ENSOANT00000028512.1 gene_biotype:protein_coding transcript_biotype:protein_coding gene_symbol:ND1 description:NADH dehydrogenase subunit 1 [Source:NCBI gene (formerly Entrezgene);Acc:808708]
MFLVNLLILIIPVLLAVAFLTLLERKILGYMQFRKGPNIVGAHGLLQPIADAVKLFTKEP
LRPLTSSIYMFILAPILALSLALTIWIPLPMPLPLIDLNLGLLFVLSVSGLSVYSILWSG
WASNSKYALTGALRAVAQTISYEVTLAIILLSIMLINGSFTLTTLNLTQEYMWLIVPTWP
LMLMWFISTLAETNRAPFDLTEGESELVSGFNVEYAAGPFAMFFLAEYANIIIMNALTVI
LFFGTYHLIFLPEMSTTTFMIKTMLLTSLFLWIRASYPRFRYDQLMHLLWKNFLPITLVT
CLWYIMLPTTLSGLPPQM
>ENSOANP00000024996.2 pep primary_assembly:mOrnAna1.p.v1:MT:3971:5014:1 gene:ENSOANG00000019384.2 transcript:ENSOANT00000028508.2 gene_biotype:protein_coding transcript_biotype:protein_coding gene_symbol:ND2 description:NADH dehydrogenase subunit 2 [Source:NCBI gene (formerly Entrezgene);Acc:808700]
MTPMTTLIMLFSLLLGTTLTLTSSHWLLMWMGLEVSTLAIIPLLTYTNHPRSIESAIKYF
LTQATASMLLMFA

In [49]:
species_to_paths = {
    SPECIES_NAME: SPECIES_PROTEIN_FASTA_PATH,
}

species_to_ids = {
    SPECIES_NAME: ASSEMBLY_NAME,
}

In [50]:
all_pos_def = []

missing_genes = {}
for species in species_to_ids.keys():
    missing_genes[species] = []
    proteome_path = species_to_paths[species]
    species_id = species_to_ids[species]

    with open(proteome_path) as f:
        proteome_lines = f.readlines()

    gene_symbol_to_location = {}
    gene_symbol_to_chrom = {}

    for line in proteome_lines:
        if line.startswith(">"):
            split_line = line.split()
            gene_symbol = [token for token in split_line if token.startswith("gene_symbol")]
            if len(gene_symbol) > 0:
                gene_symbol = gene_symbol[0].split(":")
                
                if len(gene_symbol) == 2:
                    gene_symbol = gene_symbol[1]
                elif len(gene_symbol) > 2:
                    gene_symbol = ":".join(gene_symbol[1:]) # fix for annoying zebrafish gene names with colons in them
                else:
                    1/0 # something weird happening, throw an error
                
                
                chrom = None
                
                chrom_arr = [token for token in split_line if token.startswith("chromosome:")]
                if len(chrom_arr) > 0:
                    chrom = chrom_arr[0].replace("chromosome:", "")
                else:
                    chrom_arr = [token for token in split_line if token.startswith("primary_assembly:")]
                    if len(chrom_arr) > 0:
                        chrom = chrom_arr[0].replace("primary_assembly:", "")
                    else:
                        chrom_arr = [token for token in split_line if token.startswith("scaffold:")] 
                        if len(chrom_arr) > 0:
                            chrom = chrom_arr[0].replace("scaffold:", "")
                if chrom is not None:
                    gene_symbol_to_location[gene_symbol] = chrom.split(":")[2]
                    gene_symbol_to_chrom[gene_symbol] = chrom.split(":")[1]
                else:
                    missing_genes[species].append(gene_symbol)
                    

    positional_df = pd.DataFrame()
    positional_df["gene_symbol"] = [gn.upper() for gn in list(gene_symbol_to_chrom.keys())]
    positional_df["chromosome"] = list(gene_symbol_to_chrom.values())
    positional_df["start"] = list(gene_symbol_to_location.values())
    positional_df = positional_df.sort_values(["chromosome", "start"])
    #positional_df = positional_df.set_index("gene_symbol")
    positional_df["species"] = species
    all_pos_def.append(positional_df)

In [51]:
master_pos_def = pd.concat(all_pos_def)
master_pos_def

Unnamed: 0,gene_symbol,chromosome,start,species
62,MTLN,1,100009143,platypus
82,SLC19A3,1,100065503,platypus
117,DAW1,1,100217380,platypus
135,EIF2A,1,100455816,platypus
152,TSC22D2,1,100555135,platypus
...,...,...,...,...
4571,ORNANAV1R3129,X5,9703484,platypus
8304,CHRNB2,X5,971487,platypus
8351,UBE2Q1,X5,983164,platypus
4620,ORNANAV1R3193,X5,9924327,platypus


In [52]:
master_pos_def["species"].value_counts() # double check how many genes are mapped

species
platypus    11330
Name: count, dtype: int64

In [53]:
for k, v in missing_genes.items():
    print(f"{k}: {len(v)}") # are any genes missing?

platypus: 0


In [54]:
# Count genes per chromosome
for species in species_to_ids.keys():
    print("*********")
    print(species)
    display(master_pos_def[master_pos_def["species"] == species]["chromosome"].value_counts().head(50))
    print("*********")

*********
platypus


chromosome
2                 808
3                 761
1                 732
5                 717
X1                660
4                 595
11                591
7                 501
10                499
X5                486
16                473
14                421
8                 416
13                371
9                 363
17                353
X2                333
12                277
19                275
6                 267
18                261
21                251
20                245
15                210
X3                158
RZJT01000103.1    117
X4                106
RZJT01000302.1     16
MT                 13
RZJT01000086.1      8
RZJT01000295.1      7
RZJT01000087.1      6
RZJT01000091.1      4
RZJT01000067.1      4
RZJT01000185.1      4
RZJT01000140.1      4
RZJT01000231.1      3
RZJT01000030.1      2
RZJT01000029.1      1
RZJT01000046.1      1
RZJT01000084.1      1
RZJT01000072.1      1
RZJT01000125.1      1
RZJT01000090.1      1
RZJT01000256.1      1

*********


## Filter ESM2 embeddings to only include genes in the master_pos_def



In [55]:
PE = torch.load(SPECIES_PROTEIN_EMBEDDINGS_PATH)
print(f"Number of genes in PE before filtering: {len(PE)}")
# Find intersection of gene symbols between PE and master_pos_def
pe_genes = {k.upper() for k in PE.keys()}
master_genes = set(master_pos_def["gene_symbol"].unique())
common_genes = pe_genes & master_genes
# Filter both PE and master_pos_def to only include common genes
PE = {k.upper(): v for k, v in PE.items() if k.upper() in common_genes}
print(f"Number of genes in PE after filtering: {len(PE)}")
master_pos_def = master_pos_def[master_pos_def["gene_symbol"].isin(common_genes)]
# Keep only first occurrence of each gene symbol to ensure uniqueness
master_pos_def = master_pos_def.drop_duplicates(subset=['gene_symbol'], keep='first')

print(f"Number of genes in master_pos_def after filtering: {len(master_pos_def)}")

if len(PE) != len(master_pos_def["gene_symbol"].unique()):
    print(f"Number of genes in PE: {len(PE)}")
    print(f"Number of genes in master_pos_def: {len(master_pos_def['gene_symbol'].unique())}")
    raise ValueError("Number of genes in PE and master_pos_def are not the same")

FILTERED_SPECIES_PROTEIN_EMBEDDINGS_PATH = SPECIES_PROTEIN_EMBEDDINGS_PATH[:-4] + "_filtered.pt"
torch.save(PE, FILTERED_SPECIES_PROTEIN_EMBEDDINGS_PATH)

  PE = torch.load(SPECIES_PROTEIN_EMBEDDINGS_PATH)


Number of genes in PE before filtering: 11329
Number of genes in PE after filtering: 11329
Number of genes in master_pos_def after filtering: 11329


In [56]:
chrom_file = f"{OUTPUT_DIR}/{SPECIES_NAME}_to_chrom_pos.csv"
master_pos_def.to_csv(chrom_file, index=False) # Save the DF
# The chromosome file path will be:
print(chrom_file)

/mnt/czi-sci-ai/generate-cross-species-secondary/eval/baselines/uce/33l_8ep_1024t_1280/platypus_to_chrom_pos.csv


# Generate token file

This will create the token file. Please note the offset value.

In [57]:
token_dim = 5120
species_to_offsets = {}

N_UNIQ_CHROM = len(master_pos_def[master_pos_def["species"] == species]["chromosome"].unique())

all_pe = torch.load("/mnt/czi-sci-ai/generate-cross-species-secondary/eval/baselines/uce/33l_8ep_1024t_1280/all_tokens.torch")[0:4] # read in existing token file to make sure 
# that special vocab tokens are the same for different seeds

offset = len(all_pe) # special tokens at the top!

PE = torch.load(FILTERED_SPECIES_PROTEIN_EMBEDDINGS_PATH)

pe_stacked = torch.stack(list(PE.values()))
all_pe = torch.vstack((all_pe, pe_stacked))
species_to_offsets[species] = offset

CHROM_TOKEN_OFFSET = all_pe.shape[0]
print("CHROM_TOKEN_OFFSET:", CHROM_TOKEN_OFFSET)
torch.manual_seed(TAXONOMY_ID)
CHROM_TENSORS = torch.normal(mean=0, std=1, size=(N_UNIQ_CHROM, 5120)) 
# N_UNIQ_CHROM is the total number of chromosome choices, it is hardcoded for now (for species in the training data)
all_pe = torch.vstack(
    (all_pe, CHROM_TENSORS))  # Add the chrom tensors to the end
all_pe.requires_grad = False

assert all_pe.size(0) == master_pos_def.shape[0] + 4 + N_UNIQ_CHROM, f"all_pe.size(0): {all_pe.size(0)}, master_pos_def.shape[0]: {master_pos_def.shape[0]}, N_UNIQ_CHROM: {N_UNIQ_CHROM}"

pe_tokens_file = f"{OUTPUT_DIR}/{SPECIES_NAME}_pe_tokens.torch"
print(f"Saving PE tokens to {pe_tokens_file}")
torch.save(all_pe, pe_tokens_file)

offsets_file = f"{OUTPUT_DIR}/{SPECIES_NAME}_offsets.pkl"
print(f"Saving offsets to {offsets_file}")
with open(offsets_file, "wb+") as f:
    pickle.dump(species_to_offsets, f)
print("Saved PE, offsets file")

  all_pe = torch.load("/mnt/czi-sci-ai/generate-cross-species-secondary/eval/baselines/uce/33l_8ep_1024t_1280/all_tokens.torch")[0:4] # read in existing token file to make sure
  PE = torch.load(FILTERED_SPECIES_PROTEIN_EMBEDDINGS_PATH)


CHROM_TOKEN_OFFSET: 11333
Saving PE tokens to /mnt/czi-sci-ai/generate-cross-species-secondary/eval/baselines/uce/33l_8ep_1024t_1280/platypus_pe_tokens.torch
Saving offsets to /mnt/czi-sci-ai/generate-cross-species-secondary/eval/baselines/uce/33l_8ep_1024t_1280/platypus_offsets.pkl
Saved PE, offsets file


In [58]:
N_UNIQ_CHROM

50

In [59]:
with open(offsets_file, "rb") as f:
    offsets = pickle.load(f)
print(offsets)


{'platypus': 4}


In [62]:
import pandas as pd

# Load existing CSV or create new one if it doesn't exist
try:
    embeddings_df = pd.read_csv(NEW_SPECIES_CSV_PATH)
except FileNotFoundError:
    embeddings_df = pd.DataFrame(columns=['species', 'path', 'chrom_token_offset'])

# Add new row
new_row = pd.DataFrame({
    'species': [SPECIES_NAME],
    'path': [FILTERED_SPECIES_PROTEIN_EMBEDDINGS_PATH],
    'chrom_token_offset': [CHROM_TOKEN_OFFSET]
})

# Combine and remove duplicates based on species
embeddings_df = pd.concat([embeddings_df, new_row])
embeddings_df = embeddings_df.drop_duplicates(subset=['species'], keep='last')

# Save updated CSV
embeddings_df.to_csv(NEW_SPECIES_CSV_PATH, index=False)
print(f"Updated {NEW_SPECIES_CSV_PATH}")


Updated /mnt/czi-sci-ai/generate-cross-species-secondary/eval/baselines/uce/33l_8ep_1024t_1280/model_files/new_species_protein_embeddings.csv


# Example evaluation of new species

**Note: when you evaluate a new species, you need to change some arguments and modify some files:**

You will  need to modify the csv in `model_files/new_species_protein_embeddings.csv` to include the new protein embeddings file you downloaded.

In the file add a row for the new species with the format:
`species name,full path to protein embedding file`

Please also add this line to the dictionary created on line 247 in the file `data_proc/data_utils.py`.

When you want to embed this new species, you will need to specify these newly created files as arguments.
- `CHROM_TOKEN_OFFSET`: This tells UCE when the rows corresponding to chromosome tokens starts.
- `spec_chrom_csv_path`: This is a new csv, created by this script, which maps genes to chromosomes and genomic positions
- `token_file`: This is a new token file that will work just for this species. The embeddings generated will still be universal though!
- `offset_pkl_path`: This is another file that maps genes to tokens


```

accelerate launch eval_single_anndata.py chicken_heart.h5ad --species=chicken --CHROM_TOKEN_OFFSET=13275 --spec_chrom_csv_path=data_proc/chicken_to_chrom_pos.csv --token_file=data_proc/chicken_pe_tokens.torch --offset_pkl_path=data_proc/chicken_offsets.pkl --dir=... --multi_gpu=True

```