In [None]:
%matplotlib inline

import csv
import itertools
import os
import pprint
from xml.etree import ElementTree

import h5py
import numpy as np
import pandas as pd
import skimage.io

import matplotlib.pyplot as plt

In [None]:
base_dir = '/media/ssd/slmStimPipeline'

recording = 0

if recording == 0:
    session_name = '20200124M74'
    recording_name = 'stimL23-000'
    recording_name2 = 'stim3dL23withBeam-001'
elif recording == 1:
    session_name = '20200201M79'
    recording_name = 'VRmm-000'
    recording_name2 = recording_name
elif recording == 2:
    session_name = '20200202M79'
    recording_name = 'stimL23-000'
    recording_name2 = 'stimL23-001'
else:
    raise ValueError('Bad recording %s' % recording)

base = os.path.join(base_dir, session_name, recording_name, recording_name2)
print(base)

In [None]:
def metadata(base):
    mdata_root = ElementTree.fromstring(open(base + '.xml').read())

    def state_value(key, type_fn=str):
        element = mdata_root.find(f'.//PVStateValue[@key="{key}"]')
        value = element.attrib['value']
        return type_fn(value) 

    def indexed_value(key, index, type_fn=None):
        element = mdata_root.find(f'.//PVStateValue[@key="{key}"]/IndexedValue[@index="{index}"]')
        value = element.attrib['value']
        return type_fn(value)

    num_frames = len(mdata_root.findall('Sequence'))
    num_channels = len(mdata_root.find('Sequence/Frame').findall('File'))
    num_z_planes = len(mdata_root.find('Sequence').findall('Frame'))
    num_y_px = state_value('linesPerFrame', int)
    num_x_px = state_value('pixelsPerLine', int)

    laser_power = indexed_value('laserPower', 0, float)
    laser_wavelength = indexed_value('laserWavelength', 0, int)
    
    frame_period = state_value('framePeriod', float)
    optical_zoom = state_value('opticalZoom', float)

    fname_voltage_xml = base + '_Cycle00001_VoltageRecording_001.xml'
    voltage_root = ElementTree.fromstring(open(fname_voltage_xml).read()) 

    channels = {}
    for signal in voltage_root.findall('Experiment/SignalList/VRecSignal'):
        channel = int(signal.find('Channel').text)
        name = signal.find('Name').text
        channels[name] = channel

    return {
        'size': {'frames': num_frames, 
                 'channels': num_channels, 
                 'z_planes': num_z_planes,
                 'y_px': num_y_px,
                 'x_px': num_x_px},
        'laser': {'power': laser_power, 'wavelength': laser_wavelength},
        'period': frame_period,
        'optical_zoom': optical_zoom,
        'channels': channels,
    }

md_dct = metadata(base)
pprint.pprint(md_dct)

In [None]:
fname_voltage_csv = base + '_Cycle00001_VoltageRecording_001.csv'
df = pd.read_csv(fname_voltage_csv, index_col='Time(ms)', skipinitialspace=True)

In [None]:
size = md_dct['size']
shape = (size['frames'], size['z_planes'])
y_px = size['y_px']

buffer = 5
shift = 3
chn = 3  # Get this empirically

frame = df['frame starts']
frame_start = frame[frame.diff() > 2.5].index

stim = df['FieldStimulator']
stim_start = stim[stim.diff() > 2.5].index
stim_stop = stim[stim.diff() < -2.5].index

In [None]:
def get_loc(times):
    interp = np.interp(times, frame_start, range(len(frame_start)))
    indices = interp.astype(np.int)
    y_offset = y_px * (interp - indices)
    return np.transpose(np.unravel_index(indices, shape)), y_offset

ix_start, y_off_start = get_loc(stim_start)
ix_stop, y_off_stop = get_loc(stim_stop)

In [None]:
index = 0
frame = []
z_plane = []
y_px_start = []
y_px_stop = []
for (ix_start_cyc, ix_start_z), (ix_stop_cyc, ix_stop_z), y_min, y_max in zip(ix_start, ix_stop, y_off_start, y_off_stop):
    if (ix_start_cyc == ix_stop_cyc) and (ix_start_z == ix_stop_z):
        frame.append(ix_start_cyc)
        z_plane.append(ix_start_z)
        y_px_start.append(y_min)
        y_px_stop.append(y_max)
    else:
        frame.append(ix_start_cyc)
        z_plane.append(ix_start_z)
        y_px_start.append(y_min)        
        y_px_stop.append(y_px)
        
        frame.append(ix_stop_cyc)
        z_plane.append(ix_stop_z)
        y_px_start.append(0)
        y_px_stop.append(y_max)
        
df = pd.DataFrame({'frame': frame, 
                   'z_plane': z_plane, 
                   'y_min': y_px_start, 
                   'y_max': y_px_stop})
df.to_hdf('/media/ssd/data/slm_test/artefact.h5', 'data')

In [None]:
%%time

def read(frame, chn, z):
    fname = base + f'_Cycle{frame+1:05d}_Ch{chn}_{z+1:06d}.ome.tif'
    return skimage.io.imread(fname)

shape = (size['frames'], size['z_planes'], size['y_px'], size['x_px'])
dtype = read(0, chn, 0).dtype
print((shape, dtype))
data = np.zeros(shape, dtype)

for frame in range(size['frames']):
    for z_plane in range(size['z_planes']):
        data[frame, z_plane] = read(frame, chn, z_plane)       

In [None]:
os.makedirs('/media/ssd/data/slm_test', exist_ok=True)

In [None]:
%%time
with h5py.File('/media/ssd/data/slm_test/orig.h5', 'w') as f_orig:
    f_orig.create_dataset('data', data=data)

In [None]:
for row in df.itertuples():
    y_slice = slice(int(row.y_min) + shift,
                    int(row.y_max) + shift + buffer + 1)
    before = data[row.frame - 1, row.z_plane, y_slice]
    after = data[row.frame + 1, row.z_plane, y_slice]
    data[row.frame, row.z_plane, y_slice] = (before + after) / 2

In [None]:
%%time
with h5py.File('/media/ssd/data/slm_test/corrected.h5', 'w') as f_corr:
    f_corr.create_dataset('data', data=data)

In [None]:
data_orig = h5py.File('/media/ssd/data/slm_test/orig.h5' , 'r')['data']

In [None]:
_, axes = plt.subplots(ncols=2, figsize=(10,5))
axes[1].imshow(data[114,2], vmin=0, vmax=500);
axes[0].imshow(data_orig[114,2], vmin=0, vmax=500);