In [8]:
# Import Libraries
import os
import sys

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# Bioinformatics Libraries
import scanpy as sc

# Ignore warnings
import warnings
warnings.filterwarnings('ignore')

In [9]:
# Load the pbmc3k dataset
adata = sc.datasets.pbmc3k()

print(adata)

AnnData object with n_obs × n_vars = 2700 × 32738
    var: 'gene_ids'


In [10]:
# Basic filtering
sc.pp.filter_cells(adata, min_genes=200)
sc.pp.filter_genes(adata, min_cells=3)

print(adata)


AnnData object with n_obs × n_vars = 2700 × 13714
    obs: 'n_genes'
    var: 'gene_ids', 'n_cells'


In [11]:
# Normalize the data
sc.pp.normalize_total(adata, target_sum=1e4)

# Logarithmize the data
sc.pp.log1p(adata)

# Store the raw data
adata.raw = adata



In [12]:
# Identify highly variable genes
sc.pp.highly_variable_genes(adata, min_mean=0.0125, max_mean=3, min_disp=0.5)

# Filter the data
adata = adata[:, adata.var.highly_variable]

# Summarize the highly variable genes
print(adata.var.highly_variable.sum())


1872


In [13]:
adata

View of AnnData object with n_obs × n_vars = 2700 × 1872
    obs: 'n_genes'
    var: 'gene_ids', 'n_cells', 'highly_variable', 'means', 'dispersions', 'dispersions_norm'
    uns: 'log1p', 'hvg'

## Dimension Reduction

### scVAE - scvi-tools

In [15]:
import scvi

# Setup the adata object for scVI
scvi.data.setup_anndata(adata)


AttributeError: module 'jax' has no attribute 'Device'

In [None]:
from scvi.model import SCVI

# Initialize the model
model = SCVI(adata)

# Train the model
model.train()


In [None]:
# Extract the latent space representation
latent_representation = model.get_latent_representation()

# The result is a numpy array with the reduced representation
print(latent_representation)


In [None]:
# Visualize the latent space
latent_x = latent_representation[:, 0]
latent_y = latent_representation[:, 1]

plt.figure(figsize=(8, 6))
plt.scatter(latent_x, latent_y, s=5, alpha=0.7)
plt.title('scVAE - Latent Space')
plt.xlabel('Latent Dim 1')
plt.ylabel('Latent Dim 2')
plt.show()
