In [2]:
import json
import os
import sys
import shutil
from copy import deepcopy
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import skimage
import tifffile
import yaml

In [3]:
%load_ext autoreload

In [4]:
%autoreload 2

In [5]:
from mushroom.mushroom import Mushroom, DEFAULT_CONFIG
import mushroom.data.he as he
import mushroom.data.multiplex as multiplex
import mushroom.utils as utils
import mushroom.visualization.utils as vis_utils

In [6]:
source_root = '/diskmnt/Projects/Users/estorrs/mushroom/data'
target_root = '/data/estorrs/mushroom/data'

In [7]:
run_dir = '/data/estorrs/mushroom/data/projects/submission_v1/lightsheet'

In [8]:
yaml.safe_load(
    open('/data/estorrs/mushroom/data/projects/submission_v1/lightsheet/HT427PI-A1_20x/registered/metadata.yaml')
)

{'resolution': 0.1273782615616727,
 'sections': [{'data': [{'dtype': 'multiplex',
     'filepath': '/diskmnt/Projects/Users/estorrs/mushroom/data/projects/submission_v1/lightsheet/HT427PI-A1_20x/registered/HT427PI-A1-U0_multiplex.ome.tiff'}],
   'position': 0,
   'sid': 'HT427PI-A1-U0'},
  {'data': [{'dtype': 'multiplex',
     'filepath': '/diskmnt/Projects/Users/estorrs/mushroom/data/projects/submission_v1/lightsheet/HT427PI-A1_20x/registered/HT427PI-A1-U1_multiplex.ome.tiff'}],
   'position': 7,
   'sid': 'HT427PI-A1-U1'},
  {'data': [{'dtype': 'multiplex',
     'filepath': '/diskmnt/Projects/Users/estorrs/mushroom/data/projects/submission_v1/lightsheet/HT427PI-A1_20x/registered/HT427PI-A1-U2_multiplex.ome.tiff'}],
   'position': 15,
   'sid': 'HT427PI-A1-U2'},
  {'data': [{'dtype': 'multiplex',
     'filepath': '/diskmnt/Projects/Users/estorrs/mushroom/data/projects/submission_v1/lightsheet/HT427PI-A1_20x/registered/HT427PI-A1-U3_multiplex.ome.tiff'}],
   'position': 23,
   'sid': '

## define inputs

In [9]:
fps = sorted(utils.listfiles(run_dir, regex=r'registered/metadata.yaml$'))
fps = [fp for fp in fps if '_20x' in fp]
len(fps), fps

(50,
 ['/data/estorrs/mushroom/data/projects/submission_v1/lightsheet/HT427PI-A1_20x/registered/metadata.yaml',
  '/data/estorrs/mushroom/data/projects/submission_v1/lightsheet/HT427PI-A2_20x/registered/metadata.yaml',
  '/data/estorrs/mushroom/data/projects/submission_v1/lightsheet/HT427PI-A3_20x/registered/metadata.yaml',
  '/data/estorrs/mushroom/data/projects/submission_v1/lightsheet/HT427PI-A4_20x/registered/metadata.yaml',
  '/data/estorrs/mushroom/data/projects/submission_v1/lightsheet/HT442PI-A1_20x/registered/metadata.yaml',
  '/data/estorrs/mushroom/data/projects/submission_v1/lightsheet/HT442PI-A2_20x/registered/metadata.yaml',
  '/data/estorrs/mushroom/data/projects/submission_v1/lightsheet/HT442PI-A3_20x/registered/metadata.yaml',
  '/data/estorrs/mushroom/data/projects/submission_v1/lightsheet/HT442PI-A4_20x/registered/metadata.yaml',
  '/data/estorrs/mushroom/data/projects/submission_v1/lightsheet/HT460P1-A1_20x/registered/metadata.yaml',
  '/data/estorrs/mushroom/data/p

In [10]:
DEFAULT_CONFIG

{'sections': None,
 'dtype_to_chkpt': None,
 'dtype_specific_params': {'visium': {'trainer_kwargs': {'tiling_method': 'radius'}}},
 'sae_kwargs': {'size': 8,
  'patch_size': 1,
  'encoder_dim': 128,
  'codebook_dim': 64,
  'num_clusters': (8, 4, 2),
  'dtype_to_decoder_dims': {'multiplex': (256, 128, 64),
   'he': (256, 128, 10),
   'visium': (256, 512, 2048),
   'xenium': (256, 256, 256),
   'cosmx': (256, 512, 1024),
   'points': (256, 512, 1024)},
  'recon_scaler': 1.0,
  'neigh_scaler': 0.01},
 'trainer_kwargs': {'input_resolution': 1.0,
  'target_resolution': 0.02,
  'pct_expression': 0.05,
  'log_base': 2.718281828459045,
  'tiling_method': 'grid',
  'tiling_radius': 1.0,
  'batch_size': 128,
  'num_workers': 0,
  'devices': 1,
  'accelerator': 'cpu',
  'max_epochs': 1,
  'steps_per_epoch': 1000,
  'lr': 0.0001,
  'out_dir': './outputs',
  'save_every': 1,
  'log_every_n_steps': 10,
  'logger_type': 'tensorboard',
  'logger_project': 'portobello',
  'channel_mapping': {},
  'data

In [11]:
def alter_filesystem(config, source_root, target_root):
    for entry in config['sections']:
        for mapping in entry['data']:
            mapping['filepath'] = mapping['filepath'].replace(source_root, target_root)
    
    if config['trainer_kwargs']['data_mask'] is not None:
        config['trainer_kwargs']['data_mask'] = config['trainer_kwargs']['data_mask'].replace(source_root, target_root)
        
    return config

In [12]:
case_to_config = {}
spacings = [1, 2, 4, 8]
total_sections = 100
spacing_to_steps = {
    1: 4000,
    2: 2000,
    4: 1500,
    8: 1000,
}

for spacing in spacings:
    for fp in fps:
        metadata = yaml.safe_load(open(fp))
            
        
        start = (len(metadata['sections']) // 2) - (total_sections // 2)
        
        
        metadata['sections'] = metadata['sections'][start:start + total_sections]
        
        for i, entry in enumerate(metadata['sections']):
            entry['position'] = i # makes reasoning about position in volume easier
        
        n_sections = len(metadata['sections'])
        
        idxs = [i for i in range(n_sections) if i % spacing == 0]
        
#         size = int(pct * n_sections)
#         idxs = np.random.choice(np.arange(n_sections), size=size, replace=False)
        metadata['sections'] = [x for i, x in enumerate(metadata['sections']) if i in idxs]

        case = fp.split('/')[-3]
        name = f'{case}_keepevery{spacing}'
        steps_per_epoch = spacing_to_steps[spacing]

        config = deepcopy(DEFAULT_CONFIG)
        config = utils.recursive_update(config, {
            'sections': metadata['sections'],
            'trainer_kwargs': {
                'input_resolution': metadata['resolution'],
                'target_resolution': .01,
                'out_dir': os.path.join(run_dir, case, 'mushroom', name),
                'accelerator': 'gpu',
                'steps_per_epoch': steps_per_epoch,
            }
        })

        config = alter_filesystem(config, source_root, target_root)

        case_to_config[name] = config

In [13]:
case_to_config.keys()

dict_keys(['HT427PI-A1_20x_keepevery1', 'HT427PI-A2_20x_keepevery1', 'HT427PI-A3_20x_keepevery1', 'HT427PI-A4_20x_keepevery1', 'HT442PI-A1_20x_keepevery1', 'HT442PI-A2_20x_keepevery1', 'HT442PI-A3_20x_keepevery1', 'HT442PI-A4_20x_keepevery1', 'HT460P1-A1_20x_keepevery1', 'HT460P1-A2_20x_keepevery1', 'HT460P1-A3_20x_keepevery1', 'HT461B1-A2_20x_keepevery1', 'HT461B1-A3_20x_keepevery1', 'HT462P1-A2_20x_keepevery1', 'HT462P1-A3_20x_keepevery1', 'HT486B1-A1_20x_keepevery1', 'HT491P1-A1_20x_keepevery1', 'HT491P1-A2_20x_keepevery1', 'HT491P1-A4_20x_keepevery1', 'HT495-A2_20x_keepevery1', 'HT495-A3_20x_keepevery1', 'HT495-A4_20x_keepevery1', 'HT497P1-A1_20x_keepevery1', 'HT497P1-A2_20x_keepevery1', 'HT497P1-A3_20x_keepevery1', 'HT502P1-A2_20x_keepevery1', 'HT502P1-A3_20x_keepevery1', 'HT514B1-A2_20x_keepevery1', 'HT514B1-A3_20x_keepevery1', 'HT517B1-A2_20x_keepevery1', 'HT530P1-A1_20x_keepevery1', 'HT530P1-A2_20x_keepevery1', 'HT530P1-A3_20x_keepevery1', 'HT530P1-A4_20x_keepevery1', 'HT535P1-

In [14]:
next(iter(case_to_config.items()))

('HT427PI-A1_20x_keepevery1',
 {'sections': [{'data': [{'dtype': 'multiplex',
      'filepath': '/data/estorrs/mushroom/data/projects/submission_v1/lightsheet/HT427PI-A1_20x/registered/HT427PI-A1-U289_multiplex.ome.tiff'}],
    'position': 0,
    'sid': 'HT427PI-A1-U289'},
   {'data': [{'dtype': 'multiplex',
      'filepath': '/data/estorrs/mushroom/data/projects/submission_v1/lightsheet/HT427PI-A1_20x/registered/HT427PI-A1-U290_multiplex.ome.tiff'}],
    'position': 1,
    'sid': 'HT427PI-A1-U290'},
   {'data': [{'dtype': 'multiplex',
      'filepath': '/data/estorrs/mushroom/data/projects/submission_v1/lightsheet/HT427PI-A1_20x/registered/HT427PI-A1-U291_multiplex.ome.tiff'}],
    'position': 2,
    'sid': 'HT427PI-A1-U291'},
   {'data': [{'dtype': 'multiplex',
      'filepath': '/data/estorrs/mushroom/data/projects/submission_v1/lightsheet/HT427PI-A1_20x/registered/HT427PI-A1-U292_multiplex.ome.tiff'}],
    'position': 3,
    'sid': 'HT427PI-A1-U292'},
   {'data': [{'dtype': 'multip

In [15]:
completed_fps = sorted(utils.listfiles('/data/estorrs/mushroom/data/projects/submission_v1/lightsheet/', regex=r'.pkl$'))
completed_fps

['/data/estorrs/mushroom/data/projects/submission_v1/lightsheet/HT427PI-A1_20x/mushroom/HT427PI-A1_20x_keepevery1/outputs.pkl',
 '/data/estorrs/mushroom/data/projects/submission_v1/lightsheet/HT427PI-A1_20x/mushroom/HT427PI-A1_20x_keepevery2/outputs.pkl',
 '/data/estorrs/mushroom/data/projects/submission_v1/lightsheet/HT427PI-A1_20x/mushroom/HT427PI-A1_20x_keepevery4/outputs.pkl',
 '/data/estorrs/mushroom/data/projects/submission_v1/lightsheet/HT427PI-A2_20x/mushroom/HT427PI-A2_20x_keepevery1/outputs.pkl',
 '/data/estorrs/mushroom/data/projects/submission_v1/lightsheet/HT427PI-A2_20x/mushroom/HT427PI-A2_20x_keepevery2/outputs.pkl',
 '/data/estorrs/mushroom/data/projects/submission_v1/lightsheet/HT427PI-A2_20x/mushroom/HT427PI-A2_20x_keepevery4/outputs.pkl',
 '/data/estorrs/mushroom/data/projects/submission_v1/lightsheet/HT427PI-A3_20x/mushroom/HT427PI-A3_20x_keepevery1/outputs.pkl',
 '/data/estorrs/mushroom/data/projects/submission_v1/lightsheet/HT427PI-A3_20x/mushroom/HT427PI-A3_20x_k

In [16]:
completed = []
# for fp in completed_fps:
#     case = fp.split('/')[-2]
#     completed.append(case)

In [17]:
set(case_to_config.keys()) - set(completed)

{'HT427PI-A1_20x_keepevery8',
 'HT427PI-A2_20x_keepevery8',
 'HT427PI-A3_20x_keepevery8',
 'HT427PI-A4_20x_keepevery8',
 'HT442PI-A1_20x_keepevery8',
 'HT442PI-A2_20x_keepevery8',
 'HT442PI-A3_20x_keepevery8',
 'HT442PI-A4_20x_keepevery4',
 'HT442PI-A4_20x_keepevery8',
 'HT460P1-A1_20x_keepevery4',
 'HT460P1-A1_20x_keepevery8',
 'HT460P1-A2_20x_keepevery4',
 'HT460P1-A2_20x_keepevery8',
 'HT460P1-A3_20x_keepevery4',
 'HT460P1-A3_20x_keepevery8',
 'HT461B1-A2_20x_keepevery4',
 'HT461B1-A2_20x_keepevery8',
 'HT461B1-A3_20x_keepevery4',
 'HT461B1-A3_20x_keepevery8',
 'HT462P1-A2_20x_keepevery4',
 'HT462P1-A2_20x_keepevery8',
 'HT462P1-A3_20x_keepevery4',
 'HT462P1-A3_20x_keepevery8',
 'HT486B1-A1_20x_keepevery4',
 'HT486B1-A1_20x_keepevery8',
 'HT491P1-A1_20x_keepevery4',
 'HT491P1-A1_20x_keepevery8',
 'HT491P1-A2_20x_keepevery4',
 'HT491P1-A2_20x_keepevery8',
 'HT491P1-A4_20x_keepevery4',
 'HT491P1-A4_20x_keepevery8',
 'HT495-A2_20x_keepevery4',
 'HT495-A2_20x_keepevery8',
 'HT495-A3_20x

In [None]:
%%time
for case, config in case_to_config.items():
    if case not in completed:
        print(case)
        mushroom = Mushroom.from_config(config)
        mushroom.train()
        mushroom.embed_sections()

        mushroom.save()
        z_scaler = 1.
        for level in range(3):
            dtype_to_volume = mushroom.generate_interpolated_volumes(z_scaler=z_scaler, level=level, integrate=False)
            volume = dtype_to_volume['multiplex']
            np.save(os.path.join(mushroom.trainer_kwargs['out_dir'], f'volume_l{level}.npy'), volume)

        # remove chkpts because we wont need them and they take up space
        for dt, fp in mushroom.dtype_to_chkpt.items():
            os.remove(fp)

        del(mushroom)
    

In [20]:
# chkpt_fps = sorted(utils.listfiles('/data/estorrs/mushroom/data/projects/submission_v1/lightsheet/', regex=r'pt$'))
# chkpt_fps

In [21]:
# for fp in chkpt_fps:
#     os.remove(fp)