In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES']='0'
from omegaconf import OmegaConf
import diffusion_gosai_update
import flow_gosai
import fm_dna
from hydra import initialize, compose
from hydra.core.global_hydra import GlobalHydra
import dataloader_gosai
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import oracle
from scipy.stats import pearsonr
import torch
from tqdm import tqdm
import diffusion_gosai_cfg
import torch.nn.functional as F
from utils import set_seed
set_seed(0, use_cuda=True)
plt.rcParams['figure.dpi'] = 200

%load_ext autoreload
%autoreload 2

=> Seed of the run set to 0


In [2]:
base_path = '/home/ubuntu'

In [None]:
# our model 
CKPT_PATH = os.path.join(base_path, 'finetune_models/fm_rollout.ckpt')
NUM_SAMPLE_BATCHES = 10
NUM_SAMPLES_PER_BATCH = 64

In [4]:
# reinitialize Hydra
GlobalHydra.instance().clear()

# Initialize Hydra and compose the configuration|
initialize(config_path="configs_gosai", job_name="load_model")
cfg = compose(config_name="config_gosai.yaml")
cfg.eval.checkpoint_path = CKPT_PATH

The version_base parameter is not specified.
Please specify a compatability version level, or None.
Will assume defaults for version 1.1
  initialize(config_path="configs_gosai", job_name="load_model")


In [5]:
model = fm_dna.DiscreteFlowMatchingNet(cfg).cuda() 
model.load_state_dict(torch.load(cfg.eval.checkpoint_path))
model.eval()

DiscreteFlowMatchingNet(
  (model): ConvNet(
    (input_projection): Embedding(5, 128)
    (embed_timestep): Sequential(
      (0): Embedding(1024, 128)
      (1): Unsqueeze()
    )
    (blocks): ModuleList(
      (0-5): 6 x Sequential(
        (0): Transpose()
        (1): Conv1d(128, 128, kernel_size=(31,), stride=(1,), padding=same)
        (2): Transpose()
        (3): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (4): GELU(approximate='none')
        (5): Linear(in_features=128, out_features=128, bias=True)
        (6): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (7): GELU(approximate='none')
      )
    )
    (timestep_embedding_norms): ModuleList(
      (0-5): 6 x LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    )
    (output_projection): Linear(in_features=128, out_features=5, bias=True)
  )
)

In [None]:
old_path = os.path.join(base_path, 'pretrained_models/pretrained_fm.ckpt')
ckpt = torch.load(old_path, map_location="cuda")
old_model = fm_dna.DiscreteFlowMatchingNet(cfg).cuda() 
old_model.load_state_dict(ckpt, strict=False)  
old_model.eval()

DiscreteFlowMatchingNet(
  (model): ConvNet(
    (input_projection): Embedding(5, 128)
    (embed_timestep): Sequential(
      (0): Embedding(1024, 128)
      (1): Unsqueeze()
    )
    (blocks): ModuleList(
      (0-5): 6 x Sequential(
        (0): Transpose()
        (1): Conv1d(128, 128, kernel_size=(31,), stride=(1,), padding=same)
        (2): Transpose()
        (3): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (4): GELU(approximate='none')
        (5): Linear(in_features=128, out_features=128, bias=True)
        (6): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (7): GELU(approximate='none')
      )
    )
    (timestep_embedding_norms): ModuleList(
      (0-5): 6 x LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    )
    (output_projection): Linear(in_features=128, out_features=5, bias=True)
  )
)

In [7]:
'''
zero_alpha_path = os.path.join(base_path, 'mdlm/reward_bp_results_final/zero_alpha.ckpt')
zero_alpha_model = diffusion_gosai_update.Diffusion(cfg).cuda()
zero_alpha_model.load_state_dict(torch.load(zero_alpha_path))
zero_alpha_model.eval()
'''
zero_alpha_path = os.path.join(base_path, 'mdlm/reward_bp_results_final/zero_alpha.ckpt')
ckpt = torch.load(zero_alpha_path, map_location="cuda")
zero_alpha_model = diffusion_gosai_update.Diffusion(cfg).cuda()
zero_alpha_model.load_state_dict(ckpt, strict=False)  # 或 ckpt 本身就是 dicv
zero_alpha_model.eval()


Diffusion(
  (backbone): CNNModel(
    (linear): Conv1d(5, 128, kernel_size=(9,), stride=(1,), padding=(4,))
    (time_embedder): Sequential(
      (0): GaussianFourierProjection()
      (1): Linear(in_features=128, out_features=128, bias=True)
    )
    (convs): ModuleList(
      (0-7): 8 x Conv1d(128, 128, kernel_size=(9,), stride=(1,), padding=(4,))
      (8-11): 4 x Conv1d(128, 128, kernel_size=(9,), stride=(1,), padding=(16,), dilation=(4,))
      (12-15): 4 x Conv1d(128, 128, kernel_size=(9,), stride=(1,), padding=(64,), dilation=(16,))
      (16-19): 4 x Conv1d(128, 128, kernel_size=(9,), stride=(1,), padding=(256,), dilation=(64,))
    )
    (time_layers): ModuleList(
      (0-19): 20 x Dense(
        (dense): Linear(in_features=128, out_features=128, bias=True)
      )
    )
    (norms): ModuleList(
      (0-19): 20 x LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    )
    (final_conv): Sequential(
      (0): Conv1d(128, 128, kernel_size=(1,), stride=(1,))
      (1): R

In [8]:
cfg_cfg = compose(config_name="config_gosai.yaml")
cfg_cfg.model.cls_free_guidance=True
cfg_cfg.model.cls_free_weight=10
cfg_cfg.model.cls_free_prob=0.1
cfg_path = os.path.join(base_path, 'mdlm/outputs_gosai/cfg.ckpt')
cfg_cfg.eval.checkpoint_path = cfg_path
cfg_model = diffusion_gosai_cfg.Diffusion(cfg_cfg, eval=False).cuda()
cfg_model.load_state_dict(torch.load(cfg_cfg.eval.checkpoint_path)['state_dict'])
cfg_model.eval()

Diffusion(
  (backbone): CNNModel(
    (linear): Conv1d(5, 128, kernel_size=(9,), stride=(1,), padding=(4,))
    (time_embedder): Sequential(
      (0): GaussianFourierProjection()
      (1): Linear(in_features=128, out_features=128, bias=True)
    )
    (convs): ModuleList(
      (0-7): 8 x Conv1d(128, 128, kernel_size=(9,), stride=(1,), padding=(4,))
      (8-11): 4 x Conv1d(128, 128, kernel_size=(9,), stride=(1,), padding=(16,), dilation=(4,))
      (12-15): 4 x Conv1d(128, 128, kernel_size=(9,), stride=(1,), padding=(64,), dilation=(16,))
      (16-19): 4 x Conv1d(128, 128, kernel_size=(9,), stride=(1,), padding=(256,), dilation=(64,))
    )
    (time_layers): ModuleList(
      (0-19): 20 x Dense(
        (dense): Linear(in_features=128, out_features=128, bias=True)
      )
    )
    (norms): ModuleList(
      (0-19): 20 x LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    )
    (final_conv): Sequential(
      (0): Conv1d(128, 128, kernel_size=(1,), stride=(1,))
      (1): R

### Sample from the generative model

In [9]:
all_detoeknized_samples = []
all_raw_samples = []
_, val_loader, _ = dataloader_gosai.get_dataloaders_gosai(cfg)
batch = next(iter(val_loader))
seqs = batch["seqs"].cuda()
B, L = seqs.shape
step = 32

for _ in tqdm(range(NUM_SAMPLE_BATCHES)):
    generator, logits = model.sample(
        num_sampling_steps = 8,
        num_samples =  B,
        sequence_length = L,
        yield_intermediate=True,
    )

    for x_t in generator:
        samples = x_t 

    y_soft = F.softmax(logits / cfg.gumbel_temp, dim=-1)
    idx = y_soft.argmax(dim=-1)
    y_hard = F.one_hot(idx, 4).float()
    samples = y_hard.detach() - y_soft.detach() + y_soft
    samples = torch.argmax(samples, dim = -1)
    
    all_raw_samples.append(samples)
    detokenized_samples = dataloader_gosai.batch_dna_detokenize(samples.detach().cpu().numpy())
    all_detoeknized_samples.extend(detokenized_samples)
all_raw_samples = torch.concat(all_raw_samples)
all_raw_samples = F.one_hot(all_raw_samples, num_classes=4)

[debug] Using GosaiDataset defined in: /home/ubuntu/drakes_single/dataloader_gosai.py


100%|███████████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 11.17it/s]


In [10]:
#reward_model_bs = oracle.get_gosai_oracle(mode='train')
#reward_model_bs.eval()

In [11]:
'''
tds_all_detoeknized_samples = []
tds_all_raw_samples = []
for _ in tqdm(range(NUM_SAMPLE_BATCHES)):
    samples = old_model.controlled_sample_TDS(reward_model=reward_model_bs, alpha=0.5, guidance_scale=1000, eval_sp_size=NUM_SAMPLES_PER_BATCH)
    tds_all_raw_samples.append(samples)
    detokenized_samples = dataloader_gosai.batch_dna_detokenize(samples.detach().cpu().numpy())
    tds_all_detoeknized_samples.extend(detokenized_samples)
tds_all_raw_samples = torch.concat(tds_all_raw_samples)
tds_model_logl = old_model.get_likelihood(tds_all_raw_samples, num_steps=128, n_samples=1)
'''

'\ntds_all_detoeknized_samples = []\ntds_all_raw_samples = []\nfor _ in tqdm(range(NUM_SAMPLE_BATCHES)):\n    samples = old_model.controlled_sample_TDS(reward_model=reward_model_bs, alpha=0.5, guidance_scale=1000, eval_sp_size=NUM_SAMPLES_PER_BATCH)\n    tds_all_raw_samples.append(samples)\n    detokenized_samples = dataloader_gosai.batch_dna_detokenize(samples.detach().cpu().numpy())\n    tds_all_detoeknized_samples.extend(detokenized_samples)\ntds_all_raw_samples = torch.concat(tds_all_raw_samples)\ntds_model_logl = old_model.get_likelihood(tds_all_raw_samples, num_steps=128, n_samples=1)\n'

In [12]:
'''
cg_all_detoeknized_samples = []
cg_all_raw_samples = []
for _ in tqdm(range(NUM_SAMPLE_BATCHES)):
    samples = old_model.controlled_sample_CG(reward_model=reward_model_bs, guidance_scale=300000, eval_sp_size=NUM_SAMPLES_PER_BATCH)
    cg_all_raw_samples.append(samples)
    detokenized_samples = dataloader_gosai.batch_dna_detokenize(samples.detach().cpu().numpy())
    cg_all_detoeknized_samples.extend(detokenized_samples)
cg_all_raw_samples = torch.concat(cg_all_raw_samples)
cg_model_logl = old_model.get_likelihood(cg_all_raw_samples, num_steps=128, n_samples=1)
'''

'\ncg_all_detoeknized_samples = []\ncg_all_raw_samples = []\nfor _ in tqdm(range(NUM_SAMPLE_BATCHES)):\n    samples = old_model.controlled_sample_CG(reward_model=reward_model_bs, guidance_scale=300000, eval_sp_size=NUM_SAMPLES_PER_BATCH)\n    cg_all_raw_samples.append(samples)\n    detokenized_samples = dataloader_gosai.batch_dna_detokenize(samples.detach().cpu().numpy())\n    cg_all_detoeknized_samples.extend(detokenized_samples)\ncg_all_raw_samples = torch.concat(cg_all_raw_samples)\ncg_model_logl = old_model.get_likelihood(cg_all_raw_samples, num_steps=128, n_samples=1)\n'

In [13]:

old_all_detoeknized_samples = []
old_all_raw_samples = []

_, val_loader, _ = dataloader_gosai.get_dataloaders_gosai(cfg)
batch = next(iter(val_loader))
seqs = batch["seqs"].cuda()
B, L = seqs.shape
step = 32

for _ in tqdm(range(NUM_SAMPLE_BATCHES)):
    generator, logits = old_model.sample(
        num_sampling_steps = 8,
        num_samples =  B,
        sequence_length = L,
        yield_intermediate=True,
    )
    samples = generator
    for x_t in generator:
        samples = x_t 

    y_soft = F.softmax(logits / cfg.gumbel_temp, dim=-1)
    idx = y_soft.argmax(dim=-1)
    y_hard = F.one_hot(idx, 4).float()
    samples = y_hard.detach() - y_soft.detach() + y_soft
    samples = torch.argmax(samples, dim = -1)
    old_all_raw_samples.append(samples)

    detokenized_samples = dataloader_gosai.batch_dna_detokenize(samples.detach().cpu().numpy())
    old_all_detoeknized_samples.extend(detokenized_samples)
old_all_raw_samples = torch.concat(old_all_raw_samples)
old_all_raw_samples = F.one_hot(old_all_raw_samples, num_classes=4)


[debug] Using GosaiDataset defined in: /home/ubuntu/drakes_single/dataloader_gosai.py


100%|███████████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 37.16it/s]


In [14]:
'''
smc_all_detoeknized_samples = []
smc_all_raw_samples = []
for _ in tqdm(range(NUM_SAMPLE_BATCHES)):
    samples = old_model.controlled_sample_SMC(reward_model=reward_model_bs, alpha=0.5, eval_sp_size=NUM_SAMPLES_PER_BATCH)
    smc_all_raw_samples.append(samples)
    detokenized_samples = dataloader_gosai.batch_dna_detokenize(samples.detach().cpu().numpy())
    smc_all_detoeknized_samples.extend(detokenized_samples)
smc_all_raw_samples = torch.concat(smc_all_raw_samples)
smc_model_logl = old_model.get_likelihood(smc_all_raw_samples, num_steps=128, n_samples=1)
'''

'\nsmc_all_detoeknized_samples = []\nsmc_all_raw_samples = []\nfor _ in tqdm(range(NUM_SAMPLE_BATCHES)):\n    samples = old_model.controlled_sample_SMC(reward_model=reward_model_bs, alpha=0.5, eval_sp_size=NUM_SAMPLES_PER_BATCH)\n    smc_all_raw_samples.append(samples)\n    detokenized_samples = dataloader_gosai.batch_dna_detokenize(samples.detach().cpu().numpy())\n    smc_all_detoeknized_samples.extend(detokenized_samples)\nsmc_all_raw_samples = torch.concat(smc_all_raw_samples)\nsmc_model_logl = old_model.get_likelihood(smc_all_raw_samples, num_steps=128, n_samples=1)\n'

In [15]:
'''
zero_alpha_all_detoeknized_samples = []
zero_alpha_all_raw_samples = []
for _ in tqdm(range(NUM_SAMPLE_BATCHES)):
    samples = zero_alpha_model._sample(eval_sp_size=NUM_SAMPLES_PER_BATCH)
    zero_alpha_all_raw_samples.append(samples)
    detokenized_samples = dataloader_gosai.batch_dna_detokenize(samples.detach().cpu().numpy())
    zero_alpha_all_detoeknized_samples.extend(detokenized_samples)
zero_alpha_all_raw_samples = torch.concat(zero_alpha_all_raw_samples)
zero_alpha_model_logl = old_model.get_likelihood(zero_alpha_all_raw_samples, num_steps=128, n_samples=1)
'''

'\nzero_alpha_all_detoeknized_samples = []\nzero_alpha_all_raw_samples = []\nfor _ in tqdm(range(NUM_SAMPLE_BATCHES)):\n    samples = zero_alpha_model._sample(eval_sp_size=NUM_SAMPLES_PER_BATCH)\n    zero_alpha_all_raw_samples.append(samples)\n    detokenized_samples = dataloader_gosai.batch_dna_detokenize(samples.detach().cpu().numpy())\n    zero_alpha_all_detoeknized_samples.extend(detokenized_samples)\nzero_alpha_all_raw_samples = torch.concat(zero_alpha_all_raw_samples)\nzero_alpha_model_logl = old_model.get_likelihood(zero_alpha_all_raw_samples, num_steps=128, n_samples=1)\n'

In [16]:
'''
cfg_all_detoeknized_samples = []
cfg_all_raw_samples = []
for _ in tqdm(range(NUM_SAMPLE_BATCHES)):
    samples = cfg_model._sample(eval_sp_size=NUM_SAMPLES_PER_BATCH, w=10)
    cfg_all_raw_samples.append(samples)
    detokenized_samples = dataloader_gosai.batch_dna_detokenize(samples.detach().cpu().numpy())
    cfg_all_detoeknized_samples.extend(detokenized_samples)
cfg_all_raw_samples = torch.concat(cfg_all_raw_samples)
cfg_model_logl = old_model.get_likelihood(cfg_all_raw_samples, num_steps=128, n_samples=1)
'''

'\ncfg_all_detoeknized_samples = []\ncfg_all_raw_samples = []\nfor _ in tqdm(range(NUM_SAMPLE_BATCHES)):\n    samples = cfg_model._sample(eval_sp_size=NUM_SAMPLES_PER_BATCH, w=10)\n    cfg_all_raw_samples.append(samples)\n    detokenized_samples = dataloader_gosai.batch_dna_detokenize(samples.detach().cpu().numpy())\n    cfg_all_detoeknized_samples.extend(detokenized_samples)\ncfg_all_raw_samples = torch.concat(cfg_all_raw_samples)\ncfg_model_logl = old_model.get_likelihood(cfg_all_raw_samples, num_steps=128, n_samples=1)\n'

In [17]:
#compare = np.concatenate((old_model_logl.detach().cpu().numpy(), 
                          #zero_alpha_model_logl.detach().cpu().numpy(),
                          #model_logl.detach().cpu().numpy(),
                          #cfg_model_logl.detach().cpu().numpy(),
                          #cg_model_logl.detach().cpu().numpy(),
                          #smc_model_logl.detach().cpu().numpy(),
                          #tds_model_logl.detach().cpu().numpy(),
                         #), axis= 0)

In [18]:
#np.median(compare.reshape(-1, 80), axis=-1) #(-1,640)

In [None]:
config=OmegaConf.to_container(cfg, resolve=False)
highexp_kmers_99, n_highexp_kmers_99, highexp_kmers_999, n_highexp_kmers_999, highexp_set_sp_clss_999, highexp_preds_999, highexp_seqs_999 = oracle.cal_highexp_kmers(return_clss=True,config = cfg)

[debug] Using GosaiDataset defined in: /home/ubuntu/drakes_single/dataloader_gosai.py


### Pred-Activity based on Eval Oracle

In [None]:
# Calculate the predictions of the oracle model on the generated sequences
generated_preds = oracle.cal_gosai_pred_new(all_detoeknized_samples, mode='eval')
old_generated_preds = oracle.cal_gosai_pred_new(old_all_detoeknized_samples, mode='eval')
if all_detoeknized_samples == old_all_detoeknized_samples:
    print("Yes")
#zero_alpha_preds = oracle.cal_gosai_pred_new(zero_alpha_all_detoeknized_samples, mode='eval')
#cfg_preds = oracle.cal_gosai_pred_new(cfg_all_detoeknized_samples, mode='eval')
#cg_preds = oracle.cal_gosai_pred_new(cg_all_detoeknized_samples, mode='eval')
#smc_preds = oracle.cal_gosai_pred_new(smc_all_detoeknized_samples, mode='eval')
#tds_preds = oracle.cal_gosai_pred_new(tds_all_detoeknized_samples, mode='eval')

In [None]:
compare = np.concatenate((generated_preds[:,0],old_generated_preds[:,0]), axis = 0) 
#compare = old_generated_preds[:,0]
#np.concatenate((generated_preds[:,0]),axis = 0)
    #old_generated_preds[:,0],
                            #zero_alpha_preds[:,0],
                            #generated_preds[:,0],)
                            #cfg_preds[:,0],
                            #cg_preds[:,0],
                            #smc_preds[:,0],
                            #tds_preds[:,0]), 
                            #axis= 0)

In [None]:
np.median(compare.reshape(-1, 80), axis=-1)

### ATAC-Acc

In [None]:
generated_preds_atac = oracle.cal_atac_pred_new(all_detoeknized_samples)
old_generated_preds_atac = oracle.cal_atac_pred_new(old_all_detoeknized_samples)
#zero_alpha_preds_atac = oracle.cal_atac_pred_new(zero_alpha_all_detoeknized_samples)
#cfg_preds_atac = oracle.cal_atac_pred_new(cfg_all_detoeknized_samples)
#cg_preds_atac = oracle.cal_atac_pred_new(cg_all_detoeknized_samples)
#smc_preds_atac = oracle.cal_atac_pred_new(smc_all_detoeknized_samples)
#tds_preds_atac = oracle.cal_atac_pred_new(tds_all_detoeknized_samples)

In [None]:
compare = np.concatenate((generated_preds_atac[:,1], old_generated_preds_atac[:,1]), axis = 0) 
#compare = old_generated_preds_atac[:,1]
#np.concatenate((generated_preds_atac[:,1]),axis = 0)
    #old_generated_preds_atac[:,1],
                            #zero_alpha_preds_atac[:,1],
                           
                            #cfg_preds_atac[:,1],
                            #cg_preds_atac[:,1],
                            #smc_preds_atac[:,1],
                            #tds_preds_atac[:,1]), 
                            #axis= 0)

In [None]:
(old_generated_preds_atac[:,1]>0.5).sum()/80 #640

In [None]:
(generated_preds_atac[:,1]>0.5).sum()/80  #640

In [None]:
#(zero_alpha_preds_atac[:,1]>0.5).sum()/640

In [None]:
#(cfg_preds_atac[:,1]>0.5).sum()/640

In [None]:
#(cg_preds_atac[:,1]>0.5).sum()/640

In [None]:
#(smc_preds_atac[:,1]>0.5).sum()/640

In [None]:
#(tds_preds_atac[:,1]>0.5).sum()/640

### 3-mer Pearson Correlation

In [None]:
def compare_kmer(kmer1, kmer2, n_sp1, n_sp2, title):
    kmer_set = set(kmer1.keys()) | set(kmer2.keys())
    counts = np.zeros((len(kmer_set), 2))
    for i, kmer in enumerate(kmer_set):
        if kmer in kmer1:
            counts[i][1] = kmer1[kmer] * n_sp2 / n_sp1
        if kmer in kmer2:
            counts[i][0] = kmer2[kmer]
            
    fig, ax = plt.subplots(figsize=(2.5, 2.5))
    ax.scatter(counts[:, 0], counts[:, 1], alpha=0.5)
    ax.set_title(title)
    ax.set_xlabel("k-mer train top 0.1% count")
    ax.set_ylabel("k-mer generated count")
    ax.set_ylim((-5, np.max(counts) + 5))
    ax.set_xlim((-5, np.max(counts) + 5))
    ax.set_yticklabels([])
    ax.set_xticklabels([])
    ax.text(0.5, 0.5, f"Pearson Corr: {pearsonr(counts[:, 0], counts[:, 1])[0]:.3f}")
    plt.show()

    print(pearsonr(counts[:, 0], counts[:, 1]))

In [None]:
generated_kmer = oracle.count_kmers(all_detoeknized_samples)
old_generated_kmer = oracle.count_kmers(old_all_detoeknized_samples)
#zero_alpha_generated_kmer = oracle.count_kmers(zero_alpha_all_detoeknized_samples)
#cfg_generated_kmer = oracle.count_kmers(cfg_all_detoeknized_samples)
#cg_generated_kmer = oracle.count_kmers(cg_all_detoeknized_samples)
#smc_generated_kmer = oracle.count_kmers(smc_all_detoeknized_samples)
#tds_generated_kmer = oracle.count_kmers(tds_all_detoeknized_samples)

In [None]:
compare_kmer(highexp_kmers_999, generated_kmer, n_highexp_kmers_999, len(all_detoeknized_samples), title=r"Finetuned")

In [None]:
compare_kmer(highexp_kmers_999, old_generated_kmer, n_highexp_kmers_999, len(old_all_detoeknized_samples), title=r"Pretrained")

In [None]:
#compare_kmer(highexp_kmers_999, zero_alpha_generated_kmer, n_highexp_kmers_999, len(zero_alpha_all_detoeknized_samples), 
             #title=r"Finetuned ($\alpha=0$)")

In [None]:
#compare_kmer(highexp_kmers_999, cfg_generated_kmer, n_highexp_kmers_999, len(cfg_all_detoeknized_samples), title=r"CFG")

In [None]:
#compare_kmer(highexp_kmers_999, cg_generated_kmer, n_highexp_kmers_999, len(cg_all_detoeknized_samples), title=r"CG")

In [None]:
#compare_kmer(highexp_kmers_999, smc_generated_kmer, n_highexp_kmers_999, len(smc_all_detoeknized_samples), title=r"SMC")

In [None]:
#compare_kmer(highexp_kmers_999, tds_generated_kmer, n_highexp_kmers_999, len(tds_all_detoeknized_samples), title=r"TDS")

### JASPER Motif Analysis

In [None]:
from grelu.interpret.motifs import scan_sequences
from grelu.io.motifs import get_jaspar
motifs = get_jaspar(
    release='JASPAR2024',    
    collection='CORE',        
    tax_group='vertebrates',  
    # species='9606',         
)
motif_count = scan_sequences(all_detoeknized_samples, motifs=motifs, pthresh=1e-4, rc=True)


In [None]:
motif_count = scan_sequences(all_detoeknized_samples, motifs=motifs, pthresh=1e-4, rc=True)
print(motif_count)
motif_count_sum = motif_count['motif'].value_counts()
motif_count_old = scan_sequences(old_all_detoeknized_samples, motifs = motifs, pthresh = 1e-4, rc = True)
motif_count_old_sum = motif_count_old['motif'].value_counts()
#motif_count_zero_alpha = scan_sequences(zero_alpha_all_detoeknized_samples, 'jaspar')
#motif_count_zero_alpha_sum = motif_count_zero_alpha['motif'].value_counts()
#motif_count_cfg = scan_sequences(cfg_all_detoeknized_samples, 'jaspar')
#motif_count_cfg_sum = motif_count_cfg['motif'].value_counts()
#motif_count_cg = scan_sequences(cg_all_detoeknized_samples, 'jaspar')
#motif_count_cg_sum = motif_count_cg['motif'].value_counts()
#motif_count_smc = scan_sequences(smc_all_detoeknized_samples, 'jaspar')
#motif_count_smc_sum = motif_count_smc['motif'].value_counts()
#motif_count_tds = scan_sequences(tds_all_detoeknized_samples, 'jaspar')
#motif_count_tds_sum = motif_count_tds['motif'].value_counts()
motif_count_top = scan_sequences(highexp_seqs_999, motifs=motifs, pthresh=1e-4, rc=True)
motif_count_top_sum = motif_count_top['motif'].value_counts()

In [None]:
motifs_summary = pd.concat(
    [motif_count_top_sum, motif_count_sum, motif_count_old_sum], axis=1) #, join='inner'  # 只保留共同索引
#).dropna()
motifs_summary.columns = ['top_data', 'finetuned', 'pretrained']
motifs_summary.corr(method='spearman')
