# reduce to 256 using PCA, and then generate a df with its gene embeddings (w/ delta embs)

In [101]:
import torch
from transformers import AutoTokenizer, AutoModel
import pandas as pd
import pickle as pkl

from tqdm import tqdm

tqdm.pandas()

tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
model = AutoModel.from_pretrained("h4duan/PAIR-esm2").to("cuda")

In [24]:
def extract_feature(protein):
  ids = tokenizer(protein, return_tensors="pt", padding=True, max_length=1024, truncation=True, return_attention_mask=True)
  input_ids = torch.tensor(ids['input_ids']).to("cuda")
  attention_mask = torch.tensor(ids['attention_mask']).to("cuda")
  with torch.no_grad():
    embedding_repr = model(input_ids=input_ids,attention_mask=attention_mask).last_hidden_state
  return torch.mean(embedding_repr, dim=1)

def extract_features_batch(proteins):
  ids = tokenizer(proteins, return_tensors="pt", padding=True, max_length=1024, truncation=True, return_attention_mask=True)
  input_ids = torch.tensor(ids['input_ids']).to("cuda")
  attention_mask = torch.tensor(ids['attention_mask']).to("cuda")
  with torch.no_grad():
    embedding_repr = model(input_ids=input_ids,attention_mask=attention_mask).last_hidden_state
  attention_mask = attention_mask.unsqueeze(-1)
  attention_mask = attention_mask.expand(-1, -1, embedding_repr.size(-1))
  masked_embedding_repr = embedding_repr * attention_mask
  sum_embedding_repr = masked_embedding_repr.sum(dim=1)
  non_zero_count = attention_mask.sum(dim=1) 
  mean_embedding_repr = sum_embedding_repr / non_zero_count
  return mean_embedding_repr

In [10]:
protein = ["AETCZAO"]

feature = extract_feature(protein)

  input_ids = torch.tensor(ids['input_ids']).to("cuda")
  attention_mask = torch.tensor(ids['attention_mask']).to("cuda")


In [15]:
esm_mapping = torch.load('/work/magroup/shared/Heimdall/data/pretrained_embeddings/ESM2/protein_map_human_ensembl.pt')

In [18]:
path = "/work/magroup/kaileyhu/res/perturbed/gf_12L_30M_i2048_SL/ESM_df"

df = pd.read_hdf(f"{path}/ESM_emb_mat.h5", "table")
df.set_index("Unnamed: 0", inplace = True)

In [25]:
ensembl_path = "/work/magroup/kaileyhu/Geneformer/geneformer/ensembl_mapping_dict_gc95M.pkl"
        
with open(ensembl_path, "rb") as f:
    id_gene_dict = pkl.load(f)

In [44]:
gene_id_dict = {}
for key, value in id_gene_dict.items():
    if value in gene_id_dict:
        continue
    gene_id_dict[value] = key

In [64]:
pair_mapping = {}

genes = []
# ens_ids = []
for key in esm_mapping.keys():
    
    if key not in gene_id_dict:
        genes.append(key)
    else:
        gene = gene_id_dict[key]
        genes.append(gene)

features = extract_features_batch(genes)
feature_cpu = [x.cpu() for x in features]
pair_mapping = dict(zip(esm_mapping.keys(), feature_cpu))

  input_ids = torch.tensor(ids['input_ids']).to("cuda")
  attention_mask = torch.tensor(ids['attention_mask']).to("cuda")


In [65]:
pair_df = pd.DataFrame(pair_mapping).transpose()

In [67]:
from sklearn.decomposition import PCA

In [68]:
pca = PCA(n_components=256)

principalComponents = pca.fit_transform(pair_df)

In [69]:
pca_df = pd.DataFrame(data = principalComponents, columns = [f"pc {i}" for i in range(256)])

In [70]:
pca_df.index = pair_df.index

In [71]:
pca_df.to_csv("/work/magroup/kaileyhu/res/gene_embeddings/PAIR_pca_256.csv")

In [72]:
pca_df

Unnamed: 0,pc 0,pc 1,pc 2,pc 3,pc 4,pc 5,pc 6,pc 7,pc 8,pc 9,...,pc 246,pc 247,pc 248,pc 249,pc 250,pc 251,pc 252,pc 253,pc 254,pc 255
ENSG00000121410,-2.939927,0.235590,0.245271,1.313475,0.139536,0.443911,-0.536603,1.792289,0.575737,0.798791,...,0.131530,0.002528,-0.012133,0.133362,-0.264293,0.157491,0.093152,0.086953,0.075399,0.171969
ENSG00000148584,2.112873,-1.344812,0.445607,0.291106,-1.373380,0.755434,-0.878596,0.541939,-0.461665,-1.120564,...,-0.003471,-0.091927,0.273974,-0.097112,0.064122,-0.120927,-0.258579,0.309600,0.048188,-0.178412
ENSG00000175899,0.424642,-0.603142,-1.718002,-0.538989,-1.827814,1.658736,-1.618316,1.161420,-0.272878,-0.359440,...,0.126916,0.239926,-0.279099,0.267309,-0.184580,-0.251052,0.046126,0.301693,-0.233983,-0.106899
ENSG00000166535,4.252601,0.937075,-2.469287,0.488562,0.406097,0.782439,2.453494,0.021551,0.808722,0.142757,...,0.034944,0.025370,-0.008613,0.085839,0.006990,-0.123735,0.066340,0.036735,0.038398,0.027567
ENSG00000184389,-0.781904,-0.357991,0.616421,-1.250940,-0.266191,-0.874253,0.914282,0.107264,0.503933,-0.035998,...,-0.115956,0.010742,0.020842,0.109264,-0.090908,-0.141917,-0.073681,-0.153174,0.056722,0.097499
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
ENSG00000203995,3.115587,2.058678,0.800066,-0.084726,-0.825784,-0.154372,1.475193,0.135988,-0.246537,0.599952,...,-0.058871,-0.041181,-0.049788,-0.138405,0.003161,-0.077362,-0.268796,-0.067865,0.114491,-0.244308
ENSG00000162378,-1.432864,-1.578524,1.264478,-1.438484,-0.374076,-0.012406,1.622662,-0.210969,-0.952287,-0.723893,...,0.026867,-0.054563,0.081663,-0.094380,-0.077367,0.162979,0.211054,0.153263,-0.008993,0.012680
ENSG00000159840,-1.064318,0.294862,0.556401,-0.778084,0.378676,-0.399659,-0.548248,0.838446,-0.381778,0.636120,...,0.100690,-0.146181,0.025648,-0.097892,-0.253979,-0.038200,-0.027520,-0.011465,-0.018991,-0.073372
ENSG00000074755,-1.159131,0.117286,1.549862,-0.863698,0.811906,0.122840,0.700316,-0.419327,-0.514220,-0.677741,...,-0.091692,0.072423,0.118722,-0.106668,-0.068032,-0.199963,0.035639,-0.043847,-0.208945,-0.020802


In [76]:
gene_cols = [f"gene {i}" for i in range(256)]

In [81]:
new_values_list = [0.0 for _ in range(256)]

In [86]:
df.loc[:, gene_cols] = new_values_list

In [92]:
import sys
sys.path.append('/work/magroup/kaileyhu/synthetic_lethality/utils')

from extract_df_info import *

In [95]:
df['gene'] = list(map(get_genes_from_index, list(df.index)))

  df['gene'] = list(map(get_genes_from_index, list(df.index)))


In [97]:
df['gene'] = df['gene'].apply(lambda x : x[5:])

In [115]:
df[gene_cols] = df['gene'].progress_apply(lambda x : pd.Series(list(pca_df.loc[x])) if x in pca_df.index else pd.Series([-100.0 for _ in range(256)]))

100%|██████████| 2044819/2044819 [04:13<00:00, 8073.41it/s] 


In [117]:
df.drop(columns=['gene'], inplace = True)

In [121]:
df = df[df['gene 251'] != -100]

In [124]:
path = "/work/magroup/kaileyhu/res/perturbed/gf_12L_30M_i2048_SL/PAIR_df"

df.to_hdf(f"{path}/PAIR_emb_mat.h5", key = "table")

In [126]:
path = "/work/magroup/kaileyhu/res/perturbed/gf_12L_30M_i2048_SL/PAIR_df"
name = "PAIR_general"

df = pd.read_hdf(f"{path}/PAIR_emb_mat.h5", "table")