# General Imports

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os

import numpy as np
import pandas as pd
import scipy
import torch
import tqdm
from accelerate import Accelerator
from peft import LoraConfig, get_peft_model

from torch.utils.data import DataLoader

from polya_project.data import GenomeIntervalDataset

from utils.utils import fix_rev_comp_multiome, undo_squashed_scale, get_pseudobulk_count_pred,get_gene_slice_and_strand
from modeling.scborzoi import ScBorzoi
from data.scdata import onTheFlyDataset, onTheFlyPseudobulkDataset

In [None]:
data_path = 'tmp'

# scDog

### Functions for Pearson correlation

In [3]:
def stack_and_pearson(x,y):
    return torch.corrcoef(torch.stack([x,y]))[0,1]
batched_pearson = torch.vmap(stack_and_pearson, chunk_size=1)

def predict(model, seqs, seqs_rev_comp, conv_weights, conv_biases, bins_to_predict = None):
    bs = seqs.shape[0]
    assert bs == 1
    with torch.no_grad():
        with torch.autocast("cuda"):
            outputs = model.forward_sequence_w_convs(seqs, conv_weights, conv_biases, bins_to_predict = bins_to_predict)
            if bins_to_predict is not None:
                outputs_rev_comp = model.forward_sequence_w_convs(seqs_rev_comp, conv_weights, conv_biases, bins_to_predict = (6143 - bins_to_predict))
            else:
                outputs_rev_comp = model.forward_sequence_w_convs(seqs_rev_comp, conv_weights, conv_biases, bins_to_predict = None)
    flipped_version = torch.flip(outputs_rev_comp,(1,-3))
    outputs_rev_comp = fix_rev_comp_multiome(flipped_version) #fix_rev_comp2(flipped_version)
    #outputs_rev_comp = fix_rev_comp2(flipped_version) #fix_rev_comp2(flipped_version)
    return (outputs + outputs_rev_comp)/2

### Load the model

In [17]:
accelerator = Accelerator(step_scheduler_with_optimizer = False)



In [18]:
import pybigtools
csb = ScBorzoi(
    cell_emb_dim = 14, 
    n_tracks=3,
    embedding_dim = 1920, 
    return_center_bins_only = True, 
    disable_cache=False,
    use_transform_borzoi_emb=True,#True,
)

config = LoraConfig(
    #r = 64,
    #use_rslora = True,
    target_modules=r"(?!separable\d+).*conv_layer|.*to_q|.*to_v|transformer\.\d+\.1\.fn\.1|transformer\.\d+\.1\.fn\.4",
)
csb = get_peft_model(csb, config)

csb.load_state_dict(torch.load(os.path.join(data_path, 'borzoi_saved_models/csb_epoch_20_scDog-neurips-PMseq-4nodrop-softclip5-64cell-normalizeATAC-fixedemb-noneighbors-rightembeddingrightsplit-longer/pytorch_model.bin'))) 


<All keys matched successfully>

In [19]:
clip_soft = 5

In [20]:
#csb = csb.merge_and_unload()
csb = accelerator.prepare(csb)

In [21]:
context_length = 524288

### Specify genome

In [22]:
gtf_file = os.path.join(data_path, "gencode.v32.annotation.sorted.gtf.gz")
fasta_file = os.path.join(data_path, 'genome_human.fa')
bed_file = os.path.join(data_path, 'sequences.bed')

### Load neighbors and embedding

In [23]:
base_path = os.path.join(data_path, 'snapatac', 'pseudobulks_fixed/')

In [24]:
sample = 'merged'
neighbors = scipy.sparse.load_npz(os.path.join(data_path, 'borzoi_training_data_fixed', 'no_neighbors.npz'))
embedding = pd.read_parquet(os.path.join(data_path, 'borzoi_training_data_fixed',  'embedding_no_val_genes_new.pq'))

In [25]:
cell_type_index = pd.read_parquet(os.path.join(data_path,  'borzoi_training_data_fixed/celltype_fixed.pq'))
cell_type_index['size'] = cell_type_index['cellindex'].apply(lambda x: len(x))

In [26]:
cell_type_index['celltype'] = cell_type_index['celltype'].str.replace(' ', '_').replace(r"G/M_prog", "G+M_prog").replace("MK/E_prog", "MK+E_prog") #+ '_filtered'
cell_type_index = cell_type_index.sort_values('celltype')

## Eval on Val

### Sequence dataloader 

This only loads the sequences in the right order

In [27]:
import pickle
with open(os.path.join(data_path, 'gencode.v32.annotation.gtf.transcriptome'), 'rb') as handle:
    transcriptome = pickle.load(handle)

In [28]:
import polars as pl
filter_val = lambda df: df.filter(True)#
val_ds = GenomeIntervalDataset(
    bed_file = os.path.join(data_path,'borzoi_training_data_fixed', 'test_gene_sequences.csv'),
    fasta_file = fasta_file,
    filter_df_fn = lambda df: df.filter((pl.col('column_2') >=0)), 
    return_seq_indices = False,
    shift_augs = (0,0),
    rc_aug = False,
    return_augs = True,
    context_length = context_length,
    chr_bed_to_fasta_map = {}
)
len(val_ds)

2550

In [29]:
val_dataset = onTheFlyDataset(
    None,
    None,
    neighbors,
    embedding,
    val_ds,
    get_targets= False,
    random_cells = False,
    cells_to_run = None, 
    clip_soft = clip_soft,
    )
val_loader = DataLoader(val_dataset, batch_size=1, shuffle = False, num_workers = 1)
csb, val_loader = accelerator.prepare(csb, val_loader)
clip_soft

5

### Target dataloader on bigwig pseudobulk files

In [30]:
val_dataset_targets = onTheFlyPseudobulkDataset(
    cell_types = cell_type_index['celltype'].values,
    ds = val_ds, 
    base_path = base_path,
)

### Get cell conv_weights for all cells

In [31]:
csb.eval()
cell_indices  = []
size_factors_per_ct = []
for _, row in tqdm.tqdm(cell_type_index.iterrows(),disable = True):
    cell_indices.append(
        torch.from_numpy(
            np.vstack(
                embedding.iloc[row['cellindex']]['embedding'].values # gets embeddings of all cells of the cell type
                )
            ).unsqueeze(0)
        ) # prep cell_embeddings

# get conv weights and biases for all cells sorted by cell type in a list
cell_emb_conv_weights_and_biases = []
for cell_emb_idx in tqdm.tqdm(cell_indices, disable = True):
    cell_emb_idx = cell_emb_idx.cuda()
    conv_weights, conv_biases = csb.forward_cell_embs_only(cell_emb_idx)
    cell_emb_conv_weights_and_biases.append((conv_weights, conv_biases))


### Get counts over exons

In [18]:
#num_neighbors = 100
num_neighbors = 1

all_outputs, all_targets = [], []
val_dataset_target_loader = iter(DataLoader(val_dataset_targets, batch_size=1, shuffle = False, num_workers = 4))

# iterate over all val gene sequences
for i,x in tqdm.tqdm(enumerate(val_loader), disable = False, total=len(val_dataset)):
    gene_slice, strand = get_gene_slice_and_strand(transcriptome, val_dataset_targets.genome_ds.df[i, 'column_4'], val_dataset_targets.genome_ds.df[i, 'column_2'], span = False)
    targets = (next(val_dataset_target_loader)[2]).float().cuda()
    if len(gene_slice) == 0:
        continue
    
    bs = x[0].shape[0]
    seqs = x[0].cuda().permute(0,2,1)
    stacked_outputs = []

    stacked_outputs = get_pseudobulk_count_pred(
        csb = csb, 
        seqs = seqs, 
        cell_emb_conv_weights_and_biases = cell_emb_conv_weights_and_biases, 
        gene_slice = gene_slice,
        strand = strand, 
        model_type = "multiome",
        predict = predict,  
        clip_soft = clip_soft, 
        num_neighbors = 1
    )
    
    temp_stack = []
    all_outputs.append(stacked_outputs)
    if strand == '+':
        all_targets.append(undo_squashed_scale(targets[0, gene_slice, ::2], clip_soft=384).sum(axis=0).detach().clone().cpu().squeeze())
    elif strand == '-':
        all_targets.append(undo_squashed_scale(targets[0, gene_slice, 1::2], clip_soft=384).sum(axis=0).detach().clone().cpu().squeeze())
    
    
all_outputs, all_targets = torch.vstack(all_outputs).clone().numpy(force=True),torch.vstack(all_targets).clone().numpy(force=True)

100%|██████████| 2550/2550 [19:44<00:00,  2.15it/s]


In [32]:
num_neighbors = 1

gene_names = []
# iterate over all val gene sequences
for i,x in tqdm.tqdm(enumerate(val_loader), disable = False, total=len(val_dataset)):
    #if i == 20:
    #    break
    gene_slice, strand = get_gene_slice_and_strand(transcriptome, val_dataset_targets.genome_ds.df[i, 'column_4'], val_dataset_targets.genome_ds.df[i, 'column_2'], span = False)
    if len(gene_slice) == 0:
        continue
    gene_names.append(val_dataset_targets.genome_ds.df[i, 'column_4'])

100%|██████████| 2550/2550 [00:42<00:00, 59.76it/s]


In [None]:
torch.save(all_outputs, os.path.join(data_path, "eval_fixed/count_predicted_test_no_neighbor.pq"))
torch.save(all_targets, os.path.join(data_path, "eval_fixed/count_target_test_no_neighbor.pq"))

In [None]:
torch.save(all_targets, os.path.join(data_path, "eval_fixed/count_target_test_no_neighbor.pq"))

In [36]:
pd.DataFrame(gene_names).to_parquet(os.path.join(data_path,"eval_fixed/gene_names.pq"))