In [4]:
from pathlib import Path

import numpy as np
import tensorflow as tf
from neuralDecoder.datasets.handwritingDataset import CHAR_DEF

from nwb_utils import *

BIN_SIZE_S = 0.02

In [7]:
def convertToTFRecord(inputFeats, transcriptions, outputDir):
    nClasses = 31
    maxSeqLen = 500

    def _charToId(char):
        return CHAR_DEF.index(char)

    def _convert_to_ascii(text):
        return [ord(char) for char in text]

    def _floats_feature(value):
        return tf.train.Feature(float_list=tf.train.FloatList(value=value))

    def _ints_feature(value):
        return tf.train.Feature(int64_list=tf.train.Int64List(value=value))

    saveDir = Path(outputDir)
    saveDir.mkdir(parents=True, exist_ok=True)

    with tf.io.TFRecordWriter(str(saveDir.joinpath("chunk_0.tfrecord"))) as writer:
        for trialIdx in np.arange(len(inputFeats)):
            feats = inputFeats[trialIdx]

            classLabels = np.zeros([feats.shape[0], nClasses]).astype(np.float32)
            newClassSignal = np.zeros([feats.shape[0], 1]).astype(np.float32)
            seqClassIDs = np.zeros([maxSeqLen]).astype(np.int32)

            thisTranscription = transcriptions[trialIdx]

            seqLen = len(thisTranscription)
            seqClassIDs[0:seqLen] = [_charToId(c) + 1 for c in thisTranscription]

            ceMask = np.zeros([feats.shape[0]]).astype(np.float32)

            paddedTranscription = np.zeros([maxSeqLen]).astype(np.int32)
            paddedTranscription[0 : len(thisTranscription)] = np.array(
                _convert_to_ascii(thisTranscription)
            )

            feature = {
                "inputFeatures": _floats_feature(np.ravel(feats).tolist()),
                "classLabelsOneHot": _floats_feature(np.ravel(classLabels).tolist()),
                "newClassSignal": _floats_feature(np.ravel(newClassSignal).tolist()),
                "seqClassIDs": _ints_feature(seqClassIDs),
                "nTimeSteps": _ints_feature([feats.shape[0]]),
                "nSeqElements": _ints_feature([seqLen]),
                "ceMask": _floats_feature(np.ravel(ceMask).tolist()),
                "transcription": _ints_feature(paddedTranscription),
            }

            if trialIdx % 10 == 0:
                print(thisTranscription)
                print(seqClassIDs[0:10])
            example = tf.train.Example(features=tf.train.Features(feature=feature))
            writer.write(example.SerializeToString())

In [8]:
partitions = Path("nwb").glob("*")
for part in partitions:
    print(part)
    for fn in part.glob("*.nwb"):
        spikes, time, tinfo, eval_mask, _ = load_nwb(fn)
        print(f"Number of trials: {tinfo.shape[0]}")
    print()

nwb/held_out_calib
Loading nwb/held_out_calib/T5_2023.06.28_held_out_calib.nwb
Number of trials: 3
Loading nwb/held_out_calib/T5_2023.08.16_held_out_calib.nwb
Number of trials: 3
Loading nwb/held_out_calib/T5_2023.04.17_held_out_calib.nwb
Number of trials: 3
Loading nwb/held_out_calib/T5_2023.10.09_held_out_calib.nwb
Number of trials: 3
Loading nwb/held_out_calib/T5_2023.05.31_held_out_calib.nwb
Number of trials: 3

nwb/held_out_eval
Loading nwb/held_out_eval/T5_2023.06.28_held_out_eval.nwb
Number of trials: 16
Loading nwb/held_out_eval/T5_2023.05.31_held_out_eval.nwb
Number of trials: 16
Loading nwb/held_out_eval/T5_2023.08.16_held_out_eval.nwb
Number of trials: 16
Loading nwb/held_out_eval/T5_2023.04.17_held_out_eval.nwb
Number of trials: 16
Loading nwb/held_out_eval/T5_2023.10.09_held_out_eval.nwb
Number of trials: 16

nwb/held_in_eval
Loading nwb/held_in_eval/T5_2022.06.03_held_in_eval.nwb
Number of trials: 32
Loading nwb/held_in_eval/T5_2022.06.01_held_in_eval.nwb
Number of trials

# Convert seed model training data

In [14]:
train_nwbs = sorted(Path("nwb/held_in_calib").glob("*.nwb"))
test_nwbs = sorted(Path("nwb/held_in_eval").glob("*.nwb"))

for train_nwb, test_nwb in zip(train_nwbs, test_nwbs):
    train_spikes, train_time, train_tinfo, _, session_date = load_nwb(train_nwb)
    test_spikes, test_time, test_tinfo, _, _ = load_nwb(test_nwb)
    print(
        f"Number of train trials: {train_tinfo.shape[0]}, test trials: {test_tinfo.shape[0]}, val trials: {val_tinfo.shape[0]}"
    )

    # Extract trials
    train_feats, train_cues, train_block_stats = extract_trials(
        train_spikes, train_time, train_tinfo
    )
    test_feats, test_cues, _ = extract_trials(test_spikes, test_time, test_tinfo)

    cpm = np.sum([len(c) for c in train_cues]) / (
        np.sum([f.shape[0] for f in train_feats]) * BIN_SIZE_S / 60
    )
    print(f"Characters per minute: {cpm}")

    # Convert to tfrecords
    convertToTFRecord(
        train_feats[:-2],
        train_cues[:-2],
        f"data/held_in_tfrecords/{session_date}/train",
    )
    convertToTFRecord(
        test_feats,
        test_cues,
        f"data/held_in_tfrecords/{session_date}/eval",
    )
    convertToTFRecord(
        train_feats[-2:],
        train_cues[-2:],
        f"data/held_in_tfrecords/{session_date}/test",
    )

Loading nwb/held_in_calib/T5_2022.05.18_held_in_calib.nwb
Loading nwb/held_in_eval/T5_2022.05.18_held_in_eval.nwb
Number of train trials: 30, test trials: 20, val trials: 2
Characters per minute: 42.74138775252326
there>just>might>be>too>many>people>there~
[20  8  5 18  5 27 10 21 19 20]
three>weeks
[20  8 18  5  5 27 23  5  5 11]
why>did>i>have>to>fall>for>a>suicidal>maniac?
[23  8 25 27  4  9  4 27  9 27]
is>the>baby>coming>down>with>something>nasty?
[ 9 19 27 20  8  5 27  2  1  2]
she>was>a>terrible>woman~
[19  8  5 27 23  1 19 27  1 27]
hot>weather
[ 8 15 20 27 23  5  1 20  8  5]
Loading nwb/held_in_calib/T5_2022.05.23_held_in_calib.nwb
Loading nwb/held_in_eval/T5_2022.05.23_held_in_eval.nwb
Number of train trials: 48, test trials: 32, val trials: 2
Characters per minute: 34.80892222294418
stay>home
[19 20  1 25 27  8 15 13  5  0]
it>was>quite>a>sight~
[ 9 20 27 23  1 19 27 17 21  9]
i>know>you>have>ambitions>to>enter>politics~
[ 9 27 11 14 15 23 27 25 15 21]
last>month
[12  1 19 2

# Convert held-out data

In [13]:
train_nwbs = sorted(Path("nwb/held_out_calib").glob("*.nwb"))
test_nwbs = sorted(Path("nwb/held_out_eval").glob("*.nwb"))
oracle_nwbs = sorted(Path("nwb/held_out_oracle").glob("*.nwb"))

for train_nwb, test_nwb, oracle_nwb in zip(train_nwbs, test_nwbs, oracle_nwbs):
    train_spikes, train_time, train_tinfo, _, session_date = load_nwb(train_nwb)
    test_spikes, test_time, test_tinfo, _, _ = load_nwb(test_nwb)
    oracle_spikes, oracle_time, oracle_tinfo, _, _ = load_nwb(oracle_nwb)
    print(
        f"Number of train trials: {train_tinfo.shape[0]}, test trials: {test_tinfo.shape[0]}, oracle trials: {oracle_tinfo.shape[0]}"
    )

    # Extract trials
    train_feats, train_cues, train_block_stats = extract_trials(
        train_spikes, train_time, train_tinfo
    )
    test_feats, test_cues, _ = extract_trials(
        test_spikes, test_time, test_tinfo
    )
    oracle_feats, oracle_cues, _ = extract_trials(
        oracle_spikes, oracle_time, oracle_tinfo
    )

    # Convert to tfrecords
    # For notebook demo
    convertToTFRecord(
        train_feats,
        train_cues,
        f"data/held_out_demo_tfrecords/{session_date}/test",
    )

    # For few-shot learning
    convertToTFRecord(
        train_feats,
        train_cues,
        f"data/held_out_tfrecords/{session_date}/train",
    )
    convertToTFRecord(
        test_feats,
        test_cues,
        f"data/held_out_tfrecords/{session_date}/test",
    )

    # For oracle eval
    convertToTFRecord(
        oracle_feats,
        oracle_cues,
        f"data/held_out_oracle_tfrecords/{session_date}/train",
    )
    convertToTFRecord(
        test_feats,
        test_cues,
        f"data/held_out_oracle_tfrecords/{session_date}/test",
    )

Loading nwb/held_out_calib/T5_2023.04.17_held_out_calib.nwb
Loading nwb/held_out_eval/T5_2023.04.17_held_out_eval.nwb
Loading nwb/held_out_oracle/T5_2023.04.17_held_out_oracle.nwb
Number of train trials: 3, test trials: 16, oracle trials: 24
what>about>transfers?
[23  8  1 20 27  1  2 15 21 20]
what>about>transfers?
[23  8  1 20 27  1  2 15 21 20]
so>he>wound>up>with>a>dozen~
[19 15 27  8  5 27 23 15 21 14]
must>or>should>the>federal>government>help?
[13 21 19 20 27 15 18 27 19  8]
what>about>transfers?
[23  8  1 20 27  1  2 15 21 20]
maris>sleeps>on>a>green>studio>couch>in>the>living>room~
[13  1 18  9 19 27 19 12  5  5]
the>figure>halted,>and>watson>gasped~
[20  8  5 27  6  9  7 21 18  5]
so>he>wound>up>with>a>dozen~
[19 15 27  8  5 27 23 15 21 14]
must>or>should>the>federal>government>help?
[13 21 19 20 27 15 18 27 19  8]
Loading nwb/held_out_calib/T5_2023.05.31_held_out_calib.nwb
Loading nwb/held_out_eval/T5_2023.05.31_held_out_eval.nwb
Loading nwb/held_out_oracle/T5_2023.05.31_hel