In [None]:
import torch
import numpy as np
import anndata
import scanpy as sc
import pandas as pd
from transformers import AutoTokenizer, AutoModel

In [None]:

def cosine_similarity(a, b):
    dot_product = np.dot(a, b)
    norm_a = np.linalg.norm(a)
    norm_b = np.linalg.norm(b)
    return dot_product / (norm_a * norm_b)

adata = anndata.read("./motif_data/atac_hvp.h5ad")
adata_ge = anndata.read("./motif_data/rna.h5ad")
sc.pp.normalize_total(adata_ge)
sc.pp.log1p(adata_ge)

In [None]:
target_name = "MEF2C"
target_motif = "ATGCTAAAAATAGAA"
processed = ""
tokenizer = AutoTokenizer.from_pretrained("../gfm_checkpoint/6-new-12w-0")
for i in range(len(target_motif) - 5):
    processed += target_motif[i:i+6] + " "
processed = processed[:-1]

tokens = tokenizer.encode_plus(processed,padding="max_length")['input_ids']
model = torch.load("./motif_data/cortex_checkpoint",map_location = 'cpu')

peak_embeds = model.encode_peak(torch.tensor(tokens,dtype=torch.long).reshape(1,-1))
peak_embeds = peak_embeds.detach().cpu().numpy()

background_embeds = np.loadtxt("./motif_data/motif_embeds_background.txt")
peak_embeds = (peak_embeds - np.mean(background_embeds,axis=0)) / np.std(background_embeds,axis=0)

alpha = model.alphas.weight.detach().cpu().numpy()
cell_embeds = np.loadtxt("./motif_data/cell_embeds.txt")
cell_embeds = np.matmul(cell_embeds,alpha)

In [None]:
adata_new = anndata.AnnData(cell_embeds)
adata_new.obs['label'] = list(adata.obs['label'])
adata_new.obs['age'] = list(adata.obs['Sample.Age'])
sc.tl.pca(adata_new)
sc.pp.neighbors(adata_new)
sc.tl.umap(adata_new)
sc.pl.umap(adata_new,color='label',save=True)

result_dict = dict()
all_score = []
score = cosine_similarity(cell_embeds,peak_embeds[0])
adata_new.obs['score'] = score

index = list(adata_ge.var['names']).index(target_name)
adata_new.obs['expression'] = np.array(adata_ge.X[:,index].toarray())
sc.pl.umap(adata_new,color='score',save="clean_gfetm_MEF2C_score",cmap = "bwr")
sc.pl.umap(adata_new,color='expression',save="clean_gfetm_MEF2C_expression",cmap='bwr')