# scBERT Embedding Extraction and Visualization
This notebook demonstrates how to use the `scBERTExtractor` class to extract embeddings from an h5ad file and visualize them using t-SNE.

In [1]:
from scbert_extractor import scBERTExtractor
import scanpy as sc
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
import seaborn as sns

  from pkg_resources import get_distribution, DistributionNotFound


## Load Data and Model
Specify the path to your `.h5ad` file and the pretrained scBERT checkpoint.

In [7]:
example_h5ad = '/home/jupyter/DATA/brca_full/brca_cells_only_3000cell_4096gene.h5ad'
checkpoint = '/home/jupyter/MODELS/scBERT/panglao_pretrain.pth'
config_kwargs = dict(num_tokens=7, dim=200, depth=6, heads=10, max_seq_len=16906, gene2vec_path='/home/jupyter/MODELS/scBERT/gene2vec_16906.npy')
extractor = scBERTExtractor(checkpoint, config_kwargs) 

## Preprocess the Data

In [8]:
adata = extractor.preprocess(example_h5ad)



In [9]:
adata

AnnData object with n_obs × n_vars = 87326 × 4096
    obs: 'orig.ident', 'nCount_RNA', 'nFeature_RNA', 'cell_id', 'donor_id', 'timepoint', 'outcome', 'Cancer_type', 'cell_types', 'cohort', 'pre_post', 'donor_id_pre_post', 'donor_id_outcome', 'donor_id_cell_types', 'donor_id_cell_types_pre_post', 'sample_id_pre_post_outcome', 'enough_cells', 'Study_name', 'Primary_or_met', 'RNA_snn_res.0.8', 'seurat_clusters', 'ident', 'n_genes_by_counts', 'total_counts', 'n_genes'
    var: 'n_cells_by_counts', 'mean_counts', 'pct_dropout_by_counts', 'total_counts', 'n_cells', 'highly_variable', 'means', 'dispersions', 'dispersions_norm'
    uns: 'X_name', 'hvg', 'log1p'
    obsm: 'PCA', 'UMAP'
    layers: 'counts', 'logcounts', 'scaledata'

## Extract Embeddings

In [13]:
adata.X = adata.X.toarray()

In [14]:
embeddings = extractor.extract_embeddings(adata, method='cls')

RuntimeError: Expected tensor for argument #1 'indices' to have one of the following scalar types: Long, Int; but got torch.FloatTensor instead (while checking arguments for embedding)

## Visualize with PCA and t-SNE

In [None]:
pca = PCA(n_components=50).fit_transform(embeddings)
tsne = TSNE(n_components=2, random_state=42).fit_transform(pca)
adata.obsm['X_scBERT'] = tsne

In [None]:
sc.pl.embedding(adata, basis='X_scBERT', color='louvain', title='scBERT Embeddings (t-SNE)')