In [1]:
import json
import os
import sys
from copy import deepcopy
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import skimage
import tifffile
import yaml

In [2]:
%load_ext autoreload

In [3]:
%autoreload 2

In [4]:
from mushroom.mushroom import Mushroom, DEFAULT_CONFIG
import mushroom.data.he as he
import mushroom.data.multiplex as multiplex
import mushroom.utils as utils
import mushroom.visualization.utils as vis_utils

In [5]:
source_root = '/diskmnt/Projects/Users/estorrs/mushroom/data'
target_root = '/data/estorrs/mushroom/data'

In [6]:
run_dir = '/data/estorrs/mushroom/data/projects/submission_v1/kidney/'

## define inputs

In [7]:
fps = sorted(utils.listfiles(run_dir, regex=r'registered/metadata.yaml$'))
fps

['/data/estorrs/mushroom/data/projects/submission_v1/kidney/NMK12F-Fp1/registered/metadata.yaml',
 '/data/estorrs/mushroom/data/projects/submission_v1/kidney/NMK12M-Fp1/registered/metadata.yaml',
 '/data/estorrs/mushroom/data/projects/submission_v1/kidney/NMK3F-Fp1/registered/metadata.yaml',
 '/data/estorrs/mushroom/data/projects/submission_v1/kidney/NMK3M-Fp1/registered/metadata.yaml',
 '/data/estorrs/mushroom/data/projects/submission_v1/kidney/NMK92F-Fp1/registered/metadata.yaml',
 '/data/estorrs/mushroom/data/projects/submission_v1/kidney/NMK92F2-Fc1U1Bs2/registered/metadata.yaml',
 '/data/estorrs/mushroom/data/projects/submission_v1/kidney/NMK92M-Fp1/registered/metadata.yaml',
 '/data/estorrs/mushroom/data/projects/submission_v1/kidney/NMK92M1-Fc1U1Bs2/registered/metadata.yaml',
 '/data/estorrs/mushroom/data/projects/submission_v1/kidney/P1F1MnR-Fp1/registered/metadata.yaml',
 '/data/estorrs/mushroom/data/projects/submission_v1/kidney/P1M3MnR-Fp1/registered/metadata.yaml',
 '/data/

In [8]:
DEFAULT_CONFIG

{'sections': None,
 'dtype_to_chkpt': None,
 'dtype_specific_params': {'visium': {'trainer_kwargs': {'tiling_method': 'radius'}}},
 'sae_kwargs': {'size': 8,
  'patch_size': 1,
  'encoder_dim': 128,
  'codebook_dim': 64,
  'num_clusters': (8, 4, 2),
  'dtype_to_decoder_dims': {'multiplex': (256, 128, 64),
   'he': (256, 128, 10),
   'visium': (256, 512, 2048),
   'xenium': (256, 256, 256),
   'cosmx': (256, 512, 1024),
   'points': (256, 512, 1024)},
  'recon_scaler': 1.0,
  'neigh_scaler': 0.01},
 'trainer_kwargs': {'input_resolution': 1.0,
  'target_resolution': 0.02,
  'pct_expression': 0.05,
  'log_base': 2.718281828459045,
  'tiling_method': 'grid',
  'tiling_radius': 1.0,
  'batch_size': 128,
  'num_workers': 0,
  'devices': 1,
  'accelerator': 'cpu',
  'max_epochs': 1,
  'steps_per_epoch': 1000,
  'lr': 0.0001,
  'out_dir': './outputs',
  'save_every': 1,
  'log_every_n_steps': 10,
  'logger_type': 'tensorboard',
  'logger_project': 'portobello',
  'channel_mapping': {},
  'data

In [9]:
def alter_filesystem(config, source_root, target_root):
    for entry in config['sections']:
        for mapping in entry['data']:
            mapping['filepath'] = mapping['filepath'].replace(source_root, target_root)
    
    if config['trainer_kwargs']['data_mask'] is not None:
        config['trainer_kwargs']['data_mask'] = config['trainer_kwargs']['data_mask'].replace(source_root, target_root)
        
    return config

In [10]:
case_to_config = {}

for fp in fps:
    metadata = yaml.safe_load(open(fp))
    
    case = fp.split('/')[-3]
    
    config = deepcopy(DEFAULT_CONFIG)
    config = utils.recursive_update(config, {
        'sections': metadata['sections'],
        'trainer_kwargs': {
            'input_resolution': metadata['resolution'],
            'target_resolution': .02,
            'out_dir': os.path.join(run_dir, case, 'mushroom'),
            'accelerator': 'gpu',
            'steps_per_epoch': 1000,
        }
    })
    
    config['dtype_specific_params'] = utils.recursive_update(config['dtype_specific_params'], {
        'visium': {
            'trainer_kwargs': {
                'target_resolution': .01  # lower resolution for visium
            }
        }
    })
    
    config = alter_filesystem(config, source_root, target_root)
    
    case_to_config[case] = config

In [11]:
case_to_config.keys()

dict_keys(['NMK12F-Fp1', 'NMK12M-Fp1', 'NMK3F-Fp1', 'NMK3M-Fp1', 'NMK92F-Fp1', 'NMK92F2-Fc1U1Bs2', 'NMK92M-Fp1', 'NMK92M1-Fc1U1Bs2', 'P1F1MnR-Fp1', 'P1M3MnR-Fp1', 'P1_F1LM3l', 'P21_F2RM6R'])

In [12]:
case_to_config

{'NMK12F-Fp1': {'sections': [{'data': [{'dtype': 'xenium',
      'filepath': '/data/estorrs/mushroom/data/projects/submission_v1/kidney/NMK12F-Fp1/registered/NMK12F-Fp1_xenium.h5ad'}],
    'position': 0,
    'sid': 'NMK12F-Fp1-U1'}],
  'dtype_to_chkpt': None,
  'dtype_specific_params': {'visium': {'trainer_kwargs': {'tiling_method': 'radius',
     'target_resolution': 0.01}}},
  'sae_kwargs': {'size': 8,
   'patch_size': 1,
   'encoder_dim': 128,
   'codebook_dim': 64,
   'num_clusters': (8, 4, 2),
   'dtype_to_decoder_dims': {'multiplex': (256, 128, 64),
    'he': (256, 128, 10),
    'visium': (256, 512, 2048),
    'xenium': (256, 256, 256),
    'cosmx': (256, 512, 1024),
    'points': (256, 512, 1024)},
   'recon_scaler': 1.0,
   'neigh_scaler': 0.01},
  'trainer_kwargs': {'input_resolution': 1.0,
   'target_resolution': 0.02,
   'pct_expression': 0.05,
   'log_base': 2.718281828459045,
   'tiling_method': 'grid',
   'tiling_radius': 1.0,
   'batch_size': 128,
   'num_workers': 0,
  

In [13]:
for case, config in case_to_config.items():
    print(case)
    mushroom = Mushroom.from_config(config)
    mushroom.train()
    mushroom.embed_sections()
    mushroom.save()
    
    del(mushroom)

INFO:root:loading spore for xenium
INFO:root:singleton section detected, creating temporary duplicate
INFO:root:data mask detected
INFO:root:starting xenium processing


NMK12F-Fp1


INFO:root:using 479 channels
INFO:root:2 sections detected: ['NMK12F-Fp1-U1', 'NMK12F-Fp1-U1_dup']
INFO:root:processing sections
INFO:root:generating image data for section NMK12F-Fp1-U1

KeyboardInterrupt

