# General Imports

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import os

import numpy as np
import pandas as pd
import scipy
import torch
import tqdm
from accelerate import Accelerator
import polars as pl
import scanpy as sc

from torch.utils.data import DataLoader

from enformer_pytorch.data import GenomeIntervalDataset

from scooby.modeling import Scooby
from scooby.data import onTheFlyDataset, onTheFlyPseudobulkDataset
from scooby.utils.utils import fix_rev_comp_multiome, undo_squashed_scale, get_pseudobulk_count_pred,get_gene_slice_and_strand, get_cell_count_pred
from scooby.utils.transcriptome import Transcriptome

In [4]:
data_path = '/data/ceph/hdd/project/node_08/QNA/scborzoi/submission_data'

# scDog

### Functions for Pearson correlation

In [5]:
def stack_and_pearson(x,y):
    return torch.corrcoef(torch.stack([x,y]))[0,1]
batched_pearson = torch.vmap(stack_and_pearson, chunk_size=1)

def predict(model, seqs, seqs_rev_comp, conv_weights, conv_biases, bins_to_predict = None):
    bs = seqs.shape[0]
    assert bs == 1
    with torch.no_grad():
        with torch.autocast("cuda"):
            outputs = model.forward_sequence_w_convs(seqs, conv_weights, conv_biases, bins_to_predict = bins_to_predict)
            if bins_to_predict is not None:
                outputs_rev_comp = model.forward_sequence_w_convs(seqs_rev_comp, conv_weights, conv_biases, bins_to_predict = (6143 - bins_to_predict))
            else:
                outputs_rev_comp = model.forward_sequence_w_convs(seqs_rev_comp, conv_weights, conv_biases, bins_to_predict = None)
    flipped_version = torch.flip(outputs_rev_comp,(1,-3))
    outputs_rev_comp = fix_rev_comp_multiome(flipped_version) #fix_rev_comp2(flipped_version)
    #outputs_rev_comp = fix_rev_comp2(flipped_version) #fix_rev_comp2(flipped_version)
    return (outputs + outputs_rev_comp)/2

### Load the model

In [6]:
accelerator = Accelerator(step_scheduler_with_optimizer = False)

Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


#### Load pretrained scooby from the hub

In [7]:
csb = Scooby.from_pretrained(
    '/s/project/QNA/borzoi_saved_models/neurips-scooby-no-normoblast',
    cell_emb_dim=14,
    embedding_dim=1920,
    n_tracks=3,
    return_center_bins_only=True,
    disable_cache=False,
    use_transform_borzoi_emb=True,
)

In [8]:
clip_soft = 5

In [9]:
csb = accelerator.prepare(csb)

In [10]:
context_length  = 524288

### Specify genome

In [11]:
gtf_file = os.path.join(data_path, "gencode.v32.annotation.sorted.gtf.gz")
fasta_file = os.path.join(data_path, "scooby_training_data", "genome_human.fa")
bed_file = os.path.join(data_path, "scooby_training_data", "sequences.bed")

### Load neighbors and embedding

In [12]:
base_path = os.path.join(data_path, 'scooby_training_data', 'pseudobulks')

In [13]:
base_path

'/data/ceph/hdd/project/node_08/QNA/scborzoi/submission_data/scooby_training_data/pseudobulks'

In [15]:
neighbors = scipy.sparse.load_npz(os.path.join(data_path, 'scooby_training_data', 'no_neighbors.npz'))
embedding = pd.read_parquet(os.path.join(data_path, 'scooby_training_data',  'embedding_no_val_genes_no_normoblast.pq'))

In [16]:
cell_type_index = pd.read_parquet(os.path.join(data_path,  'scooby_training_data', 'celltype_fixed.pq'))
cell_type_index['size'] = cell_type_index['cellindex'].apply(lambda x: len(x))

In [17]:
cell_type_index['celltype'] = cell_type_index['celltype'].str.replace(' ', '_').replace(r"G/M_prog", "G+M_prog").replace("MK/E_prog", "MK+E_prog") #+ '_filtered'
cell_type_index = cell_type_index.sort_values('celltype')

In [18]:
cell_type_index = cell_type_index.reset_index()

In [19]:
cell_type_index

Unnamed: 0,index,celltype,cellindex,size
0,4,B1_B,"[5, 9, 20, 32, 112, 128, 151, 265, 294, 360, 3...",1747
1,1,CD14+_Mono,"[1, 11, 13, 19, 30, 38, 49, 50, 51, 58, 62, 64...",10338
2,3,CD16+_Mono,"[4, 17, 94, 315, 329, 370, 698, 709, 928, 936,...",1762
3,7,CD4+_T_activated,"[8, 24, 28, 40, 45, 48, 55, 63, 68, 75, 76, 82...",5157
4,6,CD4+_T_naive,"[7, 44, 47, 54, 56, 59, 88, 116, 123, 132, 140...",4170
5,2,CD8+_T_activated,"[2, 3, 33, 35, 36, 41, 46, 67, 77, 84, 95, 103...",10846
6,20,CD8+_T_naive,"[11622, 11660, 11682, 11753, 11795, 11806, 119...",984
7,5,Early_Lymphoid,"[6, 298, 434, 677, 751, 757, 851, 987, 1090, 1...",1410
8,8,Erythroblast,"[10, 12, 14, 26, 31, 34, 43, 53, 69, 71, 79, 8...",4544
9,13,G+M_prog,"[91, 193, 199, 228, 288, 307, 389, 602, 621, 6...",1025


## Eval on Test

### Sequence dataloader 

This only loads the sequences in the right order

In [20]:
transcriptome = Transcriptome(gtf_file)

In [22]:
filter_val = lambda df: df.filter(True)#
val_ds = GenomeIntervalDataset(
    bed_file = os.path.join(data_path,'scooby_training_data', 'test_gene_hv_sequences.csv'),
    fasta_file = fasta_file,
    filter_df_fn = lambda df: df.filter((pl.col('column_2') >=0)), 
    return_seq_indices = False,
    shift_augs = (0,0),
    rc_aug = False,
    return_augs = True,
    context_length = context_length,
    chr_bed_to_fasta_map = {}
)
len(val_ds)

417

In [23]:
val_dataset = onTheFlyDataset(
    None,
    None,
    neighbors,
    embedding,
    val_ds,
    get_targets= False,
    random_cells = False,
    cells_to_run = None, 
    clip_soft = clip_soft,
    )
val_loader = DataLoader(val_dataset, batch_size=1, shuffle = False, num_workers = 1)
csb, val_loader = accelerator.prepare(csb, val_loader)
clip_soft

5

### Target dataloader on bigwig pseudobulk files

In [24]:
val_dataset_targets = onTheFlyPseudobulkDataset(
    cell_types = cell_type_index['celltype'].values,
    ds = val_ds, 
    base_path = base_path,
)

### Get cell conv_weights for all cells

In [25]:
csb.eval()
cell_indices  = []
size_factors_per_ct = []
for _, row in tqdm.tqdm(cell_type_index.iterrows(),disable = True):
    cell_indices.append(
        torch.from_numpy(
            np.vstack(
                embedding.iloc[row['cellindex']]['embedding'].values # gets embeddings of all cells of the cell type
                )
            ).unsqueeze(0)
        ) # prep cell_embeddings

# get conv weights and biases for all cells sorted by cell type in a list
cell_emb_conv_weights_and_biases = []
for cell_emb_idx in tqdm.tqdm(cell_indices, disable = True):
    cell_emb_idx = cell_emb_idx.cuda()
    conv_weights, conv_biases = csb.forward_cell_embs_only(cell_emb_idx)
    cell_emb_conv_weights_and_biases.append((conv_weights, conv_biases))


### Get counts over exons

In [26]:
#num_neighbors = 100
num_neighbors = 1

all_outputs, all_targets = [], []
val_dataset_target_loader = iter(DataLoader(val_dataset_targets, batch_size=1, shuffle = False, num_workers = 4))

# iterate over all val gene sequences
for i,x in tqdm.tqdm(enumerate(val_loader), disable = False, total=len(val_dataset)):
    gene_slice, strand = get_gene_slice_and_strand(transcriptome, val_dataset_targets.genome_ds.df[i, 'column_4'], val_dataset_targets.genome_ds.df[i, 'column_2'], span = False)
    targets = (next(val_dataset_target_loader)[2]).float().cuda()
    if len(gene_slice) == 0:
        continue
    
    bs = x[0].shape[0]
    seqs = x[0].cuda().permute(0,2,1)
    stacked_outputs = []

    stacked_outputs = get_pseudobulk_count_pred(
        csb = csb, 
        seqs = seqs, 
        cell_emb_conv_weights_and_biases = cell_emb_conv_weights_and_biases, 
        gene_slice = gene_slice,
        strand = strand, 
        model_type = "multiome",
        predict = predict,  
        clip_soft = clip_soft, 
        num_neighbors = 1
    )
    
    temp_stack = []
    all_outputs.append(stacked_outputs)
    if strand == '+':
        all_targets.append(undo_squashed_scale(targets[0, gene_slice, ::2], clip_soft=384).sum(axis=0).detach().clone().cpu().squeeze())
    elif strand == '-':
        all_targets.append(undo_squashed_scale(targets[0, gene_slice, 1::2], clip_soft=384).sum(axis=0).detach().clone().cpu().squeeze())
    
    
all_outputs, all_targets = torch.vstack(all_outputs).clone().numpy(force=True),torch.vstack(all_targets).clone().numpy(force=True)

100%|██████████| 417/417 [02:48<00:00,  2.47it/s]


In [36]:
torch.save(all_outputs, os.path.join(data_path, "count_eval", "count_predicted_test_no_normoblast.pq"))

In [38]:
os.path.join(data_path, "count_eval", "count_predicted_test_no_normoblast.pq")

'/data/ceph/hdd/project/node_08/QNA/scborzoi/submission_data/count_eval/count_predicted_test_no_normoblast.pq'

In [37]:
torch.save(all_targets, os.path.join(data_path, "count_eval", "count_target_test_no_normoblast.pq"))