# Gene Embedding Generation for GenePT

This notebook generates custom GenePT embeddings for genes using OpenAI's text-embedding-3-large model. The process includes:

1. Loading gene descriptions from originally sourced from NCBI and UniProt
3. Generating enhanced gene descriptions using GPT-4-mini (or whatever model you want)
4. Creating embeddings for each gene using the combined descriptions
5. Handling duplicate genes by averaging their embeddings
6. Saving the final embeddings to a parquet file

The prompt we use by default creates embeddings that capture information about:
- Gene associations
- Cell types
- Drug interactions
- Biological pathways

Output: A 3072-dimensional embedding vector for each gene

This notebook uses `dotenv` to load the OpenAI API key from the `.env` file. To install the `dotenv` package, run
```
pip install python-dotenv
```

## Setup 

In [1]:
# Variables imported from notebook_setup.ipynb
# this just gets the linter to stop complaining

repo_dir = None  # type: ignore
data_dir = None  # type: ignore

%run notebook_setup.ipynb

autoreload enabled
repo_dir set to /Users/rj/personal/GenePT-tools
File already exists at /Users/rj/personal/GenePT-tools/data/GenePT_emebdding_v2.zip
Extracting files...
Extracting GenePT_emebdding_v2/
Skipping GenePT_emebdding_v2/NCBI_UniProt_summary_of_genes.json - already exists with same size
Skipping GenePT_emebdding_v2/GenePT_gene_embedding_ada_text.pickle - already exists with same size
Skipping GenePT_emebdding_v2/GenePT_gene_protein_embedding_model_3_text.pickle. - already exists with same size
Skipping GenePT_emebdding_v2/NCBI_summary_of_genes.json - already exists with same size
Extraction complete!
Setup finished!
data_dir set to /Users/rj/personal/GenePT-tools/data


# Load `gene_info_table`

We use `gene_info_table` as the reference table.  This is from the original GenePT paper, and allows us to use the same set of genes for our custom embeddings.

In [2]:
import pandas as pd

gene_info_table = pd.read_parquet(data_dir / "gene_info_table.parquet")
gene_info_table

Unnamed: 0_level_0,ensembl_id,gene_type
index,Unnamed: 1_level_1,Unnamed: 2_level_1
TSPAN6,ENSG00000000003,protein_coding
TNMD,ENSG00000000005,protein_coding
DPM1,ENSG00000000419,protein_coding
SCYL3,ENSG00000000457,protein_coding
C1orf112,ENSG00000000460,protein_coding
...,...,...
LINC02481,ENSG00000246526,
LINC01856,ENSG00000237574,
LINC02698,ENSG00000256717,
BOLA2-SMG1P6,ENSG00000261740,


In [3]:
from src.utils import setup_data_dir
from pathlib import Path


embedding_dir = data_dir / "GenePT_emebdding_v2"
ncbi_summary_of_genes_path = embedding_dir / "NCBI_summary_of_genes.json"
ncbi_uniprot_summary_of_genes_path = (
    embedding_dir / "NCBI_UniProt_summary_of_genes.json"
)

print("embedding_dir exists:", embedding_dir.exists())
print("ncbi_summary_of_genes_path exists:", ncbi_summary_of_genes_path.exists())
print(
    "ncbi_uniprot_summary_of_genes_path exists:",
    ncbi_uniprot_summary_of_genes_path.exists(),
)

embedding_dir exists: True
ncbi_summary_of_genes_path exists: True
ncbi_uniprot_summary_of_genes_path exists: True


In [4]:
import json

ncbi_summary_of_genes = json.load(open(ncbi_summary_of_genes_path))
ncbi_uniprot_summary_of_genes = json.load(open(ncbi_uniprot_summary_of_genes_path))

In [5]:
import os
from src.prompt_templates import *

prompt_template = NCBI_UNIPROT_ASSOCIATED_CELL_TYPE_DRUG_PATHWAY_PROMPT_V1

print(
    prompt_template.format(
        "LOC124907803", ncbi_uniprot_summary_of_genes["LOC124907803"]
    )
)

Tell me about the LOC124907803 gene.

Here is the NCBI and UniProt summary of the gene:

Gene Symbol LOC124907803

----

In addition to the provided information, please:

1. List any other genes that the gene is associated with, particularly those not mentioned in the summaries above.
2. List any cell types or cell classes that the gene is expressed in.
3. List any drug or drug classes that are known to interact with this gene. 
4. Pathways and biological processes that this gene is involved in.

Only include specific information about the gene or gene class. If information is not well documented, say so briefly and don't expound on general information.



# Prompt quality characterization

To characterize the quality of the prompt. Lets look at a few example genes.

* **BRCA1** is a very well documented breast cancer gene
* **LOC124907803** is completely undocumented
* **PRDM9** is an obscure but studied gene

In [10]:
from dotenv import load_dotenv
load_dotenv()
from openai import OpenAI  # New import

client = OpenAI()  # Initialize client


In [21]:

gene_completion_test = ["LOC124907803", "BRCA1", 'PRDM9']
for gene in sorted(gene_completion_test):
    completion = client.chat.completions.create(
        model="gpt-4o-mini",
        messages=[
            {
                "role": "user",
                "content": prompt_template.format(
                    gene, ncbi_uniprot_summary_of_genes[gene]
                ),
            }
        ],
        temperature=0.0,
    )
    print(
        f"""
{ncbi_uniprot_summary_of_genes[gene]}
                    
{completion.choices[0].message.content}

"""
    )
print("-" * 100)


Gene Symbol BRCA1 This gene encodes a 190 kD nuclear phosphoprotein that plays a role in maintaining genomic stability, and it also acts as a tumor suppressor. The BRCA1 gene contains 22 exons spanning about 110 kb of DNA. The encoded protein combines with other tumor suppressors, DNA damage sensors, and signal transducers to form a large multi-subunit protein complex known as the BRCA1-associated genome surveillance complex (BASC). This gene product associates with RNA polymerase II, and through the C-terminal domain, also interacts with histone deacetylase complexes. This protein thus plays a role in transcription, DNA repair of double-stranded breaks, and recombination. Mutations in this gene are responsible for approximately 40% of inherited breast cancers and more than 80% of inherited breast and ovarian cancers. Alternative splicing plays a role in modulating the subcellular localization and physiological function of this gene. Many alternatively spliced transcript variants, som

# Create request batches

We do a test_batch first to make sure all of the machinery is working wel, the prompts are working as expected, etc.  Then we do a full batch of all the genes.

In [90]:
from src import embeddings


test_batch_info =embeddings.BatchInfo(
    batch_name = "test_batch",
    request_data = embeddings.get_gene_text_batch_requests(dict(list(ncbi_uniprot_summary_of_genes.items())[:10]), prompt_template, "test_batch_requests"),
    batch_description = "small batch of gene embeddings for testing the batch API"
)
test_batch_job = embeddings.create_batch_job(test_batch_info, "completion", client)
test_batch_job

Batch(id='batch_67ba5c756fa88190972de6bcf534d8a2', completion_window='24h', created_at=1740266613, endpoint='/v1/chat/completions', input_file_id='file-W8yKadtesgNbTiUGX4iv83', object='batch', status='validating', cancelled_at=None, cancelling_at=None, completed_at=None, error_file_id=None, errors=None, expired_at=None, expires_at=1740353013, failed_at=None, finalizing_at=None, in_progress_at=None, metadata={'description': 'small batch of gene embeddings for testing the batch API'}, output_file_id=None, request_counts=BatchRequestCounts(completed=0, failed=0, total=0))

In [92]:
test_batch_status = embeddings.monitor_batch_status(client, test_batch_job, check_interval=60, verbose=True)

2025-02-22 15:24:13 - Completed: 13, Failed: 0, Total: 100
2025-02-22 15:25:14 - Completed: 31, Failed: 0, Total: 100
2025-02-22 15:26:14 - Completed: 41, Failed: 0, Total: 100
2025-02-22 15:27:14 - Completed: 41, Failed: 0, Total: 100
2025-02-22 15:28:14 - Completed: 44, Failed: 0, Total: 100
2025-02-22 15:29:14 - Completed: 50, Failed: 0, Total: 100
2025-02-22 15:30:14 - Completed: 58, Failed: 0, Total: 100
2025-02-22 15:31:14 - Completed: 62, Failed: 0, Total: 100
2025-02-22 15:32:15 - Completed: 62, Failed: 0, Total: 100
2025-02-22 15:33:15 - Completed: 62, Failed: 0, Total: 100
2025-02-22 15:34:15 - Completed: 62, Failed: 0, Total: 100
2025-02-22 15:35:15 - Completed: 65, Failed: 0, Total: 100
2025-02-22 15:36:15 - Completed: 67, Failed: 0, Total: 100
2025-02-22 15:37:15 - Completed: 72, Failed: 0, Total: 100
2025-02-22 15:38:16 - Completed: 72, Failed: 0, Total: 100
2025-02-22 15:39:16 - Completed: 73, Failed: 0, Total: 100
2025-02-22 15:40:16 - Completed: 76, Failed: 0, Total: 1

In [101]:
from src import embeddings
test_output_path = embeddings.save_batch_response(test_batch_info, test_batch_status, client)


In [20]:

# responses = embeddings.load_batch_responses(test_batch_info)
full_batch_info = embeddings.BatchInfo(
    batch_name = "full_batch",
    request_data = embeddings.get_gene_text_batch_requests(ncbi_uniprot_summary_of_genes, prompt_template, "full_batch_requests"),
    batch_description = "full batch of gene embeddings for testing the batch API",
    data_dir = data_dir
)
responses = embeddings.load_batch_responses(full_batch_info)

In [21]:
responses[:3]

[{'id': 'batch_req_67ba32235e708190b1afe19989de3b8d',
  'custom_id': 'initial-100-request-0',
  'response': {'status_code': 200,
   'request_id': 'efda5e9d55598fcea928e1604c0622fb',
   'body': {'id': 'chatcmpl-B3p1mQLwc8aN3E0IGzNmGbkzJxLAo',
    'object': 'chat.completion',
    'created': 1740250834,
    'model': 'gpt-4o-mini-2024-07-18',
    'choices': [{'index': 0,
      'message': {'role': 'assistant',
       'content': '**Gene Symbol:** LINC01409\n\n### Summary\nLINC01409 is classified as a long intergenic non-protein coding RNA (lincRNA) gene. The specific functions and mechanisms of LINC01409 are still under investigation, with limited detailed information available regarding its biological roles.\n\n### 1. Associated Genes\nInformation specific to genes directly associated with LINC01409 is limited. However, it may be co-expressed with other long non-coding RNAs (lncRNAs) or genes within the same genomic region, though specific genes are not consistently documented in available 

In [22]:

gene_descriptions_df = embeddings.create_gene_descriptions_dataframe(
    ncbi_uniprot_summary_of_genes,
    responses,
    gene_info_table
)
gene_descriptions_df

Unnamed: 0_level_0,description,gpt_response,ensembl_id,gene_type
gene_name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
LINC01409,Gene Symbol LINC01409,**Gene Symbol:** LINC01409\n\n### Summary\nLIN...,ENSG00000237491,
FAM87B,Gene Symbol FAM87B,### FAM87B Gene Overview\n- **Gene Symbol:** F...,ENSG00000177757,lincRNA
LINC01128,Gene Symbol LINC01128,### LINC01128 Gene Overview\n\n**Gene Symbol:*...,ENSG00000228794,lncRNA
LINC00115,Gene Symbol LINC00115,### LINC00115 Gene Overview\n\n**Gene Symbol:*...,ENSG00000225880,lincRNA
FAM41C,Gene Symbol FAM41C,### FAM41C Gene Overview\n\n**Gene Symbol:** F...,ENSG00000230368,lincRNA
...,...,...,...,...
ZUP1,Gene Symbol ZUP1 This gene encodes a protein c...,### ZUP1 Gene Overview\n\n**Gene Symbol:** ZUP...,ENSG00000153975,
ZWILCH,Gene Symbol ZWILCH Involved in protein localiz...,### ZWILCH Gene Overview\n\n- **Gene Symbol**:...,ENSG00000174442,protein_coding
ZXDA,Gene Symbol ZXDA This gene encodes one of two ...,### ZXDA Gene Overview\n\n**Gene Symbol**: ZXD...,ENSG00000198205,protein_coding
ZXDB,Gene Symbol ZXDB Protein summary: Cooperates w...,### ZXDB Gene Summary\n\n**Gene Symbol**: ZXDB...,ENSG00000198455,protein_coding


In [120]:
gene_descriptions_df.loc["ZUP1"]

description     Gene Symbol ZUP1 This gene encodes a protein c...
gpt_response    ### ZUP1 Gene Overview\n\n**Gene Symbol:** ZUP...
ensembl_id                                        ENSG00000153975
gene_type                                                    None
Name: ZUP1, dtype: object

In [121]:
print(gene_descriptions_df.shape)
print(gene_info_table.shape)
print(gene_info_table.loc["SNORD112"].shape)

(37262, 4)
(84425, 2)
(51, 2)


In [39]:
# Create embeddings directory if it doesn't exist
dataset_dir = data_dir / "generated" / "huggingface_dataset"
dataset_dir.mkdir(parents=True, exist_ok=True)

# Save the averaged embeddings to a parquet file
dataset_file_name = "generated_descriptions_gpt4o_mini_cell_type_drugs_pathways.parquet"
dataset_file_path = dataset_dir / dataset_file_name
gene_descriptions_df.to_parquet( dataset_file_path )


# Identify genes with missing Ensembl ids

In [125]:
missing_ensembl_ids = gene_descriptions_df[
    gene_descriptions_df.ensembl_id.isna()
].index
missing_ensembl_ids

print(f"Descriptions with missing ensembl ids: {len(gene_descriptions_df.loc[missing_ensembl_ids])}")

Descriptions with missing ensembl ids: 576


In [132]:
full_batch_of_embedding_requests = embeddings.get_gene_embedding_batch_requests(gene_descriptions_df, "full_batch_embedding_requests")

In [133]:
len(full_batch_of_embedding_requests)

37262

In [135]:
full_batch_of_embedding_requests[1000]

{'custom_id': 'full-batch-embedding-request-1000',
 'method': 'POST',
 'url': '/v1/embeddings',
 'body': {'model': 'text-embedding-3-large',
  'input': '\n    Gene Symbol SYCP1 Enables double-stranded DNA binding activity. Involved in protein homotetramerization. Predicted to be located in synaptonemal complex. Predicted to be active in central element; male germ cell nucleus; and transverse filament. Protein summary: Major component of the transverse filaments of synaptonemal complexes, formed between homologous chromosomes during meiotic prophase. Required for normal assembly of the central element of the synaptonemal complexes. Required for normal centromere pairing during meiosis. Required for normal meiotic chromosome synapsis during oocyte and spermatocyte development and for normal male and female fertility.\n    \n    ### SYCP1 Gene Overview\n\n**Gene Symbol:** SYCP1  \n**Function:** SYCP1 encodes a protein that is a major component of the transverse filaments of the synaptonem

In [15]:
embedding_batch_info = embeddings.BatchInfo(
    batch_name = "full_batch_embeddings",
    request_data = full_batch_of_embedding_requests,
    batch_description = "full batch of gene embeddings based on texted generated by GPT-4o-mini based on the NCBI_UNIPROT_ASSOCIATED_CELL_TYPE_DRUG_PATHWAY_PROMPT_V1 prompt",
    data_dir = data_dir
)

embedding_batch_job = embeddings.create_batch_job(embedding_batch_info, 'embedding', client)

print(embedding_batch_job)

In [11]:
from src import embeddings

# # to monitor a batch job that you don't have the job object for, you can use the batch id
# final_status = embeddings.monitor_batch_status(
#     client, "batch_67bb84122cc48190b5c95ced84631145", check_interval=60, verbose=True
# )

final_status = embeddings.monitor_batch_status(
    client, embedding_batch_job, check_interval=60, verbose=True
)

2025-02-23 13:48:46 - Completed: 37108, Failed: 0, Total: 37262
2025-02-23 13:49:46 - Completed: 37108, Failed: 0, Total: 37262
2025-02-23 13:50:46 - Completed: 37108, Failed: 0, Total: 37262
2025-02-23 13:51:46 - Completed: 37108, Failed: 0, Total: 37262
2025-02-23 13:52:46 - Completed: 37108, Failed: 0, Total: 37262
2025-02-23 13:53:47 - Completed: 37108, Failed: 0, Total: 37262
2025-02-23 13:54:47 - Completed: 37262, Failed: 0, Total: 37262
Batch(id='batch_67bb84122cc48190b5c95ced84631145', completion_window='24h', created_at=1740342290, endpoint='/v1/embeddings', input_file_id='file-2rSyTX17pXwnfkXrt51f6x', object='batch', status='finalizing', cancelled_at=None, cancelling_at=None, completed_at=None, error_file_id=None, errors=None, expired_at=None, expires_at=1740428690, failed_at=None, finalizing_at=1740347649, in_progress_at=1740342297, metadata={'description': 'Embedding batch for prompt '}, output_file_id=None, request_counts=BatchRequestCounts(completed=37262, failed=0, total

# Save the embeddings to a file

We save the embeddings to a file, and then load them into a dataframe to make sure that we don't lose our work!

In [16]:
full_embedding_batch_response_path = embeddings.save_batch_response(
    embedding_batch_info, final_status, client
)

In [23]:
responses = []
with open(full_embedding_batch_response_path, "r") as f:
    for line in f:
        responses.append(json.loads(line))

embedding_df = pd.DataFrame(
    (response["response"]["body"]["data"][0]["embedding"] for response in responses),
    index=pd.Series(gene_descriptions_df.index),
)
embedding_df

Unnamed: 0_level_0,0,1,2,3,4,5,6,7,8,9,...,3062,3063,3064,3065,3066,3067,3068,3069,3070,3071
gene_name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
LINC01409,-0.022597,0.022483,-0.012700,0.000441,0.011071,-0.016550,0.058676,-0.000197,-0.010964,0.041104,...,-0.012327,0.019163,0.028202,-0.021991,-0.001263,-0.009683,0.023784,-0.012725,-0.002929,-0.013684
FAM87B,-0.008644,0.014819,-0.011657,-0.009391,-0.001620,-0.036544,0.020194,-0.024062,0.011545,0.012233,...,-0.014475,0.003162,0.021357,-0.004657,-0.026791,-0.011367,0.026910,-0.015887,-0.017643,-0.020194
LINC01128,-0.021759,0.013620,-0.007717,0.001720,-0.014182,-0.003041,0.028722,0.001704,0.008554,-0.006631,...,-0.014463,0.008548,0.021440,-0.029285,-0.009308,-0.020980,0.029693,-0.022257,-0.010253,-0.044285
LINC00115,-0.025666,-0.003462,-0.008776,-0.021912,0.008628,-0.017488,0.072318,-0.004751,0.000666,0.014042,...,-0.023365,0.012467,0.016511,-0.020343,-0.010281,-0.016729,0.024483,-0.019468,-0.003094,-0.037394
FAM41C,-0.024241,0.027637,-0.010726,0.004480,-0.013979,-0.009189,0.015590,-0.015516,0.004523,0.004898,...,-0.023163,-0.001465,0.032123,-0.012691,-0.002928,-0.017896,0.008619,-0.013508,-0.028603,-0.015169
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
ZUP1,-0.001156,0.006608,-0.005093,0.022191,0.014974,-0.020785,0.027664,-0.014472,0.061302,-0.024463,...,0.001357,0.011962,0.010487,-0.015715,0.019970,0.003212,0.014058,0.010675,-0.011165,-0.004148
ZWILCH,0.028902,0.009019,-0.007437,0.013152,-0.028322,-0.022507,0.005287,-0.017457,0.015704,-0.000555,...,-0.008010,0.029667,0.039820,-0.007885,0.003956,0.008920,-0.005119,0.014794,-0.009573,-0.012691
ZXDA,0.001503,0.023130,-0.009815,0.026748,0.038371,-0.016691,0.040565,-0.013857,0.042707,-0.032285,...,-0.006187,0.001316,0.013648,-0.012616,0.015333,-0.011500,0.010135,0.012557,-0.012917,-0.018010
ZXDB,0.025981,0.012188,-0.009365,0.040015,0.035085,-0.028250,0.041058,-0.011341,0.025850,-0.019720,...,-0.014060,0.010075,0.017621,-0.003495,0.014869,-0.023177,0.012051,0.000924,0.004441,-0.020177


# Remove duplicates

We average genes with the same ensembl id, since they should have generally the same direction if they are correspond to the same gene.

In [25]:
pd.Series(embedding_df.index).value_counts()

gene_name
SNORD112    51
SNORA31     26
SNORA40     24
SNORA48     21
SNORA25     20
            ..
EMX2OS       1
EMX2         1
EMSY         1
EMP2         1
ZYXP1        1
Name: count, Length: 33703, dtype: int64

In [30]:
embedding_df.loc["SNORA31"].describe()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,3062,3063,3064,3065,3066,3067,3068,3069,3070,3071
count,26.0,26.0,26.0,26.0,26.0,26.0,26.0,26.0,26.0,26.0,...,26.0,26.0,26.0,26.0,26.0,26.0,26.0,26.0,26.0,26.0
mean,-0.003282,0.001444,-0.016643,0.00922,-0.006747,-0.015256,0.051763,-0.003245,0.006273,0.005917,...,-0.00352,0.019362,0.029811,-0.008472,-0.019422,-0.025269,0.018193,0.01205,0.008808,0.001293
std,0.000113,4.8e-05,1.7e-05,7e-05,0.00014,5.9e-05,5.1e-05,5.2e-05,7.5e-05,0.000141,...,6.1e-05,5.1e-05,3.3e-05,3.8e-05,5.1e-05,6.7e-05,2.4e-05,3.1e-05,4.8e-05,5.8e-05
min,-0.003507,0.00135,-0.016671,0.009099,-0.006929,-0.015442,0.051663,-0.003332,0.006139,0.005647,...,-0.003606,0.01929,0.029748,-0.008547,-0.019537,-0.025389,0.018135,0.011978,0.008724,0.001181
25%,-0.003368,0.001408,-0.016658,0.009147,-0.006887,-0.015284,0.051719,-0.003282,0.006226,0.00588,...,-0.003567,0.019314,0.029794,-0.008508,-0.019459,-0.025323,0.018181,0.012036,0.008771,0.001249
50%,-0.003287,0.001451,-0.016637,0.00923,-0.006749,-0.01527,0.051766,-0.00325,0.00625,0.005918,...,-0.003546,0.019364,0.029803,-0.008465,-0.019424,-0.025268,0.018195,0.012052,0.008822,0.001288
75%,-0.003176,0.001479,-0.016631,0.009271,-0.006603,-0.015202,0.051804,-0.003194,0.006336,0.006006,...,-0.003462,0.019407,0.02984,-0.008446,-0.019386,-0.025227,0.018208,0.012066,0.008838,0.001351
max,-0.003132,0.001525,-0.016615,0.009346,-0.006511,-0.015176,0.051857,-0.00316,0.006416,0.006176,...,-0.003407,0.01944,0.029863,-0.008402,-0.019331,-0.025153,0.018239,0.012123,0.008902,0.001388


In [26]:
embedding_df_averaged = embedding_df.groupby(level=0).mean()

In [27]:
embedding_df_averaged.loc["SNORA31"]

0      -0.003282
1       0.001444
2      -0.016643
3       0.009220
4      -0.006747
          ...   
3067   -0.025269
3068    0.018193
3069    0.012050
3070    0.008808
3071    0.001293
Name: SNORA31, Length: 3072, dtype: float64

In [231]:
embedding_df_averaged.shape

(33703, 3072)

In [34]:
# Create embeddings directory if it doesn't exist
embedding_dir = data_dir / "generated" / "embeddings"
embedding_dir.mkdir(parents=True, exist_ok=True)

embedding_df_averaged.columns = [str(col) for col in embedding_df_averaged.columns]
# Save the averaged embeddings to a parquet file
embedding_file_name = "embedding_associations_age_cell_type_drugs_pathways_openai_large.parquet"
embedding_file_path = embedding_dir / embedding_file_name
embedding_df_averaged.to_parquet( embedding_file_path )



# Upload the embeddings to HuggingFace

In [33]:
from dotenv import load_dotenv

load_dotenv()

from datasets import Dataset
from huggingface_hub import HfApi
import os

# Initialize Hugging Face API
token = os.getenv("HF_WRITE_TOKEN")
api = HfApi(token=token)

  from .autonotebook import tqdm as notebook_tqdm


## Upload the gene descriptions dataset


In [43]:
dataset_repo_id = "honicky/genept-composable-embeddings-source-data"

# Upload each parquet file directly
try:
    api.upload_file(
        path_or_fileobj=str(dataset_file_path),
        path_in_repo=dataset_file_path.name,
        repo_id=dataset_repo_id,
        repo_type="dataset",
    )

    print(f"Successfully uploaded {dataset_file_path}")

except Exception as e:
    print(f"Error uploading {dataset_file_path}")

generated_descriptions_gpt4o_mini_cell_type_drugs_pathways.parquet: 100%|██████████| 32.1M/32.1M [00:12<00:00, 2.60MB/s]


Successfully uploaded /Users/rj/personal/GenePT-tools/data/generated/huggingface_dataset/generated_descriptions_gpt4o_mini_cell_type_drugs_pathways.parquet


## Upload the embedding model

Don't forget to update the README.md to add the new embedding model to the list of models if needed!

In [35]:
model_repo_id = "honicky/genept-composable-embeddings"

try:
    api.upload_file(
        path_or_fileobj=embedding_file_path,
        path_in_repo=embedding_file_name,
        repo_id=model_repo_id,
        repo_type="model",
    )
    print(f"Successfully uploaded {embedding_file_name}")
except Exception as e:
    print(f"Error uploading {embedding_file_path}")



embedding_associations_age_cell_type_drugs_pathways_openai_large.parquet: 100%|██████████| 1.04G/1.04G [06:06<00:00, 2.83MB/s]


Successfully uploaded embedding_associations_age_cell_type_drugs_pathways_openai_large.parquet
