In [4]:
%matplotlib inline
from matplotlib import pyplot as plt

In [5]:
import datajoint as dj
import numpy as np
import os
from stimulus import stimulus
from pipeline import fuse

Connecting dimitri@at-database.ad.bcm.edu:3306
Loading local settings from pipeline_config.json


In [6]:
import tqdm

In [7]:
import monet_trippy as mt

In [8]:
# sessions that have both Monet and Trippy from a few recent experiments
sessions = (fuse.Activity * stimulus.Sync & 'animal_id in (20505, 20322, 20457, 20210, 20892)'
            & (stimulus.Trial * stimulus.Monet2) & (stimulus.Trial * stimulus.Trippy)).fetch('KEY')
key = sessions[2]   # pick one

In [9]:
# load frame_times
pipe = (fuse.Activity() & key).module
num_frames = (pipe.ScanInfo() & key).fetch1('nframes')
num_depths = len(dj.U('z') & (pipe.ScanInfo.Field().proj('z', nomatch='field') & key))
frame_times = (stimulus.Sync() & key).fetch1('frame_times', squeeze=True) # one per depth
assert num_frames <= frame_times.size / num_depths <= num_frames + 1
frame_times = frame_times[:num_depths * num_frames:num_depths]  # one per volume

# load and cache soma traces
trace_hash = dj.hash.key_hash({k: v for k, v in key.items() if k not in {'stimulus_type'}})
archive = os.path.join('cache', trace_hash + '-traces.npz')
if os.path.isfile(archive):
    data = np.load(archive)
    trace_keys = data['trace_keys']
    traces = data['traces']
    ms_delay = data['ms_delay']
else:
    units = pipe.ScanSet.Unit * pipe.MaskClassification.Type & {'type': 'soma'}
    spikes = pipe.Activity.Trace * pipe.ScanSet.UnitInfo & units & key
    trace_keys, traces, ms_delay = spikes.fetch('KEY', 'trace', 'ms_delay')
    np.savez_compressed(archive, trace_keys=trace_keys, traces=traces, ms_delay=ms_delay)
frame_times = np.add.outer(ms_delay / 1000, frame_times)  # num_traces x num_frames

In [13]:
# create a trippy session and load trials
trippy_session = mt.VisualSession(np.stack(traces), frame_times)
for trial in (stimulus.Trial * stimulus.Condition * stimulus.Trippy & key).proj(..., '- movie'):
    trippy_session.add_trial(mt.Trippy.from_condition(trial), trial['flip_times'].flatten())

In [None]:
t.packed_phase_movie.shape

In [None]:
# Smoothe the traces
cutoff_freq = 4.0
sampling_freq = 1 / np.median(np.diff(frame_times))  # Hz
if sampling_freq > cutoff_freq:
    h = np.hamming(2 * int(sampling_freq / cutoff_freq) + 1)
    traces = [np.convolve(tr, h / sum(h), mode='same') for tr in traces]

# Create splines (scipy.interpolate.InterpolatedUnivariateSpline)
trace_spline = SplineCurve(frame_times, traces, k=1, ext='zeros')
ftmin, ftmax = frame_times.min(), frame_times.max()

In [21]:
os.mkdir('one')

FileExistsError: [Errno 17] File exists: 'one'

In [None]:
num_lags = 5
bin_size = 0.1
vmax = 0.4

# Iterate over every trial
total_duration = 0
trace_mean = np.zeros(len(trace_keys))
trace_meansq = np.zeros(len(trace_keys))
maps = 0 # num_traces x height x width x num_lags
movie_mean = 0 # 1 x height x width x num_lags
movie_meansq = 0 # 1 x height x width x num_lags
trial_rel = stimulus.Trial() * stimulus.Condition() & key

In [None]:
def compute_sta(traces, movie, num_lags):
    """ Spike-triggered average at diff lags."""
    num_timepoints = movie.shape[-1] - (num_lags - 1) # length of movie minus lag time
    weighted_sums = [np.tensordot(traces[..., lag:lag + num_timepoints], movie[..., :num_timepoints],
                                  axes=(-1, -1)) for lag in range(num_lags)]
    stas = np.stack(weighted_sums, -1) / num_timepoints # num_traces x height x width x num_lags
    return stas

In [None]:
condition_set = {
    'stimulus.Monet': stimulus.Monet(),
    'stimulus.Monet2': stimulus.Monet2(),
    'stimulus.Trippy': stimulus.Trippy().proj('condition_hash', 'fps', 'rng_seed', 'packed_phase_movie', 'tex_ydim',
 'tex_xdim', 'duration', 'xnodes', 'ynodes', 'up_factor', 'temp_freq', 'temp_kernel_length', 'spatial_freq'),
    'stimulus.Varma': stimulus.Varma()}[key['stimulus_type']]

In [None]:
# collect conditions
for k in tqdm.tqdm((condition_set & (trial_rel & key)).fetch('KEY')):
    p = os.path.join('cache', k['condition_hash'].replace('/','_') + '_' + key['stimulus_type'] + '.npz')
    if not os.path.isfile(p):
        np.savez_compressed(p, (condition_set & k).fetch1())

In [None]:
for trial_key, flip_times in tqdm.tqdm(zip(*trial_rel.fetch('KEY', 'flip_times', squeeze=True))):
    # Get sampling points at bin_size resolution
    sample_secs = np.arange(max(flip_times[0], ftmin), min(flip_times[-1], ftmax),
                             bin_size)

    # If trial is long enough (>= 3.5 secs)
    if (len(sample_secs) - 1) * bin_size >= 3.5:
        # Compute trial duration (after subtracting the lag)
        duration = sample_secs[-1] - sample_secs[0] - (num_lags - 1) * bin_size
        total_duration += duration

        # Interpolate movie
        movie = (condition_set & trial_key).fetch1('movie')
        movie = movie.astype('float32') / 127.5 - 1  # -1 to 1
        if movie.ndim == 4:  # ignore color in green/blue monet
            movie = movie.sum(axis=2) / np.sqrt(2)
        movie = interp1d(flip_times, movie)(sample_secs)

        # Interpolate traces
        snippets = trace_spline(sample_secs)

        # Compute trace statistics
        trace_mean += snippets.mean(axis=1) * duration
        trace_meansq += (snippets ** 2).mean(axis=1) * duration

        # Compute STA maps
        maps += compute_sta(snippets, movie, num_lags) * duration

        # Compute movie statistics
        ones = np.ones([1, len(sample_secs)])
        movie_mean += compute_sta(ones, movie, num_lags) * duration
        movie_meansq += compute_sta(ones, movie ** 2, num_lags) * duration

In [None]:
fir, ax = plt.subplots(16, 12, figsize=(18, 18))

for i, a in enumerate(ax.flatten()):
    a.imshow(maps[i,:,:,1]/total_duration, vmin=-.2, vmax=.2, cmap='gray')
    a.set_axis_off()
    a.set_title(str(i))
    
    

In [2]:
key

NameError: name 'key' is not defined

In [None]:
k = dict(keys[1])

In [None]:
k.pop('stimulus_type')

# Trippy Tune

In [None]:
import trippytune

In [None]:
cond_key = condition_set.head(limit=1, as_dict=True)[0]

In [None]:
cond_key

In [None]:
trippy = trippytune.Trippy(
    **{k: v for k, v in cond_key.items() if k in {
        'fps', 'rng_seed', 'packed_phase_movie', 'up_factor', 'temp_freq', 
        'temp_kernel_length', 'duration', 'spatial_freq'}},
    tex_size=(cond_key['tex_xdim'], cond_key['tex_ydim']),
    nodes=(cond_key['xnodes'], cond_key['ynodes']))

In [None]:
img = trippy.compute_phase_movie()