# Interactive Cell2Sentence Exploration

This notebook helps you understand how Cell2Sentence works by allowing you to:
1. Load a pretrained C2S model
2. Query it interactively with gene lists
3. Generate new cells conditioned on cell types
4. Understand the prompt structure

## Key Concepts:
- **Cell Sentences**: Space-separated gene names ordered by expression level
- **Cell Type Prediction**: Given genes → predict cell type
- **Cell Generation**: Given cell type → generate gene list
- **Embeddings**: Convert cells to numerical representations

In [1]:
# Print the conda environment name
import os
print(os.environ.get('CONDA_DEFAULT_ENV', 'No conda environment found'))

cell2sentence


In [2]:
# Import necessary libraries
import os
import torch
from transformers import AutoModelForCausalLM
import cell2sentence as cs
from cell2sentence.prompt_formatter import C2SPromptFormatter, C2SMultiCellPromptFormatter

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# Print available devices
print("Available devices:", torch.cuda.device_count(), "GPUs")

Available devices: 2 GPUs


In [None]:
# Initialize CSModel wrapper (inference mode - no local save needed)
model_id = "vandijklab/C2S-Scale-Gemma-2-27B"

csmodel = cs.CSModel(
    model_name_or_path=model_id
)

print(f"CSModel wrapper initialized")


Using device: cuda
Inference mode: model will be loaded from HuggingFace cache
CSModel initialized. Will use device: cuda
Inference mode: model will be loaded from HuggingFace cache
CSModel initialized. Will use device: cuda


In [5]:
# Load the model into memory

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="balanced",
    torch_dtype=torch.bfloat16,
    trust_remote_code=True
)

print(f"Model loaded on device: {model.device}")


Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 12/12 [00:19<00:00,  1.62s/it]



Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 12/12 [00:19<00:00,  1.62s/it]



Model loaded on device: cuda:0


## Understanding Prompt Templates

Let's first look at how Cell2Sentence structures its prompts in the single and multi cell settings:

In [6]:
# Examine cell type prediction prompts
cell_type_formatter = C2SPromptFormatter(task="cell_type_prediction", top_k_genes=100)

print("Cell Type Prediction Prompts:")
print("=" * 50)
for i, template in enumerate(cell_type_formatter.prompts_dict["model_input"][:3]):  # Show first 3
    print(f"Template {i+1}:")
    print(template)
    print()

print("Response format:")
print(cell_type_formatter.prompts_dict["response"][0])

Cell Type Prediction Prompts:
Template 1:
The following is a list of {num_genes} gene names ordered by descending expression level in a {organism} cell. Your task is to give the cell type which this cell belongs to based on its gene expression.
Cell sentence: {cell_sentence}.
The cell type corresponding to these genes is:

Template 2:
Below is a list of {num_genes} gene names in order of descending expression level from a {organism} cell. Based on this, predict what the cell type of this cell is.
Cell sentence: {cell_sentence}.
These genes are most likely associated with cell type:

Template 3:
Given the list of {num_genes} gene names ordered by descending expression level from a {organism} cell, identify the cell type.
Cell sentence: {cell_sentence}.
The probable cell type for these genes is:

Response format:
{cell_type}.


In [7]:
# Examine cell generation prompts
cell_gen_formatter = C2SPromptFormatter(task="cell_type_generation", top_k_genes=200)

print("Cell Generation Prompts:")
print("=" * 50)
for i, template in enumerate(cell_gen_formatter.prompts_dict["model_input"][:2]):  # Show first 2
    print(f"Template {i+1}:")
    print(template)
    print()

print("Response format:")
print(cell_gen_formatter.prompts_dict["response"][0])

Cell Generation Prompts:
Template 1:
Generate a list of {num_genes} genes in order of descending expression which represent a {organism} cell of cell type {cell_type}.
Cell sentence:

Template 2:
Produce a list of {num_genes} gene names in descending order of expression which represent the expressed genes of a {organism} {cell_type} cell.
Cell sentence:

Response format:
{cell_sentence}.


In [8]:
# Examine tissue preditcion with multi cell prompts
multi_cell_type_formatter = C2SMultiCellPromptFormatter(task="tissue_prediction", top_k_genes=100)

print("Multi cell tissue prediction prompts:")
print("=" * 50)
for i, template in enumerate(multi_cell_type_formatter.prompts_dict["model_input"][:2]):  # Show first 2
    print(f"Template {i+1}:")
    print(template)
    print()

print("Response format:")
print(multi_cell_type_formatter.prompts_dict["response"][0])

Multi cell tissue prediction prompts:
Template 1:
The following is a list of {num_genes} gene names ordered by descending expression level for {num_cells} different {organism} cells. Your task is to give the tissue which these cells belong to based on their gene expression.
Cell sentences:
{multi_cell_sentences}.
The tissue which these cells belong to is:

Template 2:
Below is a list of {num_genes} genes in decreasing order of expression in {num_cells} {organism} cells. Given this, predict the tissue which these cells belongs to.
Cell sentences:
{multi_cell_sentences}.
The tissue which these cells originate from is:

Response format:
{tissue}.


## Cell Type Prediction

Now let's try predicting cell types from gene lists. Here are some example gene signatures:

In [9]:
# Example gene signatures for different cell types
example_signatures = {
    "T cell signature": "CD3D CD3E CD3G IL7R CCR7 LEF1 TCF7 LTB",
    "B cell signature": "CD19 MS4A1 CD79A CD79B IGHM IGKC IGLC2 PAX5",
    "Monocyte signature": "CD14 FCGR3A S100A8 S100A9 LYZ VCAN FCN1 CTSS",
    "NK cell signature": "KLRD1 KLRF1 NCR1 GZMA GZMB PRF1 GNLY NKG7",
    "Dendritic cell signature": "FCER1A CLEC4C IRF7 IRF8 CLEC10A CD1C CLEC9A"
}

for cell_type, genes in example_signatures.items():
    print(f"{cell_type}: {genes}")

T cell signature: CD3D CD3E CD3G IL7R CCR7 LEF1 TCF7 LTB
B cell signature: CD19 MS4A1 CD79A CD79B IGHM IGKC IGLC2 PAX5
Monocyte signature: CD14 FCGR3A S100A8 S100A9 LYZ VCAN FCN1 CTSS
NK cell signature: KLRD1 KLRF1 NCR1 GZMA GZMB PRF1 GNLY NKG7
Dendritic cell signature: FCER1A CLEC4C IRF7 IRF8 CLEC10A CD1C CLEC9A


In [10]:
def predict_cell_type(gene_list, organism="Homo sapiens", verbose=True):
    """Predict cell type from a list of genes."""
    n_genes = len(gene_list.split())
    
    # Create prompt
    formatter = C2SPromptFormatter(task="cell_type_prediction", top_k_genes=n_genes)
    prompt_template = formatter.prompts_dict["model_input"][0]
    prompt = prompt_template.format(
        num_genes=n_genes,
        organism=organism,
        cell_sentence=gene_list
    )
    
    if verbose:
        print(f"Prompt:")
        print("-" * 50)
        print(prompt)
        print("-" * 50)
    
    # Generate prediction
    response = csmodel.generate_from_prompt(
        model=model,
        prompt=prompt,
        max_num_tokens=50,
        temperature=0.1,  # Low temperature for consistent predictions
        do_sample=True
    )
    
    return response.strip()

# Test with T cell signature
t_cell_genes = "CD3D CD3E CD3G IL7R CCR7 LEF1 TCF7 LTB"
prediction = predict_cell_type(t_cell_genes)
print(f"\nPrediction: {prediction}")

Prompt:
--------------------------------------------------
The following is a list of 8 gene names ordered by descending expression level in a Homo sapiens cell. Your task is to give the cell type which this cell belongs to based on its gene expression.
Cell sentence: CD3D CD3E CD3G IL7R CCR7 LEF1 TCF7 LTB.
The cell type corresponding to these genes is:
--------------------------------------------------

Prediction: thymocyte.<ctrl100>


In [11]:
# Test with B cell signature
b_cell_genes = "LDHA H2BC5 SORCS2"
prediction = predict_cell_type(b_cell_genes)
print(f"\nPrediction: {prediction}")

Prompt:
--------------------------------------------------
The following is a list of 3 gene names ordered by descending expression level in a Homo sapiens cell. Your task is to give the cell type which this cell belongs to based on its gene expression.
Cell sentence: LDHA H2BC5 SORCS2.
The cell type corresponding to these genes is:
--------------------------------------------------

Prediction: native cell.
.<ctrl100>


In [12]:
# Try your own gene list!
# Modify this cell to test different gene combinations
custom_genes = "MALAT1 B2M TMSB4X RPS27 TMSB10 RPLP1 RPL13 RPL10 RPS3 MT-CO1 RPL21 RPS19 RPS23 RPL31 IL32 RPL13A EIF1 RPS3A RPL5 RPL18A RPS2 RPS15A RPL19 RPS4X RPS28 RPLP2 RPS27A RPS14 RPL35 RPL30 RPL15 RPS6 ACTB RPL23A RPL3 RPS15 RPL34 RPS18 RPS29 RPL27A RPL28 RPL26 RPL14 RPL39 RPS12 RPS7 RPL7 RPS8 RPL11 RPL18 MT-ND1 RPL32 RPL29 RPS16 RPL10A RPL9 GPR183 RPL36 PFN1 MT-ND4 RACK1 RPL7A RPS24 ATP5E ACTG1 MT-CYB RPS13 HLA-A RPS20 RPL37A RPS26 HLA-B ACTR2 RPL12 RPL37 RPL41 RPL38 RPLP0 RPL6 FTH1 RPL4 RPL27 RPL36AL RPL8 S100A6 S100A4 RPSA MT-CO2 RPL22 RPL23 ZFP36L2 EEF1A1 RPL24 PTMA UBA52 NACA RPS25 RPL17 HINT1 RPS9 MYL6 TRBC2 NOP53 CD52 PFDN5 EEF1B2 ITM2B MT-RNR2 CORO1A RPL35A DDX5 EEF1D LTB MT-ATP6 RPL36A TOMM7 FTL HLA-C MT-ND2 TPT1 RPS10 BTG1 MT-ND3 RPS5 FAU MT-CO3 RPS11 RPS17 CD3D EEF2 RPL9P9 BTF3 COX7C HLA-E PPIA EIF3E CORO1B MYL12A ARHGDIB RPL7P9 ITGB1 LDHB FXYD5 CD3E UQCRB ARHGEF1 PPDPF HNRNPA1 HNRNPA2B1 CD2 PABPC1 MT-RNR1"  # Change this
prediction = predict_cell_type(custom_genes)
print(f"\nPrediction: {prediction}")

Prompt:
--------------------------------------------------
The following is a list of 152 gene names ordered by descending expression level in a Homo sapiens cell. Your task is to give the cell type which this cell belongs to based on its gene expression.
Cell sentence: MALAT1 B2M TMSB4X RPS27 TMSB10 RPLP1 RPL13 RPL10 RPS3 MT-CO1 RPL21 RPS19 RPS23 RPL31 IL32 RPL13A EIF1 RPS3A RPL5 RPL18A RPS2 RPS15A RPL19 RPS4X RPS28 RPLP2 RPS27A RPS14 RPL35 RPL30 RPL15 RPS6 ACTB RPL23A RPL3 RPS15 RPL34 RPS18 RPS29 RPL27A RPL28 RPL26 RPL14 RPL39 RPS12 RPS7 RPL7 RPS8 RPL11 RPL18 MT-ND1 RPL32 RPL29 RPS16 RPL10A RPL9 GPR183 RPL36 PFN1 MT-ND4 RACK1 RPL7A RPS24 ATP5E ACTG1 MT-CYB RPS13 HLA-A RPS20 RPL37A RPS26 HLA-B ACTR2 RPL12 RPL37 RPL41 RPL38 RPLP0 RPL6 FTH1 RPL4 RPL27 RPL36AL RPL8 S100A6 S100A4 RPSA MT-CO2 RPL22 RPL23 ZFP36L2 EEF1A1 RPL24 PTMA UBA52 NACA RPS25 RPL17 HINT1 RPS9 MYL6 TRBC2 NOP53 CD52 PFDN5 EEF1B2 ITM2B MT-RNR2 CORO1A RPL35A DDX5 EEF1D LTB MT-ATP6 RPL36A TOMM7 FTL HLA-C MT-ND2 TPT1 RPS10 B

## Cell Generation

Now let's try generating cells conditioned on cell types:

In [13]:
def generate_cell(cell_type, n_genes=50, organism="Homo sapiens", verbose=True):
    """Generate a cell sentence given a cell type."""
    
    # Create prompt
    formatter = C2SPromptFormatter(task="cell_type_generation", top_k_genes=n_genes)
    prompt_template = formatter.prompts_dict["model_input"][0]
    prompt = prompt_template.format(
        num_genes=n_genes,
        organism=organism,
        cell_type=cell_type
    )
    
    if verbose:
        print(f"Prompt:")
        print("-" * 50)
        print(prompt)
        print("-" * 50)
    
    # Generate cell
    response = csmodel.generate_from_prompt(
        model=model,
        prompt=prompt,
        max_num_tokens=512,
        temperature=0.8,
        do_sample=True
    )
    
    return response.strip()

# Generate a T cell
generated_cell = generate_cell("T cell", n_genes=30)
print(f"\nGenerated T cell: {generated_cell}")

Prompt:
--------------------------------------------------
Generate a list of 30 genes in order of descending expression which represent a Homo sapiens cell of cell type T cell.
Cell sentence:
--------------------------------------------------

Generated T cell: 1:
MALAT1 B2M MT-CO1 MT-ATP6 MT-CO3 MT-CO2 MT-ND4 TMSB4X RPLP1 HLA-A HLA-B RPL41 RPS27 MT-ND3 MT-CYB TMSB10 RPS29 MT-ND1 RPL28 RPS19 RPL13 ACTB RPS24 TPT1 RPS18 HLA-C MT-ND2 CCL5 RPS2 RPL32 RPS28 RPL18A EEF1A1 RPL35 RPL10 RPL34 RPS12 RPL37A ZFP36L2 RPL11 RPS16 PTMA RPS15 ATP5F1E RPL37 RPL6 RPL12 RPL21 RPS7 RPL19 HLA-E S100A6 ACTG1 RPS14 RPS4X RPS27A RPL36 RPL30 RPLP2 CD52 RPS15A RPL14 RPL10A RPL35A RPL29 RPS3 RPL8 RPS21 RPL39 RPL26 RPL13A CD7 RPS23 RPL3 SH3BGRL3 RPL15 RPL23A RPS5 RPL18 RPS25 CXCR4 PTPRC RPS6 SRGN RPL27 GAPDH RPL7A MT-ND5 CD63 RPS8 CFL1 RPS9 MYL6 XCL1 RPL9 FAU CD8A H3-3A IL32 UBA52 RACK1 RPL22 RPS26 UBC RPLP0 RPS3A KRT81 CTSW SERF2 RPL36AL RPL5 CD74 RPSA PFN1 EIF1 HSP90AA1 FTH1 PSME1 RPS13 NACA VIM RPL38 MT2A S1

In [None]:
# Generate different cell types
cell_types_to_generate = ["B cell", "monocyte", "NK cell", "dendritic cell"]

for cell_type in cell_types_to_generate:
    print(f"\n{'='*50}")
    print(f"Generating {cell_type}:")
    generated = generate_cell(cell_type, n_genes=20, verbose=False)
    print(f"Generated: {generated}")


Generating B cell:
Generated: MALAT1 MT-CYB MT-CO1 MT-CO3 MT-ND5 MT-ND2 MT-ND3 MT-ND4L MT-ATP6 TPT1 MT-CO2 RPS3A RPLP1 EEF1A1 RPL11 B2M IL7R RPL10 RPL30 RPS2 RPS3 RPS7 RPL41 RPL13 RPS19 RPL5 ZFP36L2 RPL28 MT-ND1 TXNIP RPL19 RPS27A RPS15 RPS28 RPL37 RPL7A RPS8 RPS27 RPS26 RPL12 RPL15 RPL18 RPL34 RPS4X HLA-B TMSB4X RPL14 RPL37A RPL3 RPS23 RPL35 RPL32 RPS14 RPL29 RPS24 RPL39 RPS15A RPL8 RPLP0 RPL6 RPS25 RPL18A RPS13 RPS12 FTL RPL24 RPS6 RPL9 RPL26 RPSA PTMA RACK1 PABPC1 MT-ND4 RPL36 BTG1 RPL13A NACA ANXA1 RPL21 FAU RPL10A RPL35A EEF1D RPS18 HLA-A BTF3 RPL22 RPL7 HLA-C RPS16 RPS5 ATP5F1E RPL38 PFN1 RPS29 RPL31 CD3D RPS9 RPS21 RPL23A FTH1 LTB UBA52 ZFAS1 TOMM7 SON UQCRB TMSB10 CFL1 COX4I1 HSP90AB1 RPL4 EEF1B2 ITM2B RPL27A RPS20 FXYD5 ACTB RPL36AL SRSF5 RPL23 EIF3E HNRNPA1 COX7C CD48 RPS11 H1-10 NOP53 TCF7 RPL27 PFDN5 UBC CD2 GIMAP7 VIM KLF2 ITM2A PTPRC H3-3A DDX5 RPL36A COMMD6 ARPC2 TMA7 SNHG8 HSP90AA1 S100A10 SPCS1 EIF1 S100A6 CHURC1 ZFP36L

Generating monocyte:
Generated: MALAT1 MT-CYB M


Generating B cell:
Generated: MALAT1 MT-CYB MT-CO1 MT-CO3 MT-ND5 MT-ND2 MT-ND3 MT-ND4L MT-ATP6 TPT1 MT-CO2 RPS3A RPLP1 EEF1A1 RPL11 B2M IL7R RPL10 RPL30 RPS2 RPS3 RPS7 RPL41 RPL13 RPS19 RPL5 ZFP36L2 RPL28 MT-ND1 TXNIP RPL19 RPS27A RPS15 RPS28 RPL37 RPL7A RPS8 RPS27 RPS26 RPL12 RPL15 RPL18 RPL34 RPS4X HLA-B TMSB4X RPL14 RPL37A RPL3 RPS23 RPL35 RPL32 RPS14 RPL29 RPS24 RPL39 RPS15A RPL8 RPLP0 RPL6 RPS25 RPL18A RPS13 RPS12 FTL RPL24 RPS6 RPL9 RPL26 RPSA PTMA RACK1 PABPC1 MT-ND4 RPL36 BTG1 RPL13A NACA ANXA1 RPL21 FAU RPL10A RPL35A EEF1D RPS18 HLA-A BTF3 RPL22 RPL7 HLA-C RPS16 RPS5 ATP5F1E RPL38 PFN1 RPS29 RPL31 CD3D RPS9 RPS21 RPL23A FTH1 LTB UBA52 ZFAS1 TOMM7 SON UQCRB TMSB10 CFL1 COX4I1 HSP90AB1 RPL4 EEF1B2 ITM2B RPL27A RPS20 FXYD5 ACTB RPL36AL SRSF5 RPL23 EIF3E HNRNPA1 COX7C CD48 RPS11 H1-10 NOP53 TCF7 RPL27 PFDN5 UBC CD2 GIMAP7 VIM KLF2 ITM2A PTPRC H3-3A DDX5 RPL36A COMMD6 ARPC2 TMA7 SNHG8 HSP90AA1 S100A10 SPCS1 EIF1 S100A6 CHURC1 ZFP36L

Generating monocyte:
Generated: MALAT1 MT-CYB M

KeyboardInterrupt: 

## Multi cell tissue prediction

In [14]:
task_name = "tissue_prediction"
multi_cell_tissue_formatter = C2SMultiCellPromptFormatter(
    task=task_name,
    top_k_genes=100
)

multi_cell_tissue_prompt = multi_cell_tissue_formatter.prompts_dict["model_input"][0]
print(multi_cell_tissue_prompt)

The following is a list of {num_genes} gene names ordered by descending expression level for {num_cells} different {organism} cells. Your task is to give the tissue which these cells belong to based on their gene expression.
Cell sentences:
{multi_cell_sentences}.
The tissue which these cells belong to is:


In [15]:
# Control and perturbed (CEBPE+CEBPB) from norman dataset
control_cell = "MT-CO2 MT-CO3 RPLP1 RPS27 RPS18 MT-CO1 RPS12 RPS8 RPS2 RPL13A MT-ND4 EEF1A1 RPL34 RPS23 RPS19 FTL RPS29 RPL32 RPS14 RPL19 RPS28 RPL37A MT-ATP6 RPS16 RPL21 RPL7 RPS3 MT-ND2 RPL3 RPS27A YBX1 RPL28 RPL39 RPL15 RPLP0 RPL35A RPS4X RPL13 RPL37 RPLP2 RPL9 RPS3A FTH1 RPS15 RPS15A RPL18A RPL29 RPL12 RPL8 RPS25 RPS7 RPS24 RPS17 RPL27A RPS13 RPS21 MT-ND1 RPL4 RPL11 RPL36 RPS6 RPSA RPS5 RPS20 RPL38 RPL7A MT-CYB RPL10 RPL18 RPL10A RPL14 RPL30 RPL24 UQCRH RPL6 RPL35 RPL31 TMA7 RPL22 FAU MT-ND3 RPL23A GAPDH RPL41 RPL26 TMSB4X RPL23 PPIA UBA52 OAZ1 HINT1 TPT1 RPS11 SERF2 EEF2 SDF2L1 RPS9 NACA RPL5 HMGN1 HSPE1 RANBP1 PTMA HMGA1 COX7A2 RPL27 APOC1 BTF3 NDUFB10 SNRPG HNRNPA1 PABPC1 EEF1B2 YDJC PPP1R14B PFN1 HSPD1 CALR DDX1 CHCHD2 PITX1 HSP90AB1 MT-ND5 COX7C GPX4 EIF3K UBL5 SET UBAC1 B2M SERBP1 UQCR11 UBB EIF5B PGP RPS26 NPM1 NEDD8 APRT TMEM14B ENO1 COX4I1 COX5B NDUFAB1 SLIRP GYPC CST3 HNRNPC HSP90B1 APOBEC3C DDX21 NCL TUBA1B PSMA4 JTB LDHA AP2B1 SRM COX6B1 EIF5A CCDC28B MT-ND4L EIF3I COX7B GADD45GIP1 SERPINB6 TRMT2A NME4 UBE2L3 UQCRQ NDUFB9 GOT2 COX5A SOD1 SON PEBP1 NAP1L1 NSA2 ATP1B3 NDUFA11 ACTG1 RPL36AL TMEM258 MTPN SRSF9 TUBA1C C19orf53 CCNI BCAT1 ATP6V0B SNRPE TMSB10 RBBP7 HMGN3 TOMM20 OST4 PA2G4 ROMO1 COMT EIF3H POMP IMMP1L ACTB PRELID1 ENY2 PARK7 MRPL36 EIF3F PPA1 GSTP1 ANP32B FADS1 GTF3A TMEM160 SUMO1 METTL5 SNX5 SNX3 NDUFB1 NDUFAF2 EIF3E GMPR CCT8 SNRPF SNRPD1 PGK1 EMC6 NDUFS5 ARL15 GTF3C6 GTPBP4 CFL1 GYPB SRRM1 RAN SRSF5 HBZ ZNF22 HDDC2 ATP6V1F DNAJA1 HNRNPA2B1 HES6 SDHC HMGB1 SUB1 SUMO3 NDUFS6 CMSS1 NDUFS7 NDUFS8 PFDN5 UQCRFS1 UQCRC1 FAM200B COA1 TADA3 LYL1 YIF1B AKIRIN1 CMBL RGS10 IMP3 PCBD1 YBX3 LUC7L3 SNRPD2 CCT4 KRT18 RPS27L CUTA MRPL52 MRPL57 DDT MRPS21 MRPS34 SMARCA4 SLC25A6 MRPL35 LSMEM2 PTBP1 ACP1 POLR1D CBX3 LSM12 SLC25A5 SLC25A37 CAPZA1 MARCKSL1 MTDH TOR1AIP2 METAP2 TOP2A C1QBP PSMA5 C4orf3 LDHB PSMA7 MIEN1 PSMB2 CAP1 C9orf78 LAPTM4B LAMTOR5 PPP1CA TPM3 ZNHIT1 PSMD11 PRAME ZBTB8OS POLR2L ATF7IP2 NAA10 RNF126 YWHAZ KIFAP3 THOC2 SNF8 PTP4A2 KIAA0319L SNRNP25 PKIG COQ4 SSU72 BNIP3L MYC SFPQ SNRPB PICALM DSTN COX6C SRP14 PRPF40A PSMC6 CTAG2 SRSF1 PAFAH1B2 CTU1 SRRM2 OSBPL9 CISD2 DDB1 DCUN1D5 DAZAP1 SEC62 YWHAB DIS3 SIVA1 DAP SRSF10 PHF20 GUK1 ZC3HAV1 SLC7A5 FBXO9 DCTN3 SRSF11 FBXL20 HIGD2A HGS PPP1R18 ANAPC16 CSTF3 YRDC BID PIM2 ANAPC11 SFXN1 ASH1L XPR1 EZH2 RANBP10 ANXA2 PTS HDDC3 ALDOA HDAC6 SRSF7 HDAC1 VDAC1 SLC38A5 CKB PHKB CSNK2A1 SRSF3 RAD21 PGAM1 GTF2A2 PDCD10 DEK PDE4DIP POLR2E CMTM6 UPF3B CNBP CNOT7 PDLIM7 CNPY2 FKBP8 ARPC1A QPRT PLTP VBP1 PTGES3 ZNF131 SMIM4 FKBP3 FIS1 UQCRB PRKAR2A COX7A2L GALNT1 PRR13 COX8A CYCS GOLGB1 PFDN2 PNN COX20 FAM120AOS SLC27A5 GNG5 PPDPF PPARGC1B PAK2 GAL PRRC2B PWP1 GSPT1 DDX46 PRDX1 SPN PARP1 PRDX2 SMYD3 PRDX3 COX14 PRMT2 FUS ATF4 PSME3 FBL PSMG3 FAM222B SNRPC CLPX GPS1 POLR2J CLTA PCGF6 PCNP COX6A1 UBL7 EIF1AX TIMM8B CCDC59 APOE RNF181 TXNL4A THRAP3 DYNLL1 CCDC137 TIMM13 TIMM44 ZNF664 MTCH2 EEF1D MT1X MRPS12 TIMMDC1 CBX5 ACP5 APPL1 KPNA6 KPNB1 KRT10 TMA16 ZNF706 MRPS33 TMED4 MRPS15 MYL12A RSF1 MYL6 MYL6B CDC42 EIF4B CDC16 EIF3M NDUFC1 NDUFB7 EIF3J NDUFB4 NDUFA8 EIF3D NDUFA3 S100A4 S100A13 CCT7 KCNH2"
perturbed_cell = "MT-CO3 TMSB4X FTL MT-CO2 RPS27 MT-CO1 EEF1A1 RPS14 RPL34 FTH1 RPS18 MT-ND4 RPLP1 RPL37 RPS29 MT-ATP6 RPS4X RPL13A MT-CYB RPS8 RPS2 RPS19 TPT1 RPL28 RPL32 RPS3 RPS7 RPS27A MT-ND1 RPL37A MT-ND2 RPS12 RPS3A RPL21 RPL13 ACTB RPL39 RPL6 RPL9 RPS24 RPL27A RPL11 RPS28 RPS15 RPS11 RPL15 RPLP0 OAZ1 RPL10 RPL30 MT-ND3 RPS20 RPS15A RPL35 RPL41 YBX1 RPL5 RPS23 RPL29 RPL8 RPL19 RPL14 TMA7 RPL4 RPL10A RPL18A TYROBP RPLP2 RPS17 RPL7 UBA52 GAPDH UQCRH PTMA CFD HMGB1 RPL12 PFN1 RPS25 RPL35A RPSA B2M RPL24 COX7C UQCR11 SUB1 RPL36 HMGN1 RPL18 COX5B SRP14 VIM LST1 PTP4A2 CFL1 SUMO2 NACA RPL7A SH3BGRL3 PRAME SERF2 TMSB10 COX8A FAU CEBPE RPL23 EDF1 CSF3R CHCHD2 SNX3 MPC2 LMO4 PRR13 MSN DBI FAM118A NDUFS5 UBE2L3 MTCH1 COX7A2 FKBP1A RPS6 EIF2S2 RPS21 RPS26 GPX4 LSM6 CAP1 RPL22 HMGN2 MS4A3 PET100 PFDN5 UBL5 RPL23A LSM2 CDCA7 NUCKS1 ITM2B LDHB HMGN3 AURKAIP1 LAMTOR4 RPL27 RPS16 COTL1 PGAM1 ZNF207 UBB UBC COX4I1 GTF2I SLC25A3 PPIA RPS13 PGP TMBIM6 COX7B MYL6 ARHGDIB LGALS1 AGO1 HSP90AA1 SLC25A6 TMED9 MT-ND5 CLIC1 RPL3 SNRPG YWHAH SRSF9 BSG USP20 SRSF3 ARPC5 SRRM2 ARPC3 MYO1F ARPC2 ARPC1B MORF4L1 CIRBP MSRB1 CLC CALM2 DEK MARCKSL1 COMMD6 IQGAP2 CHAF1A RABL6 SAT1 CD63 NDUFA4 FAM89A POMP EIF3K COX6C PNRC1 TMEM219 MRPS16 TUBA1B SERBP1 BUD31 MRPL13 SLC39A3 CCT5 TEX30 COX6A1 SVIP PKM BANF1 UQCRFS1 PHF19 EIF3F SON CELF2 COX17 METTL9 PPIL2 RPS5 SET MYL12A RAB4A CMTM7 RPL38 COX6B1 PCBP2 CIZ1 UCP2 DGCR6L PGLS GPBP1 LIMD2 PEBP1 SOX4 PPP1CA HINT1 NCL PPP6R1 RPL26 BTG1 CCNI HNRNPA1 RSF1 PPP1CB RMI2 PRDX5 ALDOA PYCARD MBOAT7 PPP1R18 TNNT1 TOMM7 ATP6V1F CKLF ATP6V1G1 ACTG1 EEF2 RPL31 WASF2 ELF1 SRSF11 DGUOK SRSF10 POLR2J3 METTL23 HPRT1 FAM32A COPS7A GLIPR1 TKT ANKRD28 CORO1A MFHAS1 NOTCH1 RAB1A CTSC QPRT TOR1A TRA2B TIMM17A SLC35E2B C14orf119 MRPL20 MIB2 BEST1 ANXA2 PPDPF ANP32B MRPL47 MRPL10 PPHLN1 SNX5 CCDC124 YTHDC1 LMNA TPM3 POLR2L ADRM1 COMMD9 EEF1B2 RPL36AL LARP1B SNRPE CANX CALM1 DYRK1A TRPM4 USP32 HVCN1 PPP3R1 NUDT22 LCP1 TMEM167A ISY1 RSL24D1 BTF3 PLEK TMEM14C TMEM14B CAPNS1 DNMT1 PLD3 ICMT SMC4 ZMAT2 BRI3 ZDHHC12 ANAPC13 ANAPC11 PMVK MORF4L2 SCAND1 SLIRP BRK1 RTN4 IDH2 RTN3 CRLS1 TRMT112 SMAD4 LAPTM5 OAZ2 SMAD5 FAM89B NUDT1 NUDCD2 AIF1 SNRNP25 BIRC2 CSTB PPM1G S100A11 ANKRD11 NONO PTPN7 MRPS33 NSRP1 SLBP NOP10 CYBA PPIG WDR62 SNRPB2 TLN1 SNRPD1 CBX3 CBFA2T3 TNFAIP8 BMI1 CASP8AP2 C4orf48 HSPB11 HSPB1 HSPA8 LENG8 ZMYND8 HSP90AB1 TMED5 WDR83OS RUNX1 RANBP1 C4orf3 RPS9 ALAS1 ZNF138 ZC3H15 PPM1N FUS TMA16 EEF1D COMMD3 PARK7 SPG21 SPI1 PRMT2 PARL CD37 GOLIM4 CLPX DDX52 HDDC2 UBE2D2 UBE2M MCCC2 FLCN ATP11B UBE2I PSMA3 EIF5A ETFB HDAC3 VAMP8 CLU RNF130 UQCR10 HEPH BAG6 CHP1 KCNE3 NAA38 RBM26 STAU2 CNIH4 NDUFA3 NDUFA2 SETD2 CCT4 MTRF1L ARIH2 EIF3I EIF3J RNF145 HES6 ATOX1 CMTM6 DECR2 SPN TXNIP RING1 CCDC174 RCC2 SRA1 PAIP2 LSM3 ATRX SRI TAOK3 EMP3 SSR2 RBMX SSNA1 SRSF2 P4HB SSH2 SSBP1 MYB ERH ATP6AP1 TAF10 OSTC OST4 OSBPL9 CDCA7L ACTN1 PAPOLA RNASEH2C PGK1 RHEB TBL1XR1 PSMA7 PGD ARL15 TBCA VPS16 PSMB1 VPS13C UBE2S RHOA RMND5B GYG1 ZNF83 TBC1D20 NDUFAF3 KIAA0100 PIP4K2A MZT2B ATG16L2 RGS10 NDUFS7 TFDP2 HMGA1 HMGB2 NDUFV2 PCBP1 EPM2AIP1 PRKAG1 PIM1 DHFR"
n_genes = len(control_cell.split())


In [16]:
multi_cell_sentences = f"Cell 1:\n{control_cell}\nCell 2:\n{perturbed_cell}\n"

filled_prompt = multi_cell_tissue_prompt.format(
    num_genes=n_genes,
    num_cells=2,
    organism='human',
    multi_cell_sentences=multi_cell_sentences
)


In [23]:
def query_model(prompt, max_tokens=50, temperature=0.1):
    """Send a custom prompt to the model."""
    response = csmodel.generate_from_prompt(
        model=model,
        prompt=prompt,
        max_num_tokens=max_tokens,
        temperature=temperature,
        do_sample=True
    )

    # if model id contains "Gemma", remove the <ctrl100> token
    if "Gemma" in model_id:
        response = response.replace("<ctrl100>", "")

    return response.strip()

In [24]:
print(f"Prompt:")
print("-" * 50)
print(filled_prompt)
print("-" * 50)

print()
print()
response = query_model(filled_prompt, max_tokens=50, temperature=0.1)
print(f"Response:")
print("-" * 50)
print(response)

Prompt:
--------------------------------------------------
The following is a list of 500 gene names ordered by descending expression level for 2 different human cells. Your task is to give the tissue which these cells belong to based on their gene expression.
Cell sentences:
Cell 1:
MT-CO2 MT-CO3 RPLP1 RPS27 RPS18 MT-CO1 RPS12 RPS8 RPS2 RPL13A MT-ND4 EEF1A1 RPL34 RPS23 RPS19 FTL RPS29 RPL32 RPS14 RPL19 RPS28 RPL37A MT-ATP6 RPS16 RPL21 RPL7 RPS3 MT-ND2 RPL3 RPS27A YBX1 RPL28 RPL39 RPL15 RPLP0 RPL35A RPS4X RPL13 RPL37 RPLP2 RPL9 RPS3A FTH1 RPS15 RPS15A RPL18A RPL29 RPL12 RPL8 RPS25 RPS7 RPS24 RPS17 RPL27A RPS13 RPS21 MT-ND1 RPL4 RPL11 RPL36 RPS6 RPSA RPS5 RPS20 RPL38 RPL7A MT-CYB RPL10 RPL18 RPL10A RPL14 RPL30 RPL24 UQCRH RPL6 RPL35 RPL31 TMA7 RPL22 FAU MT-ND3 RPL23A GAPDH RPL41 RPL26 TMSB4X RPL23 PPIA UBA52 OAZ1 HINT1 TPT1 RPS11 SERF2 EEF2 SDF2L1 RPS9 NACA RPL5 HMGN1 HSPE1 RANBP1 PTMA HMGA1 COX7A2 RPL27 APOC1 BTF3 NDUFB10 SNRPG HNRNPA1 PABPC1 EEF1B2 YDJC PPP1R14B PFN1 HSPD1 CALR DDX1 C

In [25]:
custom_prompt = """The following is a list of 500 gene names ordered by descending expression level for 2 different human cells. Your task is to provide the tissue which these cells belong to based on their gene expression.
Cell sentences:
Cell 1:
MT-CO2 MT-CO3 RPLP1 RPS27 RPS18 MT-CO1 RPS12 RPS8 RPS2 RPL13A MT-ND4 EEF1A1 RPL34 RPS23 RPS19 FTL RPS29 RPL32 RPS14 RPL19 RPS28 RPL37A MT-ATP6 RPS16 RPL21 RPL7 RPS3 MT-ND2 RPL3 RPS27A YBX1 RPL28 RPL39 RPL15 RPLP0 RPL35A RPS4X RPL13 RPL37 RPLP2 RPL9 RPS3A FTH1 RPS15 RPS15A RPL18A RPL29 RPL12 RPL8 RPS25 RPS7 RPS24 RPS17 RPL27A RPS13 RPS21 MT-ND1 RPL4 RPL11 RPL36 RPS6 RPSA RPS5 RPS20 RPL38 RPL7A MT-CYB RPL10 RPL18 RPL10A RPL14 RPL30 RPL24 UQCRH RPL6 RPL35 RPL31 TMA7 RPL22 FAU MT-ND3 RPL23A GAPDH RPL41 RPL26 TMSB4X RPL23 PPIA UBA52 OAZ1 HINT1 TPT1 RPS11 SERF2 EEF2 SDF2L1 RPS9 NACA RPL5 HMGN1 HSPE1 RANBP1 PTMA HMGA1 COX7A2 RPL27 APOC1 BTF3 NDUFB10 SNRPG HNRNPA1 PABPC1 EEF1B2 YDJC PPP1R14B PFN1 HSPD1 CALR DDX1 CHCHD2 PITX1 HSP90AB1 MT-ND5 COX7C GPX4 EIF3K UBL5 SET UBAC1 B2M SERBP1 UQCR11 UBB EIF5B PGP RPS26 NPM1 NEDD8 APRT TMEM14B ENO1 COX4I1 COX5B NDUFAB1 SLIRP GYPC CST3 HNRNPC HSP90B1 APOBEC3C DDX21 NCL TUBA1B PSMA4 JTB LDHA AP2B1 SRM COX6B1 EIF5A CCDC28B MT-ND4L EIF3I COX7B GADD45GIP1 SERPINB6 TRMT2A NME4 UBE2L3 UQCRQ NDUFB9 GOT2 COX5A SOD1 SON PEBP1 NAP1L1 NSA2 ATP1B3 NDUFA11 ACTG1 RPL36AL TMEM258 MTPN SRSF9 TUBA1C C19orf53 CCNI BCAT1 ATP6V0B SNRPE TMSB10 RBBP7 HMGN3 TOMM20 OST4 PA2G4 ROMO1 COMT EIF3H POMP IMMP1L ACTB PRELID1 ENY2 PARK7 MRPL36 EIF3F PPA1 GSTP1 ANP32B FADS1 GTF3A TMEM160 SUMO1 METTL5 SNX5 SNX3 NDUFB1 NDUFAF2 EIF3E GMPR CCT8 SNRPF SNRPD1 PGK1 EMC6 NDUFS5 ARL15 GTF3C6 GTPBP4 CFL1 GYPB SRRM1 RAN SRSF5 HBZ ZNF22 HDDC2 ATP6V1F DNAJA1 HNRNPA2B1 HES6 SDHC HMGB1 SUB1 SUMO3 NDUFS6 CMSS1 NDUFS7 NDUFS8 PFDN5 UQCRFS1 UQCRC1 FAM200B COA1 TADA3 LYL1 YIF1B AKIRIN1 CMBL RGS10 IMP3 PCBD1 YBX3 LUC7L3 SNRPD2 CCT4 KRT18 RPS27L CUTA MRPL52 MRPL57 DDT MRPS21 MRPS34 SMARCA4 SLC25A6 MRPL35 LSMEM2 PTBP1 ACP1 POLR1D CBX3 LSM12 SLC25A5 SLC25A37 CAPZA1 MARCKSL1 MTDH TOR1AIP2 METAP2 TOP2A C1QBP PSMA5 C4orf3 LDHB PSMA7 MIEN1 PSMB2 CAP1 C9orf78 LAPTM4B LAMTOR5 PPP1CA TPM3 ZNHIT1 PSMD11 PRAME ZBTB8OS POLR2L ATF7IP2 NAA10 RNF126 YWHAZ KIFAP3 THOC2 SNF8 PTP4A2 KIAA0319L SNRNP25 PKIG COQ4 SSU72 BNIP3L MYC SFPQ SNRPB PICALM DSTN COX6C SRP14 PRPF40A PSMC6 CTAG2 SRSF1 PAFAH1B2 CTU1 SRRM2 OSBPL9 CISD2 DDB1 DCUN1D5 DAZAP1 SEC62 YWHAB DIS3 SIVA1 DAP SRSF10 PHF20 GUK1 ZC3HAV1 SLC7A5 FBXO9 DCTN3 SRSF11 FBXL20 HIGD2A HGS PPP1R18 ANAPC16 CSTF3 YRDC BID PIM2 ANAPC11 SFXN1 ASH1L XPR1 EZH2 RANBP10 ANXA2 PTS HDDC3 ALDOA HDAC6 SRSF7 HDAC1 VDAC1 SLC38A5 CKB PHKB CSNK2A1 SRSF3 RAD21 PGAM1 GTF2A2 PDCD10 DEK PDE4DIP POLR2E CMTM6 UPF3B CNBP CNOT7 PDLIM7 CNPY2 FKBP8 ARPC1A QPRT PLTP VBP1 PTGES3 ZNF131 SMIM4 FKBP3 FIS1 UQCRB PRKAR2A COX7A2L GALNT1 PRR13 COX8A CYCS GOLGB1 PFDN2 PNN COX20 FAM120AOS SLC27A5 GNG5 PPDPF PPARGC1B PAK2 GAL PRRC2B PWP1 GSPT1 DDX46 PRDX1 SPN PARP1 PRDX2 SMYD3 PRDX3 COX14 PRMT2 FUS ATF4 PSME3 FBL PSMG3 FAM222B SNRPC CLPX GPS1 POLR2J CLTA PCGF6 PCNP COX6A1 UBL7 EIF1AX TIMM8B CCDC59 APOE RNF181 TXNL4A THRAP3 DYNLL1 CCDC137 TIMM13 TIMM44 ZNF664 MTCH2 EEF1D MT1X MRPS12 TIMMDC1 CBX5 ACP5 APPL1 KPNA6 KPNB1 KRT10 TMA16 ZNF706 MRPS33 TMED4 MRPS15 MYL12A RSF1 MYL6 MYL6B CDC42 EIF4B CDC16 EIF3M NDUFC1 NDUFB7 EIF3J NDUFB4 NDUFA8 EIF3D NDUFA3 S100A4 S100A13 CCT7 KCNH2
Cell 2:
MT-CO3 TMSB4X FTL MT-CO2 RPS27 MT-CO1 EEF1A1 RPS14 RPL34 FTH1 RPS18 MT-ND4 RPLP1 RPL37 RPS29 MT-ATP6 RPS4X RPL13A MT-CYB RPS8 RPS2 RPS19 TPT1 RPL28 RPL32 RPS3 RPS7 RPS27A MT-ND1 RPL37A MT-ND2 RPS12 RPS3A RPL21 RPL13 ACTB RPL39 RPL6 RPL9 RPS24 RPL27A RPL11 RPS28 RPS15 RPS11 RPL15 RPLP0 OAZ1 RPL10 RPL30 MT-ND3 RPS20 RPS15A RPL35 RPL41 YBX1 RPL5 RPS23 RPL29 RPL8 RPL19 RPL14 TMA7 RPL4 RPL10A RPL18A TYROBP RPLP2 RPS17 RPL7 UBA52 GAPDH UQCRH PTMA CFD HMGB1 RPL12 PFN1 RPS25 RPL35A RPSA B2M RPL24 COX7C UQCR11 SUB1 RPL36 HMGN1 RPL18 COX5B SRP14 VIM LST1 PTP4A2 CFL1 SUMO2 NACA RPL7A SH3BGRL3 PRAME SERF2 TMSB10 COX8A FAU CEBPE RPL23 EDF1 CSF3R CHCHD2 SNX3 MPC2 LMO4 PRR13 MSN DBI FAM118A NDUFS5 UBE2L3 MTCH1 COX7A2 FKBP1A RPS6 EIF2S2 RPS21 RPS26 GPX4 LSM6 CAP1 RPL22 HMGN2 MS4A3 PET100 PFDN5 UBL5 RPL23A LSM2 CDCA7 NUCKS1 ITM2B LDHB HMGN3 AURKAIP1 LAMTOR4 RPL27 RPS16 COTL1 PGAM1 ZNF207 UBB UBC COX4I1 GTF2I SLC25A3 PPIA RPS13 PGP TMBIM6 COX7B MYL6 ARHGDIB LGALS1 AGO1 HSP90AA1 SLC25A6 TMED9 MT-ND5 CLIC1 RPL3 SNRPG YWHAH SRSF9 BSG USP20 SRSF3 ARPC5 SRRM2 ARPC3 MYO1F ARPC2 ARPC1B MORF4L1 CIRBP MSRB1 CLC CALM2 DEK MARCKSL1 COMMD6 IQGAP2 CHAF1A RABL6 SAT1 CD63 NDUFA4 FAM89A POMP EIF3K COX6C PNRC1 TMEM219 MRPS16 TUBA1B SERBP1 BUD31 MRPL13 SLC39A3 CCT5 TEX30 COX6A1 SVIP PKM BANF1 UQCRFS1 PHF19 EIF3F SON CELF2 COX17 METTL9 PPIL2 RPS5 SET MYL12A RAB4A CMTM7 RPL38 COX6B1 PCBP2 CIZ1 UCP2 DGCR6L PGLS GPBP1 LIMD2 PEBP1 SOX4 PPP1CA HINT1 NCL PPP6R1 RPL26 BTG1 CCNI HNRNPA1 RSF1 PPP1CB RMI2 PRDX5 ALDOA PYCARD MBOAT7 PPP1R18 TNNT1 TOMM7 ATP6V1F CKLF ATP6V1G1 ACTG1 EEF2 RPL31 WASF2 ELF1 SRSF11 DGUOK SRSF10 POLR2J3 METTL23 HPRT1 FAM32A COPS7A GLIPR1 TKT ANKRD28 CORO1A MFHAS1 NOTCH1 RAB1A CTSC QPRT TOR1A TRA2B TIMM17A SLC35E2B C14orf119 MRPL20 MIB2 BEST1 ANXA2 PPDPF ANP32B MRPL47 MRPL10 PPHLN1 SNX5 CCDC124 YTHDC1 LMNA TPM3 POLR2L ADRM1 COMMD9 EEF1B2 RPL36AL LARP1B SNRPE CANX CALM1 DYRK1A TRPM4 USP32 HVCN1 PPP3R1 NUDT22 LCP1 TMEM167A ISY1 RSL24D1 BTF3 PLEK TMEM14C TMEM14B CAPNS1 DNMT1 PLD3 ICMT SMC4 ZMAT2 BRI3 ZDHHC12 ANAPC13 ANAPC11 PMVK MORF4L2 SCAND1 SLIRP BRK1 RTN4 IDH2 RTN3 CRLS1 TRMT112 SMAD4 LAPTM5 OAZ2 SMAD5 FAM89B NUDT1 NUDCD2 AIF1 SNRNP25 BIRC2 CSTB PPM1G S100A11 ANKRD11 NONO PTPN7 MRPS33 NSRP1 SLBP NOP10 CYBA PPIG WDR62 SNRPB2 TLN1 SNRPD1 CBX3 CBFA2T3 TNFAIP8 BMI1 CASP8AP2 C4orf48 HSPB11 HSPB1 HSPA8 LENG8 ZMYND8 HSP90AB1 TMED5 WDR83OS RUNX1 RANBP1 C4orf3 RPS9 ALAS1 ZNF138 ZC3H15 PPM1N FUS TMA16 EEF1D COMMD3 PARK7 SPG21 SPI1 PRMT2 PARL CD37 GOLIM4 CLPX DDX52 HDDC2 UBE2D2 UBE2M MCCC2 FLCN ATP11B UBE2I PSMA3 EIF5A ETFB HDAC3 VAMP8 CLU RNF130 UQCR10 HEPH BAG6 CHP1 KCNE3 NAA38 RBM26 STAU2 CNIH4 NDUFA3 NDUFA2 SETD2 CCT4 MTRF1L ARIH2 EIF3I EIF3J RNF145 HES6 ATOX1 CMTM6 DECR2 SPN TXNIP RING1 CCDC174 RCC2 SRA1 PAIP2 LSM3 ATRX SRI TAOK3 EMP3 SSR2 RBMX SSNA1 SRSF2 P4HB SSH2 SSBP1 MYB ERH ATP6AP1 TAF10 OSTC OST4 OSBPL9 CDCA7L ACTN1 PAPOLA RNASEH2C PGK1 RHEB TBL1XR1 PSMA7 PGD ARL15 TBCA VPS16 PSMB1 VPS13C UBE2S RHOA RMND5B GYG1 ZNF83 TBC1D20 NDUFAF3 KIAA0100 PIP4K2A MZT2B ATG16L2 RGS10 NDUFS7 TFDP2 HMGA1 HMGB2 NDUFV2 PCBP1 EPM2AIP1 PRKAG1 PIM1 DHFR
.
These cells are from the:"""

In [26]:
print(f"Prompt:")
print("-" * 50)
print(custom_prompt)
print("-" * 50)

print()
print()
response = query_model(custom_prompt, max_tokens=50, temperature=0.1)
print(f"Response:")
print("-" * 50)
print(response)

Prompt:
--------------------------------------------------
The following is a list of 500 gene names ordered by descending expression level for 2 different human cells. Your task is to provide the tissue which these cells belong to based on their gene expression.
Cell sentences:
Cell 1:
MT-CO2 MT-CO3 RPLP1 RPS27 RPS18 MT-CO1 RPS12 RPS8 RPS2 RPL13A MT-ND4 EEF1A1 RPL34 RPS23 RPS19 FTL RPS29 RPL32 RPS14 RPL19 RPS28 RPL37A MT-ATP6 RPS16 RPL21 RPL7 RPS3 MT-ND2 RPL3 RPS27A YBX1 RPL28 RPL39 RPL15 RPLP0 RPL35A RPS4X RPL13 RPL37 RPLP2 RPL9 RPS3A FTH1 RPS15 RPS15A RPL18A RPL29 RPL12 RPL8 RPS25 RPS7 RPS24 RPS17 RPL27A RPS13 RPS21 MT-ND1 RPL4 RPL11 RPL36 RPS6 RPSA RPS5 RPS20 RPL38 RPL7A MT-CYB RPL10 RPL18 RPL10A RPL14 RPL30 RPL24 UQCRH RPL6 RPL35 RPL31 TMA7 RPL22 FAU MT-ND3 RPL23A GAPDH RPL41 RPL26 TMSB4X RPL23 PPIA UBA52 OAZ1 HINT1 TPT1 RPS11 SERF2 EEF2 SDF2L1 RPS9 NACA RPL5 HMGN1 HSPE1 RANBP1 PTMA HMGA1 COX7A2 RPL27 APOC1 BTF3 NDUFB10 SNRPG HNRNPA1 PABPC1 EEF1B2 YDJC PPP1R14B PFN1 HSPD1 CALR DDX

In [21]:
print(f"Prompt:")
print("-" * 50)
custom_prompt = "The capital of France is"
print(custom_prompt)
print("-" * 50)

print()
print()
response = query_model(custom_prompt, max_tokens=50, temperature=0.1)
print(f"Response:")
print("-" * 50)
print(response)

Prompt:
--------------------------------------------------
The capital of France is
--------------------------------------------------


Response:
--------------------------------------------------
EEF1A1 RPL10 RPS12 RPLP1 RPL13 RPL41 RPS27A RPS27 RPS18 RPS3A RPS23 RPL32 RPL30 RPS1


## Custom Prompts

You can also send completely custom prompts to explore the model's capabilities:

In [22]:
# formatter = C2SPromptFormatter(task="cell_type_prediction", top_k_genes=n_genes)
# prompt_template = formatter.prompts_dict["model_input"][0]
# custom_prompt = prompt_template.format(
#     num_genes=n_genes,
#     organism='human',
#     cell_sentence=control_cell
# )

custom_prompt = f"""The following is a list of {n_genes} gene names ordered by descending expression level in a human cell. Your task is to classify the cell as fast or slow growing.
Cell sentence: {perturbed_cell}
The cell grows:"""

print(f"Prompt:")
print("-" * 50)
print(custom_prompt)
print("-" * 50)

print()
print()
response = query_model(custom_prompt)
print(f"Response:")
print("-" * 50)
print(response)

Prompt:
--------------------------------------------------
The following is a list of 500 gene names ordered by descending expression level in a human cell. Your task is to classify the cell as fast or slow growing.
Cell sentence: MT-CO3 TMSB4X FTL MT-CO2 RPS27 MT-CO1 EEF1A1 RPS14 RPL34 FTH1 RPS18 MT-ND4 RPLP1 RPL37 RPS29 MT-ATP6 RPS4X RPL13A MT-CYB RPS8 RPS2 RPS19 TPT1 RPL28 RPL32 RPS3 RPS7 RPS27A MT-ND1 RPL37A MT-ND2 RPS12 RPS3A RPL21 RPL13 ACTB RPL39 RPL6 RPL9 RPS24 RPL27A RPL11 RPS28 RPS15 RPS11 RPL15 RPLP0 OAZ1 RPL10 RPL30 MT-ND3 RPS20 RPS15A RPL35 RPL41 YBX1 RPL5 RPS23 RPL29 RPL8 RPL19 RPL14 TMA7 RPL4 RPL10A RPL18A TYROBP RPLP2 RPS17 RPL7 UBA52 GAPDH UQCRH PTMA CFD HMGB1 RPL12 PFN1 RPS25 RPL35A RPSA B2M RPL24 COX7C UQCR11 SUB1 RPL36 HMGN1 RPL18 COX5B SRP14 VIM LST1 PTP4A2 CFL1 SUMO2 NACA RPL7A SH3BGRL3 PRAME SERF2 TMSB10 COX8A FAU CEBPE RPL23 EDF1 CSF3R CHCHD2 SNX3 MPC2 LMO4 PRR13 MSN DBI FAM118A NDUFS5 UBE2L3 MTCH1 COX7A2 FKBP1A RPS6 EIF2S2 RPS21 RPS26 GPX4 LSM6 CAP1 RPL22 HMGN2

### Base language model (Pythia)

In [None]:
from transformers import GPTNeoXForCausalLM, AutoTokenizer
import torch

# Check if CUDA is available and set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

model = GPTNeoXForCausalLM.from_pretrained(
  "EleutherAI/pythia-1b-deduped",
  revision="step143000",
  cache_dir="./models/pythia/pythia-1b-deduped/step143000",
  torch_dtype=torch.bfloat16 if device.type == "cuda" else torch.float32
)

# Move model to GPU
model = model.to(device)
print(f"Model loaded on device: {model.device}")

tokenizer = AutoTokenizer.from_pretrained(
  "EleutherAI/pythia-1b-deduped",
  revision="step143000",
  cache_dir="./models/pythia/pythia-1b-deduped/step143000",
)

input_prompt = """The following is a list of around 200 gene names ordered by descending expression level in a human cell. Your task is to classify the cell as fast or slow growing. Respond either with "fast" or "slow".
Cell sentence: MT-CO3 TMSB4X FTL MT-CO2 RPS27 MT-CO1 EEF1A1 RPS14 RPL34 FTH1 RPS18 MT-ND4 RPLP1 RPL37 RPS29 MT-ATP6 RPS4X RPL13A MT-CYB RPS8 RPS2 RPS19 TPT1 RPL28 RPL32 RPS3 RPS7 RPS27A MT-ND1 RPL37A MT-ND2 RPS12 RPS3A RPL21 RPL13 ACTB RPL39 RPL6 RPL9 RPS24 RPL27A RPL11 RPS28 RPS15 RPS11 RPL15 RPLP0 OAZ1 RPL10 RPL30 MT-ND3 RPS20 RPS15A RPL35 RPL41 YBX1 RPL5 RPS23 RPL29 RPL8 RPL19 RPL14 TMA7 RPL4 RPL10A RPL18A TYROBP RPLP2 RPS17 RPL7 UBA52 GAPDH UQCRH PTMA CFD HMGB1 RPL12 PFN1 RPS25 RPL35A RPSA B2M RPL24 COX7C UQCR11 SUB1 RPL36 HMGN1 RPL18 COX5B SRP14 VIM LST1 PTP4A2 CFL1 SUMO2 NACA RPL7A SH3BGRL3 PRAME SERF2 TMSB10 COX8A FAU CEBPE RPL23 EDF1 CSF3R CHCHD2 SNX3 MPC2 LMO4 PRR13 MSN DBI FAM118A NDUFS5 UBE2L3 MTCH1 COX7A2 FKBP1A RPS6 EIF2S2 RPS21 RPS26 GPX4 LSM6 CAP1 RPL22 HMGN2 MS4A3 PET100 PFDN5 UBL5 RPL23A LSM2 CDCA7 NUCKS1 ITM2B LDHB HMGN3 AURKAIP1 LAMTOR4 RPL27 RPS16 COTL1 PGAM1 ZNF207 UBB UBC COX4I1 GTF2I SLC25A3 PPIA RPS13 PGP TMBIM6 COX7B MYL6 ARHGDIB LGALS1 AGO1 HSP90AA1
The cell grows:"""

inputs = tokenizer(input_prompt, return_tensors="pt", truncation=True, max_length=1024)
# Move inputs to the same device as model
inputs = {k: v.to(device) for k, v in inputs.items()}

tokens = model.generate(
    **inputs, 
    max_new_tokens=50,  # Generate up to 50 new tokens
    temperature=0.7,
    do_sample=True,
    pad_token_id=tokenizer.eos_token_id
)
print(tokenizer.decode(tokens[0]))

Using device: cuda


KeyboardInterrupt: 