In [22]:
from transformers import AutoTokenizer, EsmModel
import torch
import sys
import pandas as pd
import numpy as np
from tqdm import tqdm

In [7]:
def load_appris(unique_transcripts=True):
    # generate doc string
    """
    Load the appris data
    :param unique_transcripts: whether to load only unique transcripts
    :return: the appris data
    """
    # ## load human appris
    dir = '/h/phil/Documents/01_projects/contrastive_rna_representation/'

    app_h = pd.read_csv(f'{dir}/data/appris_data_human.principal.txt', sep='\t')
    print(app_h['Gene ID'].duplicated().sum())
    app_h['numeric_value'] = app_h['APPRIS Annotation'].str.split(':').str[1]
    app_h['key_value'] = app_h['APPRIS Annotation'].str.split(':').str[0]
    app_h = app_h.sort_values(
        ['Gene ID', 'key_value', 'numeric_value', "Transcript ID"],
        ascending=[True, False, True, True],
    )
    if unique_transcripts:
        app_h = app_h[~app_h.duplicated('Gene ID')]
        app_h = app_h[~app_h.duplicated('Gene name')]
    return app_h


In [None]:
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")

model = EsmModel.from_pretrained("facebook/esm2_t6_8M_UR50D")

inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")

outputs = model(**inputs)

last_hidden_states = outputs.last_hidden_state

In [2]:
from genome_kit import Genome, Interval
genome = Genome("gencode.v29")
my_genes=[
    'PTGDR2', 'MERTK', 'CD46', 'ITGB3', 'TNFSF4', 'MSR1', 'ICOSLG', 'PROM1', 'FASLG',
    'NT5E', 'GYPA', 'CD72', 'CCR4', 'KLRB1', 'CD2', 'PDCD1LG2', 'CD19', 'SDC1', 
    'KIR3DL1', 'CR1', 'CD63', 'FCRL5', 'DPP4', 'CD28', 'CD209', 'LY75', 'CD177',
    'LAIR1', 'ITGA2B', 'CD200', 'ENTPD1', 'CD70', 'OLR1', 'CD27', 'LILRA4', 'CCR3',
    'SELP', 'MS4A1', 'NOTCH1', 'IFNGR1', 'CD34', 'FAS', 'ERBB2', 'THBD', 'TNFRSF18',
    'CD244', 'PDGFRA', 'TNFRSF14', 'CXCR5', 'PROCR', 'ITGAE', 'CDH5', 'CD9', 'CD44',
    'TNFSF10', 'NRP1', 'IL2RA', 'ANPEP', 'CD99', 'TNFRSF17', 'ENPP3', 'ITGA2', 'IL2RB',
    'PDGFRB', 'PDCD1', 'SIGLEC8', 'VCAM1', 'ITGA1', 'ICAM2', 'FCGR1A', 'CD1D', 'CD36',
    'ICAM1', 'NECTIN2', 'MPL', 'ABCB1', 'KIT', 'CDH17', 'CD8A', 'MRC1', 'ARHGEF5', 'TEK',
    'CD81', 'NCAM1', 'LRPPRC', 'CD93', 'LAMP1', 'CLEC10A', 'KLRK1', 'B3GAT1', 'TNFRSF4', 
    'CD55', 'ITGA7', 'CLEC12A', 'TFRC', 'FLT3', 'CD48', 'CD24', 'CRLF2', 'SELE', 'CD164', 
    'NGFR', 'FLT4', 'CD52', 'CXCR4', 'SLAMF7', 'FCRL4', 'CLEC9A', 'CLEC4C', 'IL4R', 'IL6R',
    'NCR3', 'SPN', 'ICOS', 'CD86', 'PTPRC', 'NOTCH2', 'CR2', 'IL7R', 'CD4', 'BTLA', 
    'TREM1', 'CD207', 'CD1A', 'CDH2', 'SIRPA', 'ITGAX', 'XCR1', 'F3', 'ABCG2', 'LGALS9', 
    'SIGLEC1', 'CCR6', 'CD22', 'FCGR3B', 'C5AR2', 'CD47', 'LAG3', 'CCR5', 'L1CAM', 
    'TNFRSF13B', 'CD40LG', 'ITGAM', 'CD1C', 'KDR', 'ACKR2', 'CD96', 'TNFRSF8', 'CD83', 
    'CTLA4', 'CD80', 'ENG', 'ITGB1', 'TNFRSF13C', 'PECAM1', 'VTCN1', 'CD68', 'ITGB2',
    'CX3CR1', 'CSF1R', 'IL3RA', 'TLR4', 'CD14', 'TIGIT', 'FUT4', 'CDH1', 'CEACAM8', 
    'CD38', 'TNFRSF9', 'CD226', 'CD79A', 'NCR1', 'CD163', 'CD109', 'CD274', 'ITGA4',
    'CLEC1B', 'CD40', 'CCR2', 'GP1BA', 'CD79B', 'HAVCR2', 'MCAM', 'THY1', 'SLC7A5', 
    'CD69', 'CXCR6', 'PVR'
]

In [8]:
app = load_appris()

15791


In [28]:
transcripts_dict = {}
for gene in tqdm(my_genes):
    transcript_name = app[app['Gene name'] == gene].iloc[0]['Transcript ID']
    gene_object = [x for x in genome.genes if x.name == gene][0]
    for transcript in gene_object.transcripts:
        if transcript.id.split('.')[0] == transcript_name:
            transcripts_dict[gene] = transcript

100%|██████████| 188/188 [00:05<00:00, 35.62it/s]


In [36]:
# DNA codon to amino acid lookup table
dna_to_amino_acid_table = {
    'TTT': 'F', 'TTC': 'F', 'TTA': 'L', 'TTG': 'L',
    'CTT': 'L', 'CTC': 'L', 'CTA': 'L', 'CTG': 'L',
    'ATT': 'I', 'ATC': 'I', 'ATA': 'I', 'ATG': 'M',
    'GTT': 'V', 'GTC': 'V', 'GTA': 'V', 'GTG': 'V',
    'TCT': 'S', 'TCC': 'S', 'TCA': 'S', 'TCG': 'S',
    'CCT': 'P', 'CCC': 'P', 'CCA': 'P', 'CCG': 'P',
    'ACT': 'T', 'ACC': 'T', 'ACA': 'T', 'ACG': 'T',
    'GCT': 'A', 'GCC': 'A', 'GCA': 'A', 'GCG': 'A',
    'TAT': 'Y', 'TAC': 'Y', 'TAA': '*', 'TAG': '*',
    'CAT': 'H', 'CAC': 'H', 'CAA': 'Q', 'CAG': 'Q',
    'AAT': 'N', 'AAC': 'N', 'AAA': 'K', 'AAG': 'K',
    'GAT': 'D', 'GAC': 'D', 'GAA': 'E', 'GAG': 'E',
    'TGT': 'C', 'TGC': 'C', 'TGA': '*', 'TGG': 'W',
    'CGT': 'R', 'CGC': 'R', 'CGA': 'R', 'CGG': 'R',
    'AGT': 'S', 'AGC': 'S', 'AGA': 'R', 'AGG': 'R',
    'GGT': 'G', 'GGC': 'G', 'GGA': 'G', 'GGG': 'G'
}

# Function to translate a DNA sequence to an amino acid sequence
def dna_sequence_to_amino_acids(dna_sequence):
    # Ensure the sequence length is a multiple of 3
    if len(dna_sequence) % 3 != 0:
        return "Error: DNA sequence length must be a multiple of 3."

    amino_acid_sequence = ""
    
    # Iterate over the DNA sequence in steps of 3 nucleotides (1 codon)
    for i in range(0, len(dna_sequence), 3):
        codon = dna_sequence[i:i+3].upper()
        amino_acid = dna_to_amino_acid_table.get(codon, "Invalid")
        
        if amino_acid == "Invalid":
            return f"Error: Invalid codon '{codon}' encountered."
        
        amino_acid_sequence += amino_acid
        
        # Stop translation if a stop codon is encountered
        if amino_acid == "*":
            break
    
    return amino_acid_sequence

# Example usage:
dna_sequence = "ATGGCTTACTAG"
amino_acids = dna_sequence_to_amino_acids(dna_sequence)
print(amino_acids)  # Output: MAY* (Methionine, Alanine, Tyrosine, Stop)


MAY*


In [50]:
def get_transcript_amino_acid_seq(transcript, genome):
    exons = transcript.cdss
    seq = ""
    for exon in exons:
        seq += genome.dna(exon)
    return dna_sequence_to_amino_acids(seq)

get_transcript_amino_acid_seq(transcripts_dict['PTGDR2'], genome)

'MSANATLKPLCPILEQMSRLQSHSNTSIRYIDHAAVLLHGLASLLGLVENGVILFVVGCRMRQTVVTTWVLHLALSDLLASASLPFFTYFLAVGHSWELGTTFCKLHSSIFFLNMFASGFLLSAISLDRCLQVVRPVWAQNHRTVAAAHKVCLVLWALAVLNTVPYFVFRDTISRLDGRIMCYYNVLLLNPGPDRDATCNSRQVALAVSKFLLAFLVPLAIIASSHAAVSLRLQHRGRRRPGRFVRLVAAVVAAFALCWGPYHVFSLLEARAHANPGLRPLVWRGLPFVTSLAFFNSVANPVLYVLTCPDMLRKLRRSLRTVLESVLVDDSELGGAGSSRRRRTSSTARSASPLALCSRPEEPRGPARLLGWLLGSCAASPQTGPLNRALSSTSS*'

In [74]:
esm_data = []
for gene, transcript in tqdm(transcripts_dict.items()):
    aa_seq = get_transcript_amino_acid_seq(transcript, genome)
    # * assert appears once in aa_seq
    assert aa_seq.count("*") == 1
    inputs = tokenizer(aa_seq, return_tensors="pt")
    output = model(**inputs)
    output = output['pooler_output'].detach().numpy()
    row = {
        'gene': gene,
        'transcript': transcript.id,
    }
    for i in range(output.shape[1]):
        row[f'esm_{i}'] = output[0, i]
    esm_data.append(row)

  0%|          | 0/186 [00:00<?, ?it/s]

100%|██████████| 186/186 [01:42<00:00,  1.82it/s]


In [75]:
pd.DataFrame(esm_data).to_csv('../data/esm_data.csv', index=False)

In [77]:
pd.read_csv('../data/esm_data.csv')

Unnamed: 0,gene,transcript,esm_0,esm_1,esm_2,esm_3,esm_4,esm_5,esm_6,esm_7,...,esm_310,esm_311,esm_312,esm_313,esm_314,esm_315,esm_316,esm_317,esm_318,esm_319
0,PTGDR2,ENST00000332539.4,0.194653,0.126780,0.182600,-0.029803,0.051965,0.026792,0.109676,-0.213097,...,0.128914,0.092471,0.115099,0.152762,0.111469,0.124649,0.054852,0.122809,-0.076046,-0.050964
1,MERTK,ENST00000295408.8,0.174655,0.162040,0.159292,-0.064161,0.037378,0.025869,0.214430,-0.168868,...,0.130357,0.049052,0.103070,0.135618,0.163095,-0.009734,0.029290,0.162163,-0.071493,0.131510
2,CD46,ENST00000358170.6,0.160673,0.180229,0.124241,-0.047515,0.121338,0.048503,0.204651,-0.230932,...,0.178895,0.039712,0.099157,-0.035398,0.227750,-0.011657,0.019348,0.188651,-0.027152,0.146914
3,ITGB3,ENST00000559488.5,0.033419,0.148301,0.171840,-0.122247,0.052923,0.050449,0.140216,-0.211933,...,0.101055,0.028358,0.171624,0.008692,0.149858,-0.025355,0.080045,0.208778,-0.035152,0.119216
4,TNFSF4,ENST00000281834.3,0.159453,0.045351,0.133093,-0.065311,-0.001200,0.090423,0.113904,-0.243310,...,0.225913,0.040134,0.083178,-0.037448,0.193334,-0.017796,0.067195,0.200707,0.061961,0.077091
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
181,THY1,ENST00000284240.9,0.108301,0.146954,0.171239,-0.001321,0.000676,-0.037172,0.190913,-0.116044,...,0.157813,0.002127,0.212422,0.027532,0.151119,0.039846,0.055663,0.189299,-0.071068,0.024650
182,SLC7A5,ENST00000261622.4,0.110616,0.120736,0.166708,-0.035123,0.000360,0.071963,0.124000,-0.362562,...,0.121050,0.019785,0.111488,0.042031,0.230152,-0.065543,0.002094,0.153272,-0.042976,0.102418
183,CD69,ENST00000228434.7,0.197017,0.137450,0.162856,0.019329,0.073427,0.057510,0.155821,-0.286048,...,0.216885,0.054472,0.062408,-0.029178,0.193458,0.010828,0.010842,0.188487,-0.011039,0.043911
184,CXCR6,ENST00000304552.4,0.182692,0.106208,0.124159,-0.011814,0.007098,0.064117,0.131658,-0.237451,...,0.229905,0.046099,0.081270,0.091075,0.220422,-0.002151,0.039243,0.161981,-0.020390,0.000847
