## An analysis of transcription data using Evo2

It is working okay, but need to train the NN at the end on larger data set


### Load data

In [18]:
import os, io, base64, json, zipfile, time
import numpy as np
import torch
import matplotlib.pyplot as plt
from Bio import SeqIO
import requests
import pandas as pd


FASTA_PATH = "../data/CLIB 122-chr.fasta"
GFF_PATH = "../data/CLIB 122-trans.gff3"
TRANSCRIPT_PATH = "../data/fpkm_counts.csv"

### Load Evo2 API helper functions

In [None]:
BASE_URL = "https://health.api.nvidia.com/v1/biology"
NVIDIA_API_KEY = os.getenv("NVCF_RUN_KEY") or os.getenv("NVIDIA_API_KEY")
if NVIDIA_API_KEY is None:
    raise RuntimeError("Missing NVIDIA API key. Set NVCF_RUN_KEY or NVIDIA_API_KEY.")

HEADERS = {
    "Authorization": f"Bearer {NVIDIA_API_KEY}",
    "Content-Type": "application/json",
    "Accept": "application/json, application/zip",
    "Prefer": "wait=30",
}

def evo2_forward(sequence, output_layers, model="evo2-7b", timeout_s=1800, max_polls=120):
    """Call Evo2 /forward and return (logits, acts_dict) as numpy arrays.
    No classes; just functions and dicts.
    """
    url = f"{BASE_URL}/arc/{model}/forward"
    # Always ask for 'unembed' so we can pop it later if present
    # if "unembed" not in output_layers:
    #     output_layers = ["unembed"] + list(output_layers)

    payload = {"sequence": sequence, "output_layers": output_layers}
    r = requests.post(url, headers=HEADERS, json=payload, timeout=timeout_s)
    print("Initial status:", r.status_code)

    # Poll 202 with POST (same payload) if needed
    polls = 0
    reqid = r.headers.get("Nvcf-Reqid")
    while r.status_code == 202:
        polls += 1
        retry_after = int(r.headers.get("Retry-After", 10))
        status_hdr  = r.headers.get("Nvcf-Status")
        print(f"[poll {polls}] 202 — status={status_hdr}, retry_after={retry_after}s, reqid={reqid}")
        if polls >= max_polls:
            raise TimeoutError(f"Gave up after {polls} polls (last status={status_hdr}).")
        time.sleep(retry_after)
        poll_headers = dict(HEADERS)
        if reqid:
            poll_headers["Nvcf-Reqid"] = reqid
        r = requests.post(url, headers=poll_headers, json=payload, timeout=timeout_s)

    r.raise_for_status()
    ct = (r.headers.get("Content-Type") or "").lower()
    if "application/zip" in ct or "application/octet-stream" in ct:
        with zipfile.ZipFile(io.BytesIO(r.content), "r") as z:
            names = z.namelist()
            if not names:
                raise RuntimeError("Empty ZIP from server.")
            inner = z.read(names[0])
        resp = json.loads(inner.decode("utf-8", "replace"))
        blob = base64.b64decode(resp["data"])
        npz = np.load(io.BytesIO(blob), allow_pickle=False)
    elif "application/json" in ct:
        resp = r.json()
        blob = base64.b64decode(resp["data"])
        npz = np.load(io.BytesIO(blob), allow_pickle=False)
    else:
        raise RuntimeError(f"Unsupported Content-Type: {ct}")

    def squeeze(x):
        return x[0] if (isinstance(x, np.ndarray) and x.ndim >= 3 and x.shape[0] == 1) else x

    tensors = {k: squeeze(npz[k]) for k in npz.files}
    logits = tensors.pop("unembed", None)
    return logits, tensors


### Load data

In [16]:
transcript_df = pd.read_csv(TRANSCRIPT_PATH)
transcript_df

Unnamed: 0,JGI_id,YALI0_id,glucose_1,glucose_2,glucose_3,glycerol_1,glycerol_2,glycerol_3,oleic_acid_1,oleic_acid_2,oleic_acid_3
0,jgi.p|Yarli1|64471,YALI0A00110g,4.88,2.64,3.23,1.21,1.11,3.61,6.77,13.32,7.67
1,jgi.p|Yarli1|64472,YALI0A00132g,1896.43,1845.77,1682.56,1234.99,1189.24,1120.35,864.64,1056.63,928.51
2,jgi.p|Yarli1|64473,YALI0A00154g,31.34,29.21,21.92,215.23,137.85,182.81,1867.59,2153.21,1958.68
3,jgi.p|Yarli1|64474,YALI0A00176g,3.67,3.85,3.15,3.07,1.74,2.09,4.13,2.43,4.34
4,jgi.p|Yarli1|64475,YALI0A00198g,0.00,0.17,0.00,0.00,0.50,1.29,1.55,0.81,1.75
...,...,...,...,...,...,...,...,...,...,...,...
6442,jgi.p|Yarli1|70913,YALI0F32043g,20.63,17.60,17.12,13.94,13.84,11.94,10.76,12.38,10.37
6443,jgi.p|Yarli1|70914,YALI0F32065g,192.58,151.54,270.76,206.89,262.66,280.82,310.55,384.22,313.24
6444,jgi.p|Yarli1|70915,YALI0F32131g,0.12,0.66,0.14,0.34,0.29,0.64,1.01,0.58,1.33
6445,jgi.p|Yarli1|70916,YALI0F32153g,26.16,27.73,28.77,39.22,36.95,39.18,26.98,22.05,23.58


### Make a data frame of gene sequences
Note: the genes can be on the forward or reverse strand and this needs to be accounted for when getting the promoter

In [None]:
promoter_length = 500
gene_data = []

# load genome
records = SeqIO.to_dict(SeqIO.parse(FASTA_PATH, "fasta"))

# open the gff file
with open(GFF_PATH, "r", newline="") as fh:
    for line in fh:
        if line.startswith("YALI") and "GRYC\tgene" in line:
            parts = line.strip().split("\t")
            seq_id = parts[0]
            start = int(parts[3])
            end = int(parts[4])
            strand = parts[6]  # "+" or "-"
            attrs = parts[8]

            gene_name = attrs.split("Name=")[1].split(";")[0]
            gene_seq = records[seq_id].seq[start - 1:end]

            if strand == "+":
                promoter_seq = records[seq_id].seq[
                    max(0, start - promoter_length - 1):start - 1
                ]
            else:
                promoter_seq = records[seq_id].seq[
                    end:end + promoter_length
                ].reverse_complement()

            gene_data.append({
                "seq_id": seq_id,
                "gene_name": gene_name,
                "start": start,
                "end": end,
                "strand": strand,
                "length": end - start,
                "promoter": str(promoter_seq),
                "sequence": str(gene_seq if strand == "+" else gene_seq.reverse_complement()),
            })

genes_df = pd.DataFrame(gene_data)

Unnamed: 0,seq_id,gene_name,start,end,length,promoter,sequence
0,YALI0A,YALI0A00110g,2659,5322,2663,AGAGTGATGCGTTACTCCATCGTCATTTAAAGTCCAAAACGCAAGG...,ATGAGCAAACACACTGAGGTTTTCTCTTCGGAGAAAGTCTCCTCTA...
1,YALI0A,YALI0A00132g,7045,8938,1893,TTTCAAACACTACGAGTGAATTCCAGCGTTCCAAATCCGTGTATAA...,ATGAGTGAAGGAACTTTTGCTGGAGCTGTCGGTATCGATCTTGGAA...
2,YALI0A,YALI0A00154g,11559,12754,1195,GTTAAGAAGAGCACAGGAATACACTCTGAAGAGTTTCGAAGCGAAT...,ATGAAGCTCTCCAATATCTTTGCCCTCGCAACAGTGGCTCTGGCTG...
3,YALI0A,YALI0A00176g,15861,18518,2657,TATCGCTCCGGGTTCCTGCTTTGTTCATGCCCCCCTGGACGCATGT...,ATGAAACTACCCATCATTGCCCTCGCTCTTCTTCTCTCCTCGGTGG...
4,YALI0A,YALI0A00198g,19990,20857,867,GAGATAATGCCATGGAGTGCATAATCGCAACTCTGTTGGGAAGTTA...,ATCATATAACCAATAAATTAACGCAATCTATGTTTTGGGACTTGAA...
...,...,...,...,...,...,...,...
6820,YALI0F,YALI0F32131g,3984292,3985642,1350,TTGGAAGCCTTGGCTGAAGATTCTTGGGAAGCAGTCGGTTTTGGAG...,AGTGAGTATGGAGAGATGGACTGGACTTCCATCATCGAGAAGATGT...
6821,YALI0F,YALI0F32153g,3985784,3989073,3289,TCACTCACTGGAACGACTTCTTTGTTGGATGGCCCGCTGGTTATTC...,CCTGTGCAATAATACATAATAATTTACTATTCCTCCATGCTATACA...
6822,YALI0F,YALI0F32175g,3989661,3992229,2568,GCGGATGAAGAGCATGCTATCTAGTGATAGATTGGATGATGACCAA...,ATGGAAGACAGATTACAGCAAAGGGTACGGGAGACTTTGTCACTAC...
6823,YALI0F,YALI0F32197r,3992635,3993877,1242,AGCACCACCACCTAGGAAATAGAAGCTATCAAACGGCTTGGAAATA...,GTGCATTTTATAGCCTGCTGAGATGCAGAGCTGATAAGCAGACCTG...


### Add sequence and promotor info the transcript dataframe

In [4]:
# quick function to get gene sequences from their name from the dataframe
def get_seq_and_promoter(gene_name):
    gene_row = genes_df[genes_df["gene_name"] == gene_name]
    if not gene_row.empty:
        return gene_row.iloc[0]["sequence"], gene_row.iloc[0]["promoter"]
    else:
        # print(f"Gene {gene_name} not found.")
        return None, None
    
# get_seq_and_promoter("YALI0E20207g")

In [None]:
# drop rows in the transcript dataframe where the YALI0_id is not in the genes dataframe
transcript_seq_df = transcript_df[transcript_df["YALI0_id"].isin(genes_df["gene_name"])].copy()

# add sequence and promoter columns to the transcript dataframe
sequences = []
promoters = []

# loop over the transcript dataframe and get the sequence and promoter for each gene
for _, row in transcript_df.iterrows():
    gene_name = row["YALI0_id"]
    sequence, promoter = get_seq_and_promoter(gene_name)

    # only add if both sequence and promoter are found
    if sequence is not None and promoter is not None:
        sequences.append(sequence)
        promoters.append(promoter)

# add the new columns to the dataframe
transcript_seq_df["sequence"] = sequences
transcript_seq_df["promoter"] = promoters

print(f'After adding sequences and promoters, the transcript dataframe is {len(transcript_df) - len(transcript_seq_df)} genes shorter.')

transcript_seq_df


After adding sequences and promoters, the transcript dataframe is 145 genes shorter


Unnamed: 0,JGI_id,YALI0_id,glucose_1,glucose_2,glucose_3,glycerol_1,glycerol_2,glycerol_3,oleic_acid_1,oleic_acid_2,oleic_acid_3,sequence,promoter
0,jgi.p|Yarli1|64471,YALI0A00110g,4.88,2.64,3.23,1.21,1.11,3.61,6.77,13.32,7.67,ATGAGCAAACACACTGAGGTTTTCTCTTCGGAGAAAGTCTCCTCTA...,AGAGTGATGCGTTACTCCATCGTCATTTAAAGTCCAAAACGCAAGG...
1,jgi.p|Yarli1|64472,YALI0A00132g,1896.43,1845.77,1682.56,1234.99,1189.24,1120.35,864.64,1056.63,928.51,ATGAGTGAAGGAACTTTTGCTGGAGCTGTCGGTATCGATCTTGGAA...,TTTCAAACACTACGAGTGAATTCCAGCGTTCCAAATCCGTGTATAA...
2,jgi.p|Yarli1|64473,YALI0A00154g,31.34,29.21,21.92,215.23,137.85,182.81,1867.59,2153.21,1958.68,ATGAAGCTCTCCAATATCTTTGCCCTCGCAACAGTGGCTCTGGCTG...,GTTAAGAAGAGCACAGGAATACACTCTGAAGAGTTTCGAAGCGAAT...
3,jgi.p|Yarli1|64474,YALI0A00176g,3.67,3.85,3.15,3.07,1.74,2.09,4.13,2.43,4.34,ATGAAACTACCCATCATTGCCCTCGCTCTTCTTCTCTCCTCGGTGG...,TATCGCTCCGGGTTCCTGCTTTGTTCATGCCCCCCTGGACGCATGT...
4,jgi.p|Yarli1|64475,YALI0A00198g,0.00,0.17,0.00,0.00,0.50,1.29,1.55,0.81,1.75,ATCATATAACCAATAAATTAACGCAATCTATGTTTTGGGACTTGAA...,GAGATAATGCCATGGAGTGCATAATCGCAACTCTGTTGGGAAGTTA...
...,...,...,...,...,...,...,...,...,...,...,...,...,...
6441,jgi.p|Yarli1|70912,YALI0F32021g,169.05,163.87,150.14,125.34,124.80,114.74,105.38,113.77,108.63,ATGATTCGAAGACTGGCTCTTGGAAGGGTGAGTATGGAAGGGACAA...,ACATACTGTACTCGTAGTAGTACTACAAATCATTTCTTCAGTAGGA...
6442,jgi.p|Yarli1|70913,YALI0F32043g,20.63,17.60,17.12,13.94,13.84,11.94,10.76,12.38,10.37,ATGGGATGGGGAAAGAAATCGTCGTCATCGGCGGTCCCTAAGCCGC...,GATATAGATGTATAACCAACCAATTATATTGAGCGGCAACAACAAC...
6444,jgi.p|Yarli1|70915,YALI0F32131g,0.12,0.66,0.14,0.34,0.29,0.64,1.01,0.58,1.33,AGTGAGTATGGAGAGATGGACTGGACTTCCATCATCGAGAAGATGT...,TTGGAAGCCTTGGCTGAAGATTCTTGGGAAGCAGTCGGTTTTGGAG...
6445,jgi.p|Yarli1|70916,YALI0F32153g,26.16,27.73,28.77,39.22,36.95,39.18,26.98,22.05,23.58,CCTGTGCAATAATACATAATAATTTACTATTCCTCCATGCTATACA...,TCACTCACTGGAACGACTTCTTTGTTGGATGGCCCGCTGGTTATTC...


### Make an encoding dataframe for block 26 outputs

In [20]:
SAE_LAYER_NAME = "blocks.26.mlp.l3"  # layer used in the paper examples
wanted_layers = [SAE_LAYER_NAME]

logits, acts = evo2_forward(promoter, output_layers=wanted_layers, model="evo2-7b")
print("Layers returned:", list(acts.keys()))
layer_key = f"{SAE_LAYER_NAME}.output"
layer_act = acts[layer_key]  # shape [T, d_hidden], expected T ~ len(sequence)
print("Layer activations shape:", layer_act.shape)

Initial status: 200
Layers returned: ['blocks.26.mlp.l3.output', 'unembed.output']
Layer activations shape: (500, 4096)


In [21]:
def last_token_embedding_from_layer(sequence, layer_name):
    """
    Fetch embeddings from Evo2 API and return last-token vector from the requested layer.
    Assumes returned tensor has shape [batch=1, T, hidden_dim] or [T, hidden_dim].
    """
    logits, acts = evo2_forward(sequence, output_layers=[layer_name])
    arr = acts[f'{layer_name}.output']
    print(arr.shape)
    # Shape normalization
    if arr.ndim == 3:   # [1, T, H]
        return arr[0, -1, :].astype(np.float32)
    elif arr.ndim == 2: # [T, H]
        return arr[-1, :].astype(np.float32)
    else:
        raise ValueError(f"Unexpected array shape for {layer_name}: {arr.shape}")
    

last_token_embedding_from_layer(promoter, 'blocks.26.mlp.l3')

Initial status: 200
(500, 4096)


array([ 0.20703125, -0.1796875 ,  0.08789062, ...,  0.06738281,
       -0.05493164,  0.04125977], shape=(4096,), dtype=float32)

### Add the layer 26 embedding for each promoter sequence

In [None]:
# make transcript_seq_df shorter for testing
transcript_seq_df = transcript_seq_df[:1000]

transcript_seq_df

Unnamed: 0,JGI_id,YALI0_id,glucose_1,glucose_2,glucose_3,glycerol_1,glycerol_2,glycerol_3,oleic_acid_1,oleic_acid_2,oleic_acid_3,sequence,promoter
0,jgi.p|Yarli1|64471,YALI0A00110g,4.88,2.64,3.23,1.21,1.11,3.61,6.77,13.32,7.67,ATGAGCAAACACACTGAGGTTTTCTCTTCGGAGAAAGTCTCCTCTA...,AGAGTGATGCGTTACTCCATCGTCATTTAAAGTCCAAAACGCAAGG...
1,jgi.p|Yarli1|64472,YALI0A00132g,1896.43,1845.77,1682.56,1234.99,1189.24,1120.35,864.64,1056.63,928.51,ATGAGTGAAGGAACTTTTGCTGGAGCTGTCGGTATCGATCTTGGAA...,TTTCAAACACTACGAGTGAATTCCAGCGTTCCAAATCCGTGTATAA...
2,jgi.p|Yarli1|64473,YALI0A00154g,31.34,29.21,21.92,215.23,137.85,182.81,1867.59,2153.21,1958.68,ATGAAGCTCTCCAATATCTTTGCCCTCGCAACAGTGGCTCTGGCTG...,GTTAAGAAGAGCACAGGAATACACTCTGAAGAGTTTCGAAGCGAAT...
3,jgi.p|Yarli1|64474,YALI0A00176g,3.67,3.85,3.15,3.07,1.74,2.09,4.13,2.43,4.34,ATGAAACTACCCATCATTGCCCTCGCTCTTCTTCTCTCCTCGGTGG...,TATCGCTCCGGGTTCCTGCTTTGTTCATGCCCCCCTGGACGCATGT...
4,jgi.p|Yarli1|64475,YALI0A00198g,0.0,0.17,0.0,0.0,0.5,1.29,1.55,0.81,1.75,ATCATATAACCAATAAATTAACGCAATCTATGTTTTGGGACTTGAA...,GAGATAATGCCATGGAGTGCATAATCGCAACTCTGTTGGGAAGTTA...
5,jgi.p|Yarli1|64476,YALI0A00212g,9.15,9.49,9.73,6.31,6.53,6.85,11.27,14.18,13.1,ATGAAACTGCCCACGGAGATTGTCGCCCAGATATGCGCGTCGCTCG...,GAGTAATTAAAGTAGTTCTTTCTACGTACTGTACTGAATAGATATT...
6,jgi.p|Yarli1|64477,YALI0A00264g,489.68,496.98,447.1,431.05,408.37,366.64,246.99,285.32,251.25,ACTTTCACAATCAATTCATTTACATCAATAGACGCTTGATCTAGAT...,AACCCGCAGCCTTTGGAGGACGTCGTTCAGAAGAACCGAAATCCAA...
7,jgi.p|Yarli1|64478,YALI0A00286g,244.82,272.97,249.23,168.19,185.59,158.59,161.13,191.51,173.14,ATGTCGACAACGGTGGAAAAGATCAAAAACATCGAGGAGGAGATGG...,CCGGGAGACATAAAAACCAGATGACTTTTTTGTTTTTGCGATACAT...
8,jgi.p|Yarli1|64479,YALI0A00330g,32.58,27.91,35.69,36.85,31.17,37.82,24.36,20.34,20.96,GCTCATAAAAATACATAGAAATACATCCGAGTGGTGCGGCAGCTTT...,GTCAATATGTCATGGGATTCACTGCAATGTTCCGATTTTCTACCAG...
9,jgi.p|Yarli1|64480,YALI0A00352g,1691.33,1682.06,1425.46,1262.91,1239.39,1130.14,885.43,982.31,880.89,ATGGGTGAGTACGAGCGGCAGCGGGAAGAAGAGGGCAGCGATGACA...,CATGATCTTCAGACACACGGTGGGTGCAATCATCTCTCGCTGATTT...


In [27]:
embedding_list = []

for _, row in transcript_seq_df.iterrows():
    emb = last_token_embedding_from_layer(row['promoter'], SAE_LAYER_NAME)
    embedding_list.append(emb)

# add the embeddings to the dataframe
transcript_seq_df['embedding'] = embedding_list

# save to a dataframe
transcript_seq_df.to_csv("../data/promotor_transcript_seq_with_embeddings.csv", index=False)

Initial status: 200
(500, 4096)
Initial status: 200
(500, 4096)
Initial status: 200
(500, 4096)
Initial status: 200
(500, 4096)
Initial status: 200
(500, 4096)
Initial status: 200
(500, 4096)
Initial status: 200
(500, 4096)
Initial status: 200
(500, 4096)
Initial status: 200
(500, 4096)
Initial status: 200
(500, 4096)


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  transcript_seq_df['embedding'] = embedding_list


In [28]:
transcript_seq_df

Unnamed: 0,JGI_id,YALI0_id,glucose_1,glucose_2,glucose_3,glycerol_1,glycerol_2,glycerol_3,oleic_acid_1,oleic_acid_2,oleic_acid_3,sequence,promoter,embedding
0,jgi.p|Yarli1|64471,YALI0A00110g,4.88,2.64,3.23,1.21,1.11,3.61,6.77,13.32,7.67,ATGAGCAAACACACTGAGGTTTTCTCTTCGGAGAAAGTCTCCTCTA...,AGAGTGATGCGTTACTCCATCGTCATTTAAAGTCCAAAACGCAAGG...,"[-0.052734375, -0.057617188, -0.24511719, 0.06..."
1,jgi.p|Yarli1|64472,YALI0A00132g,1896.43,1845.77,1682.56,1234.99,1189.24,1120.35,864.64,1056.63,928.51,ATGAGTGAAGGAACTTTTGCTGGAGCTGTCGGTATCGATCTTGGAA...,TTTCAAACACTACGAGTGAATTCCAGCGTTCCAAATCCGTGTATAA...,"[0.16308594, -0.18652344, -0.2421875, -0.26367..."
2,jgi.p|Yarli1|64473,YALI0A00154g,31.34,29.21,21.92,215.23,137.85,182.81,1867.59,2153.21,1958.68,ATGAAGCTCTCCAATATCTTTGCCCTCGCAACAGTGGCTCTGGCTG...,GTTAAGAAGAGCACAGGAATACACTCTGAAGAGTTTCGAAGCGAAT...,"[-0.03515625, -0.12109375, -0.16308594, -0.052..."
3,jgi.p|Yarli1|64474,YALI0A00176g,3.67,3.85,3.15,3.07,1.74,2.09,4.13,2.43,4.34,ATGAAACTACCCATCATTGCCCTCGCTCTTCTTCTCTCCTCGGTGG...,TATCGCTCCGGGTTCCTGCTTTGTTCATGCCCCCCTGGACGCATGT...,"[0.04711914, -0.02758789, 0.010437012, 0.01458..."
4,jgi.p|Yarli1|64475,YALI0A00198g,0.0,0.17,0.0,0.0,0.5,1.29,1.55,0.81,1.75,ATCATATAACCAATAAATTAACGCAATCTATGTTTTGGGACTTGAA...,GAGATAATGCCATGGAGTGCATAATCGCAACTCTGTTGGGAAGTTA...,"[0.059570312, -0.005554199, -0.022705078, 0.01..."
5,jgi.p|Yarli1|64476,YALI0A00212g,9.15,9.49,9.73,6.31,6.53,6.85,11.27,14.18,13.1,ATGAAACTGCCCACGGAGATTGTCGCCCAGATATGCGCGTCGCTCG...,GAGTAATTAAAGTAGTTCTTTCTACGTACTGTACTGAATAGATATT...,"[0.25, -0.0071105957, 0.30859375, 0.037353516,..."
6,jgi.p|Yarli1|64477,YALI0A00264g,489.68,496.98,447.1,431.05,408.37,366.64,246.99,285.32,251.25,ACTTTCACAATCAATTCATTTACATCAATAGACGCTTGATCTAGAT...,AACCCGCAGCCTTTGGAGGACGTCGTTCAGAAGAACCGAAATCCAA...,"[0.14257812, -0.1171875, -0.11816406, 0.036132..."
7,jgi.p|Yarli1|64478,YALI0A00286g,244.82,272.97,249.23,168.19,185.59,158.59,161.13,191.51,173.14,ATGTCGACAACGGTGGAAAAGATCAAAAACATCGAGGAGGAGATGG...,CCGGGAGACATAAAAACCAGATGACTTTTTTGTTTTTGCGATACAT...,"[0.016845703, 0.016723633, -0.29101562, -0.324..."
8,jgi.p|Yarli1|64479,YALI0A00330g,32.58,27.91,35.69,36.85,31.17,37.82,24.36,20.34,20.96,GCTCATAAAAATACATAGAAATACATCCGAGTGGTGCGGCAGCTTT...,GTCAATATGTCATGGGATTCACTGCAATGTTCCGATTTTCTACCAG...,"[-0.020263672, 0.064453125, 0.008056641, 0.167..."
9,jgi.p|Yarli1|64480,YALI0A00352g,1691.33,1682.06,1425.46,1262.91,1239.39,1130.14,885.43,982.31,880.89,ATGGGTGAGTACGAGCGGCAGCGGGAAGAAGAGGGCAGCGATGACA...,CATGATCTTCAGACACACGGTGGGTGCAATCATCTCTCGCTGATTT...,"[-0.036865234, -0.09814453, -0.359375, -0.0649..."


### Set up data to train a NN to predict expression levels

In [31]:
df = transcript_seq_df.copy()  # columns shown in your sample

# Option A: use means per substrate (3 outputs)
df["glucose_mean"]  = df[["glucose_1","glucose_2","glucose_3"]].mean(axis=1)
df["glycerol_mean"] = df[["glycerol_1","glycerol_2","glycerol_3"]].mean(axis=1)
df["oleic_mean"]    = df[["oleic_acid_1","oleic_acid_2","oleic_acid_3"]].mean(axis=1)

# (Recommended) log-transform to stabilize scale
for col in ["glucose_mean","glycerol_mean","oleic_mean"]:
    df[col] = np.log1p(df[col])  # log(1+x)

# X: embeddings (list -> np.array); y: 3-dim target
X = np.vstack(df["embedding"].to_numpy())        # shape [N, 4096]
y = df[["glucose_mean","glycerol_mean","oleic_mean"]].to_numpy()  # [N, 3]

print("X shape:", X.shape)
print("y shape:", y.shape)

X shape: (10, 4096)
y shape: (10, 3)


### Standardize features + split

In [None]:
from sklearn.model_selection import train_val_split
from sklearn.preprocessing import StandardScaler

X_train, X_val, y_train, y_val = train_val_split(
    X, y, test_size=0.2, random_state=42
)

scaler = StandardScaler().fit(X_train)
X_train = scaler.transform(X_train)
X_val   = scaler.transform(X_val)

X_train.shape, y_train.shape, X_val.shape, y_val.shape

((8, 4096), (8, 3), (2, 4096), (2, 3))

### PyTorch dataset + model

In [34]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader

class ExprDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.float32)
    def __len__(self): return self.X.shape[0]
    def __getitem__(self, i): return self.X[i], self.y[i]

train_ds = ExprDataset(X_train, y_train)
val_ds   = ExprDataset(X_val, y_val)

train_dl = DataLoader(train_ds, batch_size=128, shuffle=True)
val_dl   = DataLoader(val_ds, batch_size=256, shuffle=False)

# MLP: 4096 -> 1024 -> 3 (glucose, glycerol, oleic)
class ExprMLP(nn.Module):
    def __init__(self, d_in=4096, hidden=1024, d_out=3, p_drop=0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d_in, hidden, bias=True),
            nn.ReLU(),
            nn.Dropout(p_drop),
            nn.Linear(hidden, d_out, bias=True),
        )
    def forward(self, x): return self.net(x)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ExprMLP(d_in=4096, hidden=1024, d_out=3, p_drop=0.1).to(device)
criterion = nn.MSELoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-2)

### Training loop with early stopping + metrics

In [36]:
import math
from sklearn.metrics import r2_score

best_val = math.inf
patience, bad = 15, 0

for epoch in range(200):
    # train
    model.train()
    tr_loss = 0.0
    for xb, yb in train_dl:
        xb, yb = xb.to(device), yb.to(device)
        optimizer.zero_grad()
        pred = model(xb)
        loss = criterion(pred, yb)
        loss.backward()
        optimizer.step()
        tr_loss += loss.item() * xb.size(0)
    tr_loss /= len(train_ds)

    # validate
    model.eval()
    va_loss = 0.0
    preds, targs = [], []
    with torch.no_grad():
        for xb, yb in val_dl:
            xb, yb = xb.to(device), yb.to(device)
            pr = model(xb)
            va_loss += criterion(pr, yb).item() * xb.size(0)
            preds.append(pr.cpu().numpy()); targs.append(yb.cpu().numpy())
    va_loss /= len(val_ds)
    preds = np.vstack(preds); targs = np.vstack(targs)

    # R² per task (glucose/glycerol/oleic)
    r2_per = [r2_score(targs[:,i], preds[:,i]) for i in range(targs.shape[1])]
    print(f"Epoch {epoch:03d} | train {tr_loss:.4f} | val {va_loss:.4f} | R2 {r2_per}")

    # early stopping
    if va_loss < best_val - 1e-5:
        best_val, bad = va_loss, 0
        torch.save({"model": model.state_dict(),
                    "scaler_mean": scaler.mean_, "scaler_scale": scaler.scale_},
                   "expr_mlp_best.pt")
    else:
        bad += 1
        if bad >= patience:
            print("Early stopping.")
            break

Epoch 000 | train 19.6787 | val 24.0226 | R2 [-5.942276954650879, -6.985127925872803, -4.769136905670166]
Epoch 001 | train 9.7673 | val 18.9260 | R2 [-4.514968395233154, -5.384789943695068, -3.410644054412842]
Epoch 002 | train 4.6809 | val 14.7096 | R2 [-3.385404586791992, -4.078326225280762, -2.212418794631958]
Epoch 003 | train 1.9162 | val 11.7199 | R2 [-2.7631258964538574, -3.042250156402588, -1.2526299953460693]
Epoch 004 | train 0.7096 | val 9.7068 | R2 [-2.165656805038452, -2.3499584197998047, -0.8074996471405029]
Epoch 005 | train 0.5928 | val 8.6652 | R2 [-1.917978048324585, -1.8132781982421875, -0.662283182144165]
Epoch 006 | train 0.6237 | val 8.1499 | R2 [-1.723663330078125, -1.5655827522277832, -0.6577060222625732]
Epoch 007 | train 1.2982 | val 8.1350 | R2 [-1.5554914474487305, -1.590092420578003, -0.8173205852508545]
Epoch 008 | train 1.4502 | val 8.2067 | R2 [-1.6361165046691895, -1.626688003540039, -0.7543737888336182]
Epoch 009 | train 1.5863 | val 8.6313 | R2 [-1.8

### Inference

In [38]:
import numpy as np
from torch.serialization import safe_globals
import torch

with safe_globals([np.core.multiarray._reconstruct]):
    ckpt = torch.load("expr_mlp_best.pt", map_location="cpu", weights_only=False)

model.load_state_dict(ckpt["model"])
mean = ckpt["scaler_mean"]
scale = ckpt["scaler_scale"]

model.eval()

def predict_expr(emb_4096):
    x = (emb_4096 - mean) / scale
    x = torch.tensor(x[None, :], dtype=torch.float32)
    with torch.no_grad():
        y_log1p = model(x).numpy()[0]
    return np.expm1(y_log1p)

  with safe_globals([np.core.multiarray._reconstruct]):


### Predict the test set data

In [41]:
import numpy as np
import torch
from sklearn.metrics import r2_score, mean_squared_error

# Assume you already have:
# X_val, y_val  (NumPy arrays)
# model (loaded and eval mode)
# mean, scale  (NumPy arrays from your scaler)

# Normalize with training scaler
X_val_std = (X_val - mean) / scale
X_val_tensor = torch.tensor(X_val_std, dtype=torch.float32)

# Predict
model.eval()
with torch.no_grad():
    y_pred_log = model(X_val_tensor).cpu().numpy()  # predictions in log1p space

# Invert log transform if you used log1p earlier
y_pred = np.expm1(y_pred_log)
y_true = np.expm1(y_val)

# Evaluate performance
r2 = [r2_score(y_true[:, i], y_pred[:, i]) for i in range(y_true.shape[1])]
rmse = [np.sqrt(mean_squared_error(y_true[:, i], y_pred[:, i])) for i in range(y_true.shape[1])]

print("R² per condition:", np.round(r2, 3))
print("RMSE per condition:", np.round(rmse, 3))

# Combine predictions with gene IDs if you like
test_results = pd.DataFrame({
    # "YALI0_id": test_df["YALI0_id"].values,
    "YALI0_id": 'example_gene',
    "glucose_pred": y_pred[:, 0],
    "glycerol_pred": y_pred[:, 1],
    "oleic_pred": y_pred[:, 2],
})
test_results.head()

R² per condition: [-4.68980056e+27 -1.24014084e+25 -3.91975659e+24]
RMSE per condition: [6.08187412e+16 2.01829012e+15 9.18684205e+14]


Unnamed: 0,YALI0_id,glucose_pred,glycerol_pred,oleic_pred
0,example_gene,4.559548,5338.207,988.8033
1,example_gene,8.601069e+16,2854293000000000.0,1299216000000000.0
