# Get cell batch embeddings

## Imports

In [4]:
import scanpy as sc
import logging
import sys
import os

In [11]:
import json
import os
from pathlib import Path
import numpy as np
import scanpy as sc
import torch
from torch.utils.data import DataLoader, SequentialSampler
from tqdm import tqdm

In [5]:
os.chdir('/Users/cecileherbermann/Library/CloudStorage/OneDrive-Persoonlijk/Documenten/0 Bioinformatics & BioComplexity/General Research Profile/Scripts')

In [6]:
logger = logging.getLogger("scGPT")
# check if logger has been initialized
if not logger.hasHandlers() or len(logger.handlers) == 0:
    logger.propagate = False
    logger.setLevel(logging.INFO)
    handler = logging.StreamHandler(sys.stdout)
    handler.setLevel(logging.INFO)
    formatter = logging.Formatter(
        "%(name)s - %(levelname)s - %(message)s", datefmt="%H:%M:%S"
    )
    handler.setFormatter(formatter)
    logger.addHandler(handler)

In [7]:
from data_collator import DataCollator
from model import TransformerModel
from gene_tokenizer import GeneVocab
from util import load_pretrained



In [None]:
from scgpt.data_collator import DataCollator
from scgpt.model import TransformerModel
from scgpt.tokenizer.gene_tokenizer import GeneVocab
from scgpt.utils import load_pretrained

## Get cell batch embeddings

#### Define Dataset Class

In [12]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, count_matrix, gene_ids, batch_ids=None, vocab=None, model_configs=None):
        self.count_matrix = count_matrix
        self.gene_ids = gene_ids
        self.batch_ids = batch_ids
        self.vocab = vocab
        self.model_configs = model_configs

    def __len__(self):
        return len(self.count_matrix)

    def __getitem__(self, idx):
        row = self.count_matrix[idx]
        nonzero_idx = np.nonzero(row)[0]
        values = row[nonzero_idx]
        genes = self.gene_ids[nonzero_idx]
        if self.vocab is not None and self.model_configs is not None:
            genes = np.insert(genes, 0, self.vocab["<cls>"])
            values = np.insert(values, 0, self.model_configs["pad_value"])
        genes = torch.from_numpy(genes).long()
        values = torch.from_numpy(values).float()
        output = {"id": idx, "genes": genes, "expressions": values}
        if self.batch_ids is not None:
            output["batch_labels"] = self.batch_ids[idx]
        return output


#### Settings

In [21]:
cell_embedding_mode: str = "cls"
model=None
vocab=None
max_length=1200
batch_size=64
model_configs=None
gene_ids=None
use_batch_labels=False

#### Load Data and Model

In [13]:
# Loading the file into an AnnData object
adata = sc.read_h5ad('/Users/cecileherbermann/Downloads/CFS_all_days_rawcount.h5ad')

# Create count matrix 
count_matrix = adata.X
count_matrix = count_matrix if isinstance(count_matrix, np.ndarray) else count_matrix.A

# Handling gene_ids
gene_ids = adata.var.index.to_numpy()

# Handling batch labels
use_batch_labels = False

if use_batch_labels:
    batch_ids = np.array(adata.obs["batch_id"].tolist())

  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)


In [31]:
# Load vocabulary
vocab = GeneVocab.from_file('vocab.json')

# Load model configs
with open('args.json', 'r') as f:
    model_configs = json.load(f)

In [37]:
if cell_embedding_mode == "cls":
        dataset = Dataset(
            count_matrix, gene_ids, batch_ids if use_batch_labels else None
        )

In [35]:
collator = DataCollator(
            do_padding=True,
            pad_token_id=vocab[model_configs["pad_token"]],
            pad_value=model_configs["pad_value"],
            do_mlm=False,
            do_binning=True,
            max_length=max_length,
            sampling=True,
            keep_first_n_tokens=1,
        )

In [41]:
import multiprocessing

data_loader = DataLoader(
    dataset,
    batch_size=batch_size,
    sampler=SequentialSampler(dataset),
    collate_fn=collator,  
    drop_last=False,
    num_workers=min(multiprocessing.cpu_count(), batch_size),
    pin_memory=True,
)


In [42]:
# Initialize your Dataset
dataset = Dataset(count_matrix, gene_ids, batch_ids if use_batch_labels else None)

In [49]:
device = next(model.parameters()).device
cell_embeddings = np.zeros((len(dataset), model_configs["embsize"]), dtype=np.float32)

AttributeError: 'NoneType' object has no attribute 'parameters'

In [59]:
cell_embeddings = np.zeros((len(dataset), model_configs["embsize"]), dtype=np.float32)
with torch.no_grad(), torch.cuda.amp.autocast(enabled=True):
    count = 0
    for data_dict in tqdm(data_loader, desc="Embedding cells"):
        input_gene_ids = data_dict["gene"]
        src_key_padding_mask = input_gene_ids.eq(
        vocab[model_configs["pad_token"]]
        )
        embeddings = model._encode(
            input_gene_ids,
            data_dict["expr"],
            src_key_padding_mask=src_key_padding_mask,
            batch_labels=data_dict["batch_labels"]
            if use_batch_labels
            else None,
        )

        embeddings = embeddings[:, 0, :]  # get the <cls> position embedding
        embeddings = embeddings.cpu().numpy()
        cell_embeddings[count : count + len(embeddings)] = embeddings
        count += len(embeddings)
cell_embeddings = cell_embeddings / np.linalg.norm(
    cell_embeddings, axis=1, keepdims=True
)

Embedding cells:   0%|          | 0/619 [00:00<?, ?it/s]Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/Users/cecileherbermann/opt/anaconda3/lib/python3.9/multiprocessing/spawn.py", line 116, in spawn_main
    exitcode = _main(fd, parent_sentinel)
  File "/Users/cecileherbermann/opt/anaconda3/lib/python3.9/multiprocessing/spawn.py", line 126, in _main
    self = reduction.pickle.load(from_parent)
AttributeError: Can't get attribute 'Dataset' on <module '__main__' (built-in)>
Embedding cells:   0%|          | 0/619 [04:19<?, ?it/s]


KeyboardInterrupt: 