# import

In [27]:

import numpy as np
import scanpy as sc
import torch
import matplotlib.pyplot as plt
import time
import os
from typing import Union

from torchtext.vocab import Vocab
from torchtext._torchtext import (
    Vocab as VocabPybind,
)
from scgpt import logger

from scgpt.trainer import train as scgpt_train
from scgpt.trainer import evaluate as scgpt_evaluate
from scgpt.trainer import eval_testdata as scgpt_test
from scgpt.preprocess import Preprocessor
from scgpt.model import TransformerModel
from scgpt.utils import eval_scib_metrics, load_pretrained
from scgpt.loss import (
    masked_mse_loss,
    masked_relative_error,
    criterion_neg_log_bernoulli,
)
import sys
sys.path.append("./")
from scgpt_helper import *

import lamindb as ln

%load_ext autoreload
%autoreload 2

  import pkg_resources
Implementing implicit namespace packages (as specified in PEP 420) is preferred to `pkg_resources.declare_namespace`. See https://setuptools.pypa.io/en/latest/references/keywords.html#keyword-namespace-packages
  declare_namespace(pkg)
Implementing implicit namespace packages (as specified in PEP 420) is preferred to `pkg_resources.declare_namespace`. See https://setuptools.pypa.io/en/latest/references/keywords.html#keyword-namespace-packages
  declare_namespace(pkg)
Implementing implicit namespace packages (as specified in PEP 420) is preferred to `pkg_resources.declare_namespace`. See https://setuptools.pypa.io/en/latest/references/keywords.html#keyword-namespace-packages
  declare_namespace(parent)
Implementing implicit namespace packages (as specified in PEP 420) is preferred to `pkg_resources.declare_namespace`. See https://setuptools.pypa.io/en/latest/references/keywords.html#keyword-namespace-packages
  declare_namespace(pkg)
Implementing implicit namespac

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## create experiment

In [65]:
#experiment = scGExperiment()
batch_keys = [
    "self_reported_ethnicity_ontology_term_id",
    "assay_ontology_term_id",
]
special_tokens = ["<pad>", "<unk>", "<mask>"]
n_hvg=2000
filter_gene_by_counts = 3
data_is_raw = True
save_path = "../save/"


In [76]:
vocab = "../save/scGPT_human/vocab.json"
model_path = "../save/scGPT_human/best_model.pt"
batch_size = 8
epoch = 5
fast_transformer=True
mask_ratio=0.3

In [2]:
import json

with open("../save/scGPT_human/args.json", 'r') as f:
    config = json.load(f)


## prepare dataset

In [None]:
dataset_name = "retina"

In [61]:
adata = ln.File.filter(uid=ln.File.search("retina").index[0]).one().load()
# adata = ln.File.filter().first().load()
#ln.Dataset.using("laminlabs/cellxgene-census").one()
adata

## setup

AnnData object with n_obs × n_vars = 19694 × 37127
    obs: 'n_genes', 'n_counts', 'percent_mito', 'donor_id', 'assay_ontology_term_id', 'cell_type_ontology_term_id', 'development_stage_ontology_term_id', 'disease_ontology_term_id', 'self_reported_ethnicity_ontology_term_id', 'is_primary_data', 'organism_ontology_term_id', 'sex_ontology_term_id', 'tissue_ontology_term_id', 'author_cell_type', 'suspension_type', 'cell_type', 'assay', 'disease', 'organism', 'sex', 'tissue', 'self_reported_ethnicity', 'development_stage'
    var: 'chromosome', 'featureend', 'featurestart', 'n_cells', 'percent_cells', 'robust', 'highly_variable_features', 'mean', 'var', 'hvf_loess', 'hvf_rank', 'gene_symbols', 'feature_is_filtered', 'feature_name', 'feature_reference', 'feature_biotype'
    uns: 'cell_type_ontology_term_id_colors', 'default_embedding', 'schema_version', 'title'
    obsm: 'X_diffmap', 'X_diffmap_pca', 'X_fitsne', 'X_fle', 'X_pca', 'X_phi', 'X_umap'

In [63]:
adata.obs[batch]

AnnData object with n_obs × n_vars = 19694 × 37127
    obs: 'n_genes', 'n_counts', 'percent_mito', 'donor_id', 'assay_ontology_term_id', 'cell_type_ontology_term_id', 'development_stage_ontology_term_id', 'disease_ontology_term_id', 'self_reported_ethnicity_ontology_term_id', 'is_primary_data', 'organism_ontology_term_id', 'sex_ontology_term_id', 'tissue_ontology_term_id', 'author_cell_type', 'suspension_type', 'cell_type', 'assay', 'disease', 'organism', 'sex', 'tissue', 'self_reported_ethnicity', 'development_stage'
    var: 'chromosome', 'featureend', 'featurestart', 'n_cells', 'percent_cells', 'robust', 'highly_variable_features', 'mean', 'var', 'hvf_loess', 'hvf_rank', 'gene_symbols', 'feature_is_filtered', 'feature_name', 'feature_reference', 'feature_biotype', 'id_in_vocab'
    uns: 'cell_type_ontology_term_id_colors', 'default_embedding', 'schema_version', 'title'
    obsm: 'X_diffmap', 'X_diffmap_pca', 'X_fitsne', 'X_fle', 'X_pca', 'X_phi', 'X_umap'

In [None]:
save_dir = setup(dataset_name, save_path, config)

if type(vocab) is str:
    vocab = GeneVocab.from_file(vocab)
vocab.set_default_index(vocab["<pad>"])

dataset = load_dataset(adata, vocab)
dataset.obs["batch_id"] = dataset.obs[batch_keys].apply("_".join, axis=1)

In [None]:
# experiment.init_datamodule(dataset=, vocab=)

In [67]:
preprocessor = Preprocessor(
    use_key="X",  # the key in adata.layers to use as raw data
    filter_gene_by_counts=filter_gene_by_counts,  # step 1
    filter_cell_by_counts=False,  # step 2
    normalize_total=1e4,  # 3. whether to normalize the raw data and to what sum
    result_normed_key="X_normed",  # the key in adata.layers to store the normalized data
    log1p=data_is_raw,  # 4. whether to log1p the normalized data
    result_log1p_key="X_log1p",
    subset_hvg=n_hvg,  # 5. whether to subset the raw data to highly variable genes
    hvg_flavor="seurat_v3" if data_is_raw else "cell_ranger",
    binning=config['n_bins'],  # 6. whether to bin the raw data and to what number of bins
    result_binned_key="X_binned",  # the key in adata.layers to store the binned data
)

preprocessor(dataset)

scGPT - INFO - Filtering genes by counts ...
scGPT - INFO - Normalizing total counts ...
scGPT - INFO - Log1p transforming ...
scGPT - INFO - Subsetting highly variable genes ...




scGPT - INFO - Binning data ...


In [77]:
data_loader, valid_loader = prepare_dataset(
    dataset,
    vocab,
    batch_size,
    epoch=epoch,
    n_hvg=n_hvg,
    test_size=0.2,
    mask_ratio=mask_ratio
)

TypeError: string indices must be integers

In [None]:
model = TransformerModel(
    len(vocab),  # n_tokens
    # TODO:
    config['embsize'],
    config['nhead'],
    config['d_hid'],
    config['nlayers'],
    vocab=vocab,
    dropout=config['dropout'],
    pad_token=config['pad_token'],
    pad_value=config['pad_value'],
    do_mvc=,
    do_dab=True,
    use_batch_labels=True,
    num_batch_labels=len(set(dataset.obs["batch_id"])),
    domain_spec_batchnorm=config['DSBN'],
    n_input_bins=config['n_input_bins'],
    ecs_threshold=config['ecs_thres'],
    explicit_zero_prob=config['explicit_zero_prob'],
    use_fast_transformer=fast_transformer,
    pre_norm=config['pre_norm'],
)
if model_path is not None:
    load_pretrained(model, torch.load(model_path), verbose=False)
    # model_config['file'] = model_dir / "args.json"
    # model_file = model_dir / "best_model.pt"
model.to(device)
wandb.watch(model)

In [None]:
# TODO: init wandb
criterion = masked_mse_loss
criterion_dab = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(
    model.parameters(),
    lr=config['lr'],
    eps=1e-4 if config['amp'] else 1e-8,
)
scheduler = torch.optim.lr_scheduler.StepLR(
    optimizer, 1, gamma=config['schedule_ratio']
)
scaler = torch.cuda.amp.GradScaler(enabled=config['amp'])

In [None]:
batch_ids = 

## fine tune and save

In [None]:
run = wandb.init(
    config=config,
    project="scGPT",
    reinit=True,
    settings=wandb.Settings(start_method="fork"),
)

In [None]:
fine_tune(model, )

In [None]:

torch.save(
    {
        "epoch": epoch,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "vocab": vocab,
        "config": config,
    },
    save_dir+"/best_model.pt",
)

wandb.use_artifact(
    save_dir + "/best_model.pt", type="model"
)

wandb.finish()
gc.collect()

## look at what we have

In [None]:
scgpt_test(
    model, adata_test, gene_ids, vocab, config
)