# Import packages

In [None]:
import nenrich
import pop_id

import pandas as pd
import scanpy as sc
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import numpy as np
import anndata as ad

from pathlib import Path
import os

# Set up output figure settings
plt.rcParams['figure.figsize']=(5,5)

# Set up scanpy settings
sc.settings.verbosity = 3
sc.set_figure_params(dpi=100, dpi_save=200) #Increase DPI for better resolution figures

In [None]:
import nenrich
import imp
imp.reload(nenrich)
imp.reload(pop_id)

# 1. Run Neighborhood Identification

## Import data
Import the AnnData from the second notebook

In [None]:
adata_pops = ad.read_h5ad('adata_temp.h5ad')

### <font color='red'> Test works on only a couple of ROIs first </font>
This will filter the AnnData to only the first couple of ROIs, if this runs fine, then skip this cell and run as normal

In [None]:
number_rois_to_use = 1

adata_pops = adata_pops[adata_pops.obs.ROI.isin(adata_pops.obs.ROI.unique().tolist()[:number_rois_to_use])]

## Run analysis

In [None]:
adata = nenrich.Neighborhood_Identification(data = adata_pops, 
                                             cluster_col='population',
                                             radius=10,
                                             keep_cols='all',
                                             bootstrap=2)

In [None]:
adata.write('adata_nenrich.h5ad')

# 2. Consensus clustering

### Picking number for PCA

In [None]:
from sklearn.decomposition import PCA

data = adata.X

pca = PCA()
pca.fit(data)

plt.figure(figsize= (10,8))
plt.plot(range(0,len(pca.explained_variance_ratio_)), pca.explained_variance_ratio_.cumsum(), marker = 'o', linestyle='--')
plt.title('Explained variance by components')
plt.xlabel('Number of components')
plt.ylabel('Cumulative explained variance')

In [None]:
# Pick the number of PCs from above graph where you're getting about 90ish percent of the variability accounted for
n_for_pca = 25

# Calculate PCA, this must be done before BBKNN
sc.tl.pca(adata, n_comps=n_for_pca)

## OPTIONAL - Batch correction using Harmony
I'm not sure this is required, you can see try and see how it affect the identified neighbourhoods

In [None]:
import scanpy.external as sce

batch_correction_obs = 'Case'

# Compute Harmony correction
sce.pp.harmony_integrate(adata, key=batch_correction_obs, basis='X_pca', adjusted_basis='X_pca')

## Clustering

In [None]:
pop_id.consensus(adata,
          n_clusters=10,
          n_runs=200, # No reason not to do loads here, if you have time
          save=False)

## Plotting results

In [None]:
pop_obs = adata.obs.columns.tolist()[-1] #Specify clustering, or just use last one

# Minimum and maximum of plotting
v = .5

number_pops = int(adata.var_names.shape[0]/3)
sc.tl.dendrogram(adata, groupby=pop_obs, n_pcs=25)
sc.pl.matrixplot(adata, adata.var_names[:number_pops], groupby=pop_obs, dendrogram=True, vmin=-v, vmax=v)
sc.pl.matrixplot(adata, adata.var_names[number_pops: (number_pops+number_pops)], groupby=pop_obs, dendrogram=True, cmap='coolwarm', vmin=-v, vmax=v)
sc.pl.matrixplot(adata, adata.var_names[(number_pops+number_pops):], groupby=pop_obs, dendrogram=True, cmap='coolwarm', vmin=-v, vmax=v)

# 3. Save as AnnData_nerich

In [None]:
adata.write('adata_nenrich.h5ad')

# 4. Plot Voronois

In [None]:
enrich_pop_obs = 'sc3s_10'
########################

from nenrich import draw_voronoi_scatter

figure_dir=Path('Figures','Voronoi')
os.makedirs(figure_dir, exist_ok=True)

# Ensure correct format of categories
adata.obs[enrich_pop_obs] = adata.obs[enrich_pop_obs].astype('int').astype('category')

for i in adata.obs['ROI'].unique().tolist():

    spot = adata.obs[adata.obs['ROI']==i]

    _ = draw_voronoi_scatter(spot=spot,
                             c=[],
                             voronoi_palette = sc.pl.palettes.vega_20_scanpy,
                             X='X_loc',
                             Y='Y_loc',
                             voronoi_hue=enrich_pop_obs)
    
    plt.savefig(Path(figure_dir, str(i)+'_voronoi.png'), bbox_inches='tight', dpi=200)
    plt.close()
    


## View voronoi colourmap

In [None]:
print('Voronoi regions:')
print(adata.obs[enrich_pop_obs].cat.categories)

ListedColormap([sc.pl.palettes.vega_20_scanpy[x] for x in adata.obs[enrich_pop_obs].cat.categories])

# <font color='orange'>OPTION 2 - Leiden Clustering</font>
I think the above option is currently better, but this is here if you want to have a look.

### Calculate PCA

In [None]:
n_for_pca = 25

# Define the 'obs' which defines the different cases
batch_correction_obs = 'TMA'

# Calculate PCA, this must be done before BBKNN
sc.tl.pca(adata, n_comps=n_for_pca)

# BBKNN - it is used in place of the scanpy 'neighbors' command that calculates nearest neighbours in the feature space
sc.external.pp.bbknn(adata, batch_key=batch_correction_obs, n_pcs=n_for_pca)

### Calculate UMAP

In [None]:
sc.tl.umap(adata, min_dist=0.1)

In [None]:
sc.pl.umap(adata, color=adata.var_names)

### Calculate Leidens

In [None]:
leiden_resolutions = [0.2]

for r in leiden_resolutions:
    leiden_key = f'leiden_{str(r)}'
    
    sc.tl.leiden(adata, resolution=r, key_added = leiden_key)
    
    print(f'Number of groups in {leiden_key}: {str(adata.obs[leiden_key].cat.categories.shape[0])}, only showing first 20')

    sc.pl.umap(adata[adata.obs[leiden_key].isin([str(x) for x in range(0,21)])],                     
                     color=leiden_key)
    
    sc.tl.dendrogram(adata, groupby=leiden_key, n_pcs=n_for_pca)

    sc.pl.matrixplot(adata, adata.var_names, groupby=[leiden_key], dendrogram=True)

In [None]:
adata.write('adata_nenrich')

### Prune Leidens

In [None]:
new_leiden_name = 'leiden_merged'

remap_dict= nenrich.prune_leiden_using_dendrogram(adata,
                                  leiden_obs='leiden_0.2',
                                  new_obs=new_leiden_name,
                                  mode='max',
                                  max_leiden=10)


sc.pl.umap(adata, color=new_leiden_name)

sc.tl.dendrogram(adata, groupby=new_leiden_name, n_pcs=n_for_pca)

sc.pl.matrixplot(adata[adata.obs[new_leiden_name].isin([str(x) for x in range(0,50)])], adata.var_names, groupby=[new_leiden_name], dendrogram=True, vmax=2)