In [None]:
%matplotlib inline

# Transfer pretrained model to predict tumor microenvironment

Spatial organization of tumor microenvironment has been a research hotspot for a long time. Here we collected spatial transcriptomics data from both 10x Xenium (DCIS) and Nanostring CosMx (NSCLC), and investigated the Bering model portability between tumor spatial data from different technologies

### Import packages & data

In [1]:
import random
import numpy as np
import pandas as pd
import tifffile as tiff
import matplotlib.pyplot as plt 

import Bering as br

### Load pretrain model from CoxMx NSCLC data

In [None]:
bg_pre = br.datasets.model_cosmx_nsclc_he()

### Load Xenium DCIS data

In [None]:
df_spots_seg = pd.read_csv('/data/aronow/Kang/spatial/Bering/demo/bmx_xenium_dcis/spots_seg.txt', sep = '\t', header = 0, index_col = 0)
df_spots_seg = df_spots_seg[['x','y','z','features']]
df_spots_unseg = pd.read_csv('/data/aronow/Kang/spatial/Bering/demo/bmx_xenium_dcis/spots_unseg.txt', sep = '\t', header = 0, index_col = 0)

# remove labels of segmented spots
df_spots_unseg = pd.concat([df_spots_unseg, df_spots_seg], axis = 0)
df_spots_seg = pd.DataFrame()

### Transfer model

In [None]:
# create a new Bering object
bg = br.BrGraph(df_spots_seg = df_spots_seg, df_spots_unseg = df_spots_unseg)
bg.use_settings(bg_pre) # transfer basic settings

# transfer models
bg.trainer_node = bg_pre.trainer_node
bg.trainer_edge = bg_pre.trainer_edge

In [None]:
max_num_spots = 1500000; num_chunks = 25; nodeclf_prob_threshold = 0.3

# node classification on whole slice
br.tl.node_classification(bg, bg.spots_all.copy(), n_neighbors = 30, prob_threshold = nodeclf_prob_threshold, max_num_spots = max_num_spots, num_chunks = num_chunks)

# cell segmentation
pred_cells = br.tl.cell_segmentation(bg)

# ensembl results
df_results, adata_ensembl, adata_seg = br.tl.cell_annotation(bg)

### visualize the results

In [None]:
# randomly select a cell
random_cell = cells = random.sample(bg.segmented.index.values.tolist(), 1)[0]
_,_,_ = br.pl.Plot_Classification(
    bg, 
    cell_name = random_cell,
    n_neighbors = 30, 
    zoomout_scale = 8,
)

### Self-distillation

We used the pretrained model from CoxMx NSCLC to predict the possible cells and annotations in Xenium DSIC data. However, the predicted labels are very coarse, which are not appropriate for down-stream analysis. Here, we used the strategy of self-distillation to learn from the coarse predicted labels and fine-tune the pre-trained model.

In [None]:
# original input data for self-distillation using the ensembled results
df_spots_all = df_results
df_spots_seg = df_spots_all[df_spots_all['ensembled_labels'] != 'Unknown'].copy()
df_spots_seg = df_spots_seg[['x', 'y', 'z', 'features', 'ensembled_cells', 'ensembled_labels']]
df_spots_seg.columns = ['x', 'y', 'z', 'features', 'segmented', 'labels']

df_spots_unseg = df_spots_all[df_spots_all['ensembled_labels'] == 'Unknown'].copy()
df_spots_unseg = df_spots_unseg[['x', 'y', 'z', 'features']]

In [None]:
bg = br.BrGraph(
    df_spots_seg = df_spots_seg, 
    df_spots_unseg = df_spots_unseg,
)

In [None]:
bg.use_settings(bg_pre) # transfer basic settings
bg.trainer_node = bg_pre.trainer_node
bg.trainer_edge = bg_pre.trainer_edge

In [None]:
# Build graphs for GCN training purpose
br.graphs.BuildWindowGraphs(
    bg, 
    n_cells_perClass = 12, 
    window_width = 15.0, 
    window_height = 15.0, 
    n_neighbors = 30, 
)
br.graphs.CreateData(
    bg, 
    batch_size = 16, 
    training_ratio = 0.8, 
)

In [None]:
br.train.Training(bg, retrain = True) # set retrain = True to retrain the model

In [None]:
# visualize the results
_,_,_ = br.pl.Plot_Classification(
    bg, 
    cell_name = random_cell,
    n_neighbors = 30, 
    zoomout_scale = 8,
)

In [None]:
# some settings
max_num_spots = 1500000
num_chunks = 25
nodeclf_prob_threshold = 0.3

pos_thresh = 0.7
resolution = 0.20
num_edges_perSpot = 200

In [None]:
# node classification on whole slice
br.tl.node_classification(bg, bg.spots_all.copy(), n_neighbors = 30, prob_threshold = nodeclf_prob_threshold, max_num_spots = max_num_spots, num_chunks = num_chunks)

# cell segmentation
pred_cells = br.tl.cell_segmentation(bg)

# ensembl results
df_results, adata_ensembl, adata_seg = br.tl.cell_annotation(bg)