In [None]:
import json
import os
import sys
from copy import deepcopy
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import skimage
import tifffile
import yaml

In [None]:
%load_ext autoreload

In [None]:
%autoreload 2

In [None]:
from mushroom.mushroom import Mushroom, DEFAULT_CONFIG
import mushroom.data.he as he
import mushroom.data.multiplex as multiplex
import mushroom.utils as utils
import mushroom.visualization.utils as vis_utils

In [None]:
source_root = '/diskmnt/Projects/Users/estorrs/mushroom/data'
target_root = '/data/estorrs/mushroom/data'

In [None]:
run_dir = '/data/estorrs/mushroom/data/projects/submission_v1/lightsheet'

In [None]:
yaml.safe_load(
    open('/data/estorrs/mushroom/data/projects/submission_v1/lightsheet/HT427PI-A1_20x/registered/metadata.yaml')
)

## define inputs

In [None]:
fps = sorted(utils.listfiles(run_dir, regex=r'registered/metadata.yaml$'))
fps = [fp for fp in fps if '_20x' in fp]
len(fps), fps

In [None]:
DEFAULT_CONFIG

In [None]:
def alter_filesystem(config, source_root, target_root):
    for entry in config['sections']:
        for mapping in entry['data']:
            mapping['filepath'] = mapping['filepath'].replace(source_root, target_root)
    
    if config['trainer_kwargs']['data_mask'] is not None:
        config['trainer_kwargs']['data_mask'] = config['trainer_kwargs']['data_mask'].replace(source_root, target_root)
        
    return config

In [None]:
case_to_config = {}
spacings = [1, 2, 4, 8]
total_sections = 100

for spacing in spacings:
    for fp in fps:
        metadata = yaml.safe_load(open(fp))
            
        
        start = (len(metadata['sections']) // 2) - (total_sections // 2)
        
        
        metadata['sections'] = metadata['sections'][start:start + total_sections]
        
        for i, entry in enumerate(metadata['sections']):
            entry['position'] = i # makes reasoning about position in volume easier
        
        n_sections = len(metadata['sections'])
        
        idxs = [i for i in range(n_sections) if i % spacing == 0]
        
#         size = int(pct * n_sections)
#         idxs = np.random.choice(np.arange(n_sections), size=size, replace=False)
        metadata['sections'] = [x for i, x in enumerate(metadata['sections']) if i in idxs]

        case = fp.split('/')[-3]
        name = f'{case}_keepevery{spacing}'

        config = deepcopy(DEFAULT_CONFIG)
        config = utils.recursive_update(config, {
            'sections': metadata['sections'],
            'trainer_kwargs': {
                'input_resolution': metadata['resolution'],
                'target_resolution': .02,
                'out_dir': os.path.join(run_dir, case, 'mushroom', name),
                'accelerator': 'gpu',
                'steps_per_epoch': 1000,
            }
        })

        config['dtype_specific_params'] = utils.recursive_update(config['dtype_specific_params'], {
            'visium': {
                'trainer_kwargs': {
                    'target_resolution': .01  # lower resolution for visium
                }
            }
        })

        config = alter_filesystem(config, source_root, target_root)

        case_to_config[name] = config

In [None]:
case_to_config.keys()

In [None]:
next(iter(case_to_config.items()))

In [None]:
%%time
for case, config in case_to_config.items():
    print(case)
    mushroom = Mushroom.from_config(config)
    mushroom.train()
    mushroom.embed_sections()
    
    mushroom.save()
    z_scaler = 1.
    for level in range(3):
        dtype_to_volume = mushroom.generate_interpolated_volumes(z_scaler=z_scaler, level=level, integrate=False)
        volume = dtype_to_volume['multiplex']
        np.save(os.path.join(mushroom.trainer_kwargs['out_dir'], f'volume_l{level}.npy'), volume)
    
    del(mushroom)
    