In [1]:
import os
import scanpy as sc
import muon as mu
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import anndata as ad

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch_geometric.data import Data
from torch_geometric.transforms import RandomNodeSplit, RandomLinkSplit
from torch_geometric.loader import DataLoader, NeighborLoader

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Plotting settings

colors = ["#3B7EA1", "#FDB515", "#D9661F", "#859438", "#EE1F60", "#00A598"]
sns.set(context="notebook", font_scale=1.3, style="ticks")
sns.set_palette(sns.color_palette(colors))
plt.rcParams['svg.fonttype'] = 'none'
plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['savefig.transparent'] = True
sc.settings._vector_friendly = True
DPI = 300
# GPU settings

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [7]:

# data_path
save_path = "/home/wuxinchao/data/st_cite_data/totalVI_reproducibility/data/"
mdata = mu.read_h5mu(save_path + "/SNL_111.h5mu")
rna = mdata.mod["rna"]
protein = mdata.mod["prot"]

# Preprocess

In [8]:
# RNA
sc.pp.normalize_total(rna)
sc.pp.log1p(rna)
rna.obs_names_make_unique()

In [9]:
from muon import prot as pt

pt.pp.clr(protein)
mdata.update()

# data input

In [3]:
concat_data = np.concatenate([rna.X, protein.X], axis=1)

adj_mtx = rna.obsp['connectivities'].toarray()
edge_index = adj_mtx.nonzero()
edge_index = torch.tensor(edge_index, 
        dtype=torch.long).contiguous()

scCITEseq_data = Data(x=torch.tensor(
    concat_data, dtype=torch.float), 
    edge_index=edge_index)

NameError: name 'rna' is not defined

# Spatial Data

In [4]:
data_path = "/home/wuxinchao/data/st_cite_data/"
prot_data_path = "B01825A4_protein_filter.csv"
rna_data_path = "B01825A4_rna_raw.csv"

In [5]:
from utils import construct_spatial_adata

sp_mudata = construct_spatial_adata(
    data_path, 
    rna_data=rna_data_path, 
    prot_data=prot_data_path
)
sp_mudata

  rna_adata = ad.AnnData(
  prot_adata = ad.AnnData(


In [9]:
rna_adata = sp_mudata.mod["rna"]
prot_adata = sp_mudata.mod["prot"]

In [10]:
sc.pp.normalize_total(rna_adata)
sc.pp.log1p(rna_adata)
rna_adata.obs_names_make_unique()
rna_adata.var['mt'] = rna_adata.var_names.str.startswith('MT-')
rna_adata.layers["counts"] = rna_adata.X.copy()
sc.pp.highly_variable_genes(
    rna_adata,
    n_top_genes=2000,
    flavor="seurat_v3",
    layer="counts",
)

sc.pp.scale(rna_adata, max_value=10)
sc.tl.pca(rna_adata, svd_solver="arpack")
# sc.pl.pca_variance_ratio(rna_adata, log=True)
sc.pp.neighbors(rna_adata, n_pcs=50)
sc.tl.umap(rna_adata)



In [11]:
from muon import prot as pt

pt.pp.clr(prot_adata)
sp_mudata.update()

In [12]:
sp_mudata.write_h5mu(data_path + "sp_mudata.h5mu")

In [13]:
sp_mudata

In [None]:
num_splits = 2
num_val = 0.2
num_test = 0.2

tsf = RandomNodeSplit(num_splits=num_splits, num_val=num_val, num_test=num_test, key=None)
training_data = tsf(scCITEseq_data)

In [None]:
from train import Trainer

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

trainer = Trainer(
    training_data, 
    model_choice="Graph Cross Attention", 
    rna_input_dim=rna.shape[1], 
    prot_input_dim=protein.shape[1],
    hidden_dim=32,
    embedding_dim=32,
    heads=4,
    num_blocks=2, 
    batch_size=256,
    epochs=25,
    mask_ratio=0.75,
    permute=True,
    preserve_rate=0.2,
    num_splits=num_splits,
    device=device,
    alpha=0.4,
    beta=0.4,
    )

train_losses, val_losses = trainer.train()