In [1]:
import os
import scanpy as sc
import muon as mu
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import anndata as ad

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch_geometric.data import Data
from torch_geometric.transforms import RandomNodeSplit, RandomLinkSplit
from torch_geometric.loader import DataLoader, NeighborLoader

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from train import Trainer
from dataset import GeneVocab

In [3]:
# Plotting settings

colors = ["#3B7EA1", "#FDB515", "#D9661F", "#859438", "#EE1F60", "#00A598"]
sns.set(context="notebook", font_scale=1.3, style="ticks")
sns.set_palette(sns.color_palette(colors))
plt.rcParams['svg.fonttype'] = 'none'
plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['savefig.transparent'] = True
sc.settings._vector_friendly = True
DPI = 300
# GPU settings

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

# Data input

## CITE-seq SNL data

In [4]:

# data_path
save_path = "/home/wuxinchao/data/st_cite_data/totalVI_reproducibility/data/"
mdata = mu.read_h5mu(save_path + "/SNL_111.h5mu")
rna = mdata.mod["rna"]
protein = mdata.mod["prot"]

In [5]:
rna, protein

(AnnData object with n_obs × n_vars = 9264 × 4005
     obs: 'n_protein_counts', 'n_proteins', 'seurat_hash_id', 'batch_indices', 'hash_id', 'n_genes', 'percent_mito', 'leiden_subclusters', 'cell_types', 'leiden_totalVI', 'combined_cell_types'
     var: 'gene_ids', 'feature_types', 'highly_variable', 'highly_variable_mean_variance', 'encode', 'hvg_encode'
     uns: 'cell_types_colors', 'combined_cell_types_colors', 'leiden', 'leiden_totalVI_colors', 'neighbors', 'protein_names', 'umap', 'version'
     obsm: 'X_totalVI', 'X_umap', 'protein_expression'
     layers: 'counts', 'denoised_rna'
     obsp: 'connectivities', 'distances',
 AnnData object with n_obs × n_vars = 9264 × 110
     obs: 'n_protein_counts', 'n_proteins', 'seurat_hash_id', 'batch_indices', 'hash_id', 'n_genes', 'percent_mito', 'leiden_subclusters', 'cell_types', '_scvi_batch'
     var: 'clean_names'
     layers: 'denoised_protein', 'protein_foreground_prob')

# Preprocess

In [6]:
# RNA
sc.pp.normalize_total(rna)
sc.pp.log1p(rna)
rna.obs_names_make_unique()

In [7]:
from muon import prot as pt

pt.pp.clr(protein)
mdata.update()

# Pre-train model

In [8]:
concat_data = np.concatenate(
    [rna.X, protein.X], axis=1
)

adj_mtx = rna.obsp['connectivities'].toarray()
edge_index = adj_mtx.nonzero()
edge_index = torch.tensor(edge_index, 
        dtype=torch.long).contiguous()

scCITEseq_data = Data(x=torch.tensor(
    concat_data, dtype=torch.float), 
    edge_index=edge_index)

geneVocab = GeneVocab(rna)
proteinVocab = GeneVocab(protein)

  edge_index = torch.tensor(edge_index,


In [9]:
num_splits = 2
num_val = 0.2
num_test = 0.2

tsf = RandomNodeSplit(num_splits=num_splits, 
                      num_val=num_val, 
                      num_test=num_test, 
                      key=None
                      )
training_data = tsf(scCITEseq_data)

In [17]:
# setting for the model
rna_input_dim = rna.shape[1]
prot_input_dim = protein.shape[1]
hidden_dim = 32
embedding_dim = 32
heads = 4
num_blocks = 2
permute = False
preserve_rate = 0.2
alpha = 0.4
beta = 0.4

# setting for the trainer
batch_size = 256
lr = 5e-5
epochs = 50
mask_ratio = 0.85
num_splits = 2
device = torch.device(
    'cuda' if torch.cuda.is_available() else 'cpu')

In [18]:
model_choice = "Graph Cross Attention"

trainer = Trainer(
    training_data,
    model_choice=model_choice, 
    rna_input_dim=rna_input_dim, 
    prot_input_dim=prot_input_dim,
    hidden_dim=hidden_dim,
    embedding_dim=embedding_dim,
    heads=heads,
    num_blocks=num_blocks, 
    batch_size=batch_size,
    lr=lr,
    epochs=epochs,
    mask_ratio=mask_ratio,
    permute=permute,
    preserve_rate=preserve_rate,
    num_splits=num_splits,
    device=device,
    alpha=alpha,
    beta=beta,
    )

train_losses, val_losses = trainer.train()

Epoch 1/50 train_loss: 21.46366 val_loss: 7.09633
Epoch 2/50 train_loss: 20.43760 val_loss: 6.71513
Epoch 3/50 train_loss: 18.93721 val_loss: 6.07844
Epoch 4/50 train_loss: 16.84409 val_loss: 5.32738
Epoch 5/50 train_loss: 14.59727 val_loss: 4.54053
Epoch 6/50 train_loss: 12.57584 val_loss: 3.94231
Epoch 7/50 train_loss: 11.00732 val_loss: 3.52841
Epoch 8/50 train_loss: 9.87183 val_loss: 3.16644
Epoch 9/50 train_loss: 8.99064 val_loss: 2.91790
Epoch 10/50 train_loss: 8.27592 val_loss: 2.69729
Epoch 11/50 train_loss: 7.66485 val_loss: 2.49915
Epoch 12/50 train_loss: 7.16286 val_loss: 2.33356
Epoch 13/50 train_loss: 6.68177 val_loss: 2.18289
Epoch 14/50 train_loss: 6.28247 val_loss: 2.04852
Epoch 15/50 train_loss: 5.94214 val_loss: 1.95897
Epoch 16/50 train_loss: 5.63278 val_loss: 1.87486
Epoch 17/50 train_loss: 5.40493 val_loss: 1.78595
Epoch 18/50 train_loss: 5.19028 val_loss: 1.72692
Epoch 19/50 train_loss: 5.01122 val_loss: 1.67483
Epoch 20/50 train_loss: 4.87185 val_loss: 1.62662
Ep

RuntimeError: Parent directory .save_model does not exist.

In [19]:
# save the best model parameters
save_dict = {
    "model": trainer.best_model.state_dict(),
    "optimizer": trainer.optimizer.state_dict(),
}
torch.save(save_dict, "../save_model/best_model.pt")

## Spatial Data

In [20]:
data_path = "/home/wuxinchao/data/st_cite_data/"
prot_data_path = "B01825A4_protein_filter.csv"
rna_data_path = "B01825A4_rna_raw.csv"

In [21]:
from utils import construct_spatial_adata

sp_mudata = construct_spatial_adata(
    data_path, 
    rna_data=rna_data_path, 
    prot_data=prot_data_path
)
sp_mudata

In [None]:
rna_adata = sp_mudata.mod["rna"]
prot_adata = sp_mudata.mod["prot"]

In [None]:
sc.pp.normalize_total(rna_adata)
sc.pp.log1p(rna_adata)
rna_adata.obs_names_make_unique()
rna_adata.var['mt'] = rna_adata.var_names.str.startswith('MT-')
rna_adata.layers["counts"] = rna_adata.X.copy()
sc.pp.highly_variable_genes(
    rna_adata,
    n_top_genes=2000,
    flavor="seurat_v3",
    layer="counts",
)

sc.pp.scale(rna_adata, max_value=10)
sc.tl.pca(rna_adata, svd_solver="arpack")
# sc.pl.pca_variance_ratio(rna_adata, log=True)
sc.pp.neighbors(rna_adata, n_pcs=50)
sc.tl.umap(rna_adata)



In [None]:
from muon import prot as pt

pt.pp.clr(prot_adata)
sp_mudata.update()

In [12]:
sp_mudata.write_h5mu(data_path + "sp_mudata.h5mu")

In [5]:
# Loading data
sp_mudata = mu.read_h5mu(data_path + "sp_mudata.h5mu")
rna_adata = sp_mudata.mod["rna"]
prot_adata = sp_mudata.mod["prot"]
sp_mudata

In [None]:
geneVocab.update_gene_dict(rna_adata)
proteinVocab.update_gene_dict(prot_adata)

In [7]:
concat_data = np.concatenate(
    [rna_adata.X, prot_adata.X], axis=1
)

concat_spatial_encoding_data = np.concatenate(
    [rna_adata.X, prot_adata.X, sp_mudata.obsm["spatial"]], 
    axis=1
)

adj_mtx = rna_adata.obsp['connectivities'].toarray()
edge_index = adj_mtx.nonzero()
edge_index = torch.tensor(edge_index, 
        dtype=torch.long).contiguous()

scCITEseq_data = Data(x=torch.tensor(
    concat_data, dtype=torch.float), 
    edge_index=edge_index)

spCITEseq_data = Data(x=torch.tensor(
    concat_spatial_encoding_data, dtype=torch.float), 
    edge_index=edge_index)

  edge_index = torch.tensor(edge_index,


In [8]:
num_splits = 2
num_val = 0.2
num_test = 0.2

tsf = RandomNodeSplit(num_splits=num_splits, 
                      num_val=num_val, 
                      num_test=num_test, 
                      key=None
                      )
training_data = tsf(scCITEseq_data)
spatial_training_data = tsf(spCITEseq_data)

In [12]:
# setting for the model
rna_input_dim = rna_adata.shape[1]
prot_input_dim = prot_adata.shape[1]
hidden_dim = 32
embedding_dim = 32
heads = 4
num_blocks = 2
permute = True
preserve_rate = 0.2
alpha = 0.4
beta = 0.4

# setting for the trainer
batch_size = 256
lr = 1e-6
epochs = 100
mask_ratio = 0.85
num_splits = 2
device = torch.device(
    'cuda' if torch.cuda.is_available() else 'cpu')

# Fine-tune with spatial cite-seq

In [None]:
# align features of fine-tuning data to the pre-trained model
geneVocab.align_features(rna_adata)
proteinVocab.align_features(prot_adata)

In [16]:
model_choice = "Spatial Graph Cross Attention"

spatial_trainer = Trainer(
    spatial_training_data, 
    geneVocab,
    model_choice=model_choice, 
    rna_input_dim=rna_input_dim, 
    prot_input_dim=prot_input_dim,
    hidden_dim=hidden_dim,
    embedding_dim=embedding_dim,
    heads=heads,
    num_blocks=num_blocks, 
    batch_size=batch_size,
    lr=lr,
    epochs=epochs,
    mask_ratio=mask_ratio,
    permute=permute,
    preserve_rate=preserve_rate,
    num_splits=num_splits,
    device=device,
    alpha=alpha,
    beta=beta,
    )

spatial_train_losses, spatial_val_losses = spatial_trainer.fine_tune()

Epoch 1/25 train_loss: nan val_loss: nan
Epoch 2/25 train_loss: nan val_loss: nan
Epoch 3/25 train_loss: nan val_loss: nan
Epoch 4/25 train_loss: nan val_loss: nan
Epoch 5/25 train_loss: nan val_loss: nan
Epoch 6/25 train_loss: nan val_loss: nan
Epoch 7/25 train_loss: nan val_loss: nan
Epoch 8/25 train_loss: nan val_loss: nan
Epoch 9/25 train_loss: nan val_loss: nan
Epoch 10/25 train_loss: nan val_loss: nan
Epoch 11/25 train_loss: nan val_loss: nan
Epoch 12/25 train_loss: nan val_loss: nan
Epoch 13/25 train_loss: nan val_loss: nan
Epoch 14/25 train_loss: nan val_loss: nan
Epoch 15/25 train_loss: nan val_loss: nan
Epoch 16/25 train_loss: nan val_loss: nan
Epoch 17/25 train_loss: nan val_loss: nan
Epoch 18/25 train_loss: nan val_loss: nan
Epoch 19/25 train_loss: nan val_loss: nan
Epoch 20/25 train_loss: nan val_loss: nan
Epoch 21/25 train_loss: nan val_loss: nan
Epoch 22/25 train_loss: nan val_loss: nan
Epoch 23/25 train_loss: nan val_loss: nan
Epoch 24/25 train_loss: nan val_loss: nan
E