In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import numpy as np
import pandas as pd
import scanpy as sc

import glob
import os
import scrna
import h5py

from micron2 import load_as_anndata
from micron2 import cluster_leiden_cu
from cuml import UMAP

import tqdm.auto as tqdm

from matplotlib import pyplot as plt
from matplotlib import rcParams
import time

from micron2.data import staining_border_nonzero

In [None]:
datasets = !ls -lha /storage/codex/preprocessed_data/*/*.hdf5 | awk '{print $9}'
datasets

In [None]:
len(datasets)

In [None]:
ring_channels = ['CD45', 'CD20', 'CD3e', 'CD45RO', 'CD45RA', 'CD8', 'CD4']
adatas = []
sample_ids = []
for path in datasets:
    try:
        ad = load_as_anndata(path, 
                             recover_tile_nuclei=False, 
                             as_sparse = True
                            )
    except:
        print('failed to load', path)
        continue
    
#     # Apply staining border function
#     with h5py.File(ad.uns['source_data'], 'r') as h5f:
#         ncells = ad.shape[0]
#         h5_ncells = h5f['cell_intensity']['DAPI'][:].shape[0]
#         print(ncells, h5_ncells)
#         ring_positive_pct = pd.DataFrame(index=ad.obs_names,
#                                          columns=[f'{ch}_ringpct' for ch in ring_channels],
#                                          dtype=np.float32
#                                         )
#         tstart = time.time()
#         for i in tqdm.trange(ncells):
#             m = h5f['meta']['nuclear_masks'][i,...]
#             vect = []
#             for ch in ring_channels:
#                 x = h5f['cells'][ch][i,...]
#                 v = staining_border_nonzero(x,m)
#                 vect.append(v)
#             ring_positive_pct.loc[ad.obs_names[i],:] = vect
#         tend = time.time()
#         print(f'elapsed time: {tend-tstart:3.4f}s')
#     ad.obs = pd.concat([ad.obs, ring_positive_pct], axis=1)

    
    adatas.append(ad.copy()) 
    s = os.path.splitext(os.path.basename(path))[0]
    sample_ids.append(s)

adata = adatas[0].concatenate(adatas[1:], batch_key='sample_id', batch_categories=sample_ids, 
                              index_unique = '-')
print(sample_ids)
adata.raw = adata


In [None]:
print(len(sample_ids))
sample_ids

In [None]:
sample_ids_padded = sample_ids
while len(sample_ids_padded) % 2 != 0:
    sample_ids_padded.append('pass')
sample_layout = np.expand_dims(np.array(sample_ids_padded), axis=0).reshape(2,-1)
# sample_layout = np.expand_dims(np.array(sample_ids+['pass']), axis=0).reshape(2,-1)
print(np.array(sample_ids).shape)
print(sample_layout.shape)
print(sample_layout)

In [None]:
# sample_layout = np.array(sample_ids).reshape(2,-1)
def shift_coordinates(adata, sample_layout):
    # rows and columns are flipped again
    
    nr,nc = sample_layout.shape
    print('layout', nr,nc)
    
    coords = adata.obsm['coordinates'].copy()
    # Flip dim1 (vertical dimension)
    coords[:,1] = -coords[:,1]
        
    #nrange=nr+1 if nr==1 else nr
    #crange=nc+1 if nc==1 else nc
    print('ranges', nr, nc)
    for r2 in range(nr):
        print('row', r2)
        if r2>0:
            ref_row = r2-1 if r2>0 else r2
            print('\treference row:', ref_row)
            #row_ref_slide = sample_layout[ref_row,c2] 
            row_shift = current_row_max
            print('\trow shift:', row_shift)
        else: 
            row_shift = 0
        
        current_row_max = 0
        for c2 in range(nc):
            print('\tcolumn', c2)
            if r2==0 and c2==0: continue
            
            # curr_row=r2-1 if r2>=nr else r2
            # curr_col=c2-1 if c2>=nc else c2
            
            print('\t\tlocation', r2, c2)
            target_slide = sample_layout[r2,c2]
            if target_slide == 'pass':
                continue
            if target_slide is None:
                continue
            print('\t\ttarget slide:', target_slide)
            
            ref_col = c2-1 if c2>0 else c2
            print('\t\treference col:', ref_col)
            col_ref_slide = sample_layout[r2,ref_col] 
            #print('row reference:', row_ref_slide)
            print('\t\tcol reference:', col_ref_slide)
            
            target_idx = adata.obs.sample_id.values==target_slide
            target_coords = coords[target_idx].copy()
            print('\t\tstart:', max(target_coords[:,0]), max(target_coords[:,1]))
                
            print('\t\tshifting rows (dim1) by', row_shift)
            target_coords[:,1] += row_shift
            if max(target_coords[:,1]) > current_row_max:
                print('\t\tfound new row max')
                current_row_max = max(target_coords[:,1])
            
            if col_ref_slide != target_slide:
                col_ref = coords[adata.obs.sample_id==col_ref_slide]
                col_max = max(col_ref[:,0])
                print('\t\tshifting cols (dim0) by', col_max)
                target_coords[:,0] += col_max
            
            print('\t\tend:', max(target_coords[:,0]), max(target_coords[:,1]))
            coords[target_idx] = target_coords
            
    # Flip dim1 (vertical dimension)
    coords[:,1] = -coords[:,1]
    return coords
    
print(sample_layout)
shifted_coords = shift_coordinates(adata, sample_layout)
adata.obsm['coordinates_shift'] = shifted_coords
adata.uns['sample_layout'] = sample_layout



In [None]:
sample_id_printing = [x.split('Breast_')[1].replace('_','\n') for x in adata.obs.sample_id]
adata.obs['sample_id_printing'] = sample_id_printing

In [None]:
rcParams['figure.dpi'] = 300
rcParams['figure.facecolor'] = (1,1,1,1)
r = np.max(np.abs(adata.obsm['coordinates_shift']), axis=0)
r = r[0]/r[1]
plt.figure(figsize=(r*5,5))
sc.pl.embedding(adata, basis='coordinates_shift', color='sample_id_printing', ax=plt.gca(), 
                legend_loc='on data')

In [None]:
print(adata.shape)

In [None]:
adata.uns['channels'] = ad.uns['channels']

In [None]:
adata

In [None]:
!ls -lha /storage/codex/datasets_v1/

In [None]:
adata.write("/storage/codex/datasets_v1/merged_v3.h5ad")