# Copy Task dataset survey


This notebook explores the Copy Task neural recordings that power the Brain-to-Text model. It performs the following steps:

1. Load block-level metadata from `data/t15_copyTaskData_description.csv` to annotate the corpora, speaking strategy, and data split for every trial.
2. Iterate over the `.hdf5` session files with `model_training.evaluate_model_helpers.load_h5py_file` to extract neural features, labels, and phoneme sequences.
3. Summarize the dataset via channel-level statistics, sentence length histograms, and baseline phoneme error rate (PER) calculations that can guide future modeling work.


In [None]:

from pathlib import Path
from collections import defaultdict
from typing import Dict, List, Tuple, Optional

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from model_training.evaluate_model_helpers import load_h5py_file, LOGIT_TO_PHONEME
from nejm_b2txt_utils.general_utils import calculate_error_rate

sns.set_theme(style="whitegrid", context="talk")


In [None]:

REPO_ROOT = Path.cwd()
DATA_DIR = REPO_ROOT / "data"
METADATA_PATH = DATA_DIR / "t15_copyTaskData_description.csv"
HDF5_ROOT = DATA_DIR / "hdf5_data_final"

# Optional limits for quick experiments. Set to integers to subsample the workload.
MAX_FILES: Optional[int] = None
MAX_TRIALS_PER_FILE: Optional[int] = None

print(f"Using data directory: {DATA_DIR}")
print(f"Expecting HDF5 files under: {HDF5_ROOT}")


## Load Copy Task metadata

In [None]:

raw_metadata_df = pd.read_csv(METADATA_PATH)
if 'Speaking strategy' not in raw_metadata_df.columns:
    raw_metadata_df['Speaking strategy'] = 'Unknown'

raw_metadata_df['Block number'] = raw_metadata_df['Block number'].astype(int)
raw_metadata_df['Date'] = pd.to_datetime(raw_metadata_df['Date']).dt.strftime('%Y-%m-%d')

metadata_df = raw_metadata_df.rename(columns={
    'Post-implant day': 'post_implant_day',
    'Block number': 'block_number',
    'Number of sentences': 'n_sentences',
    'Corpus': 'corpus',
    'Split': 'split',
    'Speaking strategy': 'speaking_strategy',
})
metadata_df['date'] = metadata_df['Date']

metadata_lookup = metadata_df.set_index(['date', 'block_number']).to_dict('index')

print(f"Metadata rows: {len(metadata_df):,}")
metadata_df.head()


In [None]:

metadata_pivot = metadata_df.groupby(['split', 'corpus'])['n_sentences'].sum().unstack(fill_value=0)
metadata_pivot


## Locate neural session files

In [None]:

hdf5_files = sorted(HDF5_ROOT.glob('t15.*/*.hdf5'))
if MAX_FILES:
    hdf5_files = hdf5_files[:MAX_FILES]

print(f"Discovered {len(hdf5_files)} HDF5 files")
if not hdf5_files:
    raise FileNotFoundError(
        "No .hdf5 files were found. Download `t15_copyTask_neuralData.zip`, "
        "unzip it, and place the `hdf5_data_final` directory inside `data/`."
    )

hdf5_files[:5]


## Helper functions for decoding and aggregation

In [None]:

CHANNEL_GROUPS: List[Tuple[str, Tuple[int, int]]] = [
    ("ventral 6v (TC)", (0, 64)),
    ("area 4 (TC)", (64, 128)),
    ("55b (TC)", (128, 192)),
    ("dorsal 6v (TC)", (192, 256)),
    ("ventral 6v (SBP)", (256, 320)),
    ("area 4 (SBP)", (320, 384)),
    ("55b (SBP)", (384, 448)),
    ("dorsal 6v (SBP)", (448, 512)),
]

def session_to_date(session_name: str) -> str:
    parts = session_name.split('.')
    if len(parts) >= 4:
        return f"{parts[1]}-{parts[2]}-{parts[3]}"
    return parts[-1]


def decode_sentence_label(value) -> Optional[str]:
    if value is None:
        return None
    if isinstance(value, bytes):
        return value.decode('utf-8')
    if isinstance(value, str):
        return value
    arr = np.array(value).flatten()
    if arr.dtype.kind in {'U', 'S'}:
        return ''.join(arr.tolist()).strip()
    return str(value)


def decode_transcription(encoded) -> Optional[str]:
    if encoded is None:
        return None
    arr = np.array(encoded).flatten()
    chars = []
    for code in arr:
        code = int(code)
        if code == 0:
            break
        chars.append(chr(code))
    return ''.join(chars)


def decode_phoneme_sequence(seq_ids, seq_len) -> List[str]:
    if seq_ids is None or seq_len is None:
        return []
    clipped = [int(p) for p in seq_ids[:seq_len]]
    return [LOGIT_TO_PHONEME[p] for p in clipped]


def assign_channel_group(channel_idx: int) -> str:
    for group_name, (start, end) in CHANNEL_GROUPS:
        if start <= channel_idx < end:
            return group_name
    return 'unknown'


def update_channel_stats(stats: Dict[str, Optional[np.ndarray]], features: np.ndarray) -> None:
    feats = np.asarray(features, dtype=np.float64)
    if feats.ndim != 2:
        raise ValueError(f"Expected 2D features, received shape {feats.shape}")
    if stats['sum'] is None:
        n_channels = feats.shape[1]
        stats['sum'] = np.zeros(n_channels, dtype=np.float64)
        stats['sum_sq'] = np.zeros(n_channels, dtype=np.float64)
        stats['min'] = np.full(n_channels, np.inf, dtype=np.float64)
        stats['max'] = np.full(n_channels, -np.inf, dtype=np.float64)
        stats['count'] = 0
    stats['sum'] += feats.sum(axis=0)
    stats['sum_sq'] += np.square(feats).sum(axis=0)
    stats['min'] = np.minimum(stats['min'], feats.min(axis=0))
    stats['max'] = np.maximum(stats['max'], feats.max(axis=0))
    stats['count'] += feats.shape[0]


## Load sessions with `load_h5py_file`

In [None]:

channel_stats = {'sum': None, 'sum_sq': None, 'min': None, 'max': None, 'count': 0}
trial_records: List[Dict[str, object]] = []
phoneme_records: List[Dict[str, object]] = []

sentence_lookup: Dict[Tuple[str, str], List[str]] = {}

for file_idx, file_path in enumerate(hdf5_files):
    session_payload = load_h5py_file(str(file_path), raw_metadata_df)
    n_trials = len(session_payload['session'])
    if MAX_TRIALS_PER_FILE is not None:
        n_trials = min(n_trials, MAX_TRIALS_PER_FILE)

    for trial_idx in range(n_trials):
        session_name = session_payload['session'][trial_idx]
        block_num = int(session_payload['block_num'][trial_idx])
        trial_num = int(session_payload['trial_num'][trial_idx])
        n_time_steps = int(session_payload['n_time_steps'][trial_idx])

        session_date = session_to_date(session_name)
        meta = metadata_lookup.get((session_date, block_num), {})
        corpus_name = meta.get('corpus', session_payload['corpus'][trial_idx])
        split_name = meta.get('split', 'Unknown')
        speaking_strategy = meta.get('speaking_strategy', 'Unknown')

        sentence_label = decode_sentence_label(session_payload['sentence_label'][trial_idx])
        transcription = decode_transcription(session_payload['transcriptions'][trial_idx])
        seq_len = session_payload['seq_len'][trial_idx]

        record = {
            'session': session_name,
            'session_date': session_date,
            'block_num': block_num,
            'trial_num': trial_num,
            'split': split_name,
            'corpus': corpus_name,
            'speaking_strategy': speaking_strategy,
            'n_time_steps': n_time_steps,
            'seq_len': int(seq_len) if seq_len is not None else None,
            'sentence_label': sentence_label,
            'transcription': transcription,
        }

        if sentence_label:
            record['sentence_word_len'] = len(sentence_label.split())
            record['sentence_char_len'] = len(sentence_label)

        trial_records.append(record)

        neural_features = session_payload['neural_features'][trial_idx]
        update_channel_stats(channel_stats, neural_features)

        seq_ids = session_payload['seq_class_ids'][trial_idx]
        if seq_ids is not None and seq_len is not None:
            phoneme_seq = decode_phoneme_sequence(seq_ids, seq_len)
            phoneme_records.append({
                'session': session_name,
                'block_num': block_num,
                'trial_num': trial_num,
                'split': split_name,
                'corpus': corpus_name,
                'speaking_strategy': speaking_strategy,
                'sentence_label': sentence_label or f"trial_{trial_idx}",
                'phonemes': phoneme_seq,
            })
            sentence_lookup[(corpus_name, sentence_label)] = phoneme_seq

print(f"Loaded {len(trial_records):,} trials across {len(hdf5_files)} files")
trials_df = pd.DataFrame(trial_records)
trials_df.head()


## Channel statistics

In [None]:

if channel_stats['sum'] is None:
    raise RuntimeError('Channel statistics are unavailable. Make sure at least one trial was loaded.')

count = channel_stats['count']
means = channel_stats['sum'] / count
variances = channel_stats['sum_sq'] / count - np.square(means)
stds = np.sqrt(np.maximum(variances, 0))

channel_df = pd.DataFrame({
    'channel': np.arange(len(means)),
    'mean': means,
    'std': stds,
    'min': channel_stats['min'],
    'max': channel_stats['max'],
})
channel_df['array'] = channel_df['channel'].apply(assign_channel_group)
channel_df.head()


In [None]:

fig, axes = plt.subplots(2, 1, figsize=(12, 10), sharex=True)
sns.lineplot(data=channel_df, x='channel', y='mean', hue='array', ax=axes[0])
axes[0].set_title('Channel mean activity (20 ms bins)')
axes[0].set_xlabel('Channel index')
axes[0].set_ylabel('Mean feature value')
axes[0].legend(title='Array', bbox_to_anchor=(1.02, 1), loc='upper left')

sns.lineplot(data=channel_df, x='channel', y='std', hue='array', ax=axes[1])
axes[1].set_title('Channel standard deviation')
axes[1].set_xlabel('Channel index')
axes[1].set_ylabel('Std. dev.')
axes[1].legend(title='Array', bbox_to_anchor=(1.02, 1), loc='upper left')

plt.tight_layout()


## Sentence length distributions

In [None]:

if trials_df.empty:
    raise RuntimeError('Trial metadata is empty. Verify that the HDF5 files contain labeled trials.')

length_cols = ['sentence_word_len', 'sentence_char_len']
for col in length_cols:
    trials_df[col] = trials_df[col].fillna(0)

fig, axes = plt.subplots(1, 2, figsize=(14, 5))
for ax, col, label in zip(axes, length_cols, ['Words per sentence', 'Characters per sentence']):
    sns.histplot(
        data=trials_df,
        x=col,
        hue='split',
        multiple='stack',
        bins=30,
        ax=ax,
    )
    ax.set_title(label)
    ax.set_xlabel(label)
    ax.set_ylabel('Trials')
plt.tight_layout()

sns.displot(
    data=trials_df,
    x='sentence_word_len',
    col='corpus',
    col_wrap=3,
    hue='split',
    bins=25,
    facet_kws={'sharey': False},
)
plt.subplots_adjust(top=0.85)
plt.suptitle('Sentence word-length distributions by corpus')


## PER baselines

In [None]:

def aggregate_per(targets: List[List[str]], predictions: List[List[str]]):
    total_err = 0
    total_len = 0
    for true_seq, pred_seq in zip(targets, predictions):
        if not true_seq:
            continue
        total_err += calculate_error_rate(true_seq, pred_seq)
        total_len += len(true_seq)
    return total_err / total_len if total_len else np.nan

if not phoneme_records:
    raise RuntimeError('Phoneme annotations were not found. Ensure you are using training/validation files with labels.')

phoneme_targets = [rec['phonemes'] for rec in phoneme_records]
corpus_labels = [rec['corpus'] for rec in phoneme_records]

silence_predictions = [['SIL'] * len(seq) for seq in phoneme_targets]

corpus_sentence_counts = (
    pd.DataFrame(phoneme_records)
    .groupby(['corpus', 'sentence_label'])
    .size()
    .reset_index(name='count')
)
corpus_mode = corpus_sentence_counts.loc[
    corpus_sentence_counts.groupby('corpus')['count'].idxmax()
][['corpus', 'sentence_label']]
mode_lookup = dict(zip(corpus_mode['corpus'], corpus_mode['sentence_label']))

corpus_sentence_map = {
    (rec['corpus'], rec['sentence_label']): rec['phonemes']
    for rec in phoneme_records
}
mode_predictions = [
    corpus_sentence_map[(corpus, mode_lookup[corpus])]
    for corpus in corpus_labels
]

rng = np.random.default_rng(0)
corpus_sequences: Dict[str, List[List[str]]] = defaultdict(list)
for rec in phoneme_records:
    corpus_sequences[rec['corpus']].append(rec['phonemes'])
random_predictions = [
    corpus_sequences[corpus][rng.integers(len(corpus_sequences[corpus]))]
    for corpus in corpus_labels
]

baseline_results = pd.DataFrame([
    {'baseline': 'All SIL', 'per': aggregate_per(phoneme_targets, silence_predictions)},
    {'baseline': 'Corpus mode sentence', 'per': aggregate_per(phoneme_targets, mode_predictions)},
    {'baseline': 'Random sentence (same corpus)', 'per': aggregate_per(phoneme_targets, random_predictions)},
]).sort_values('per')

baseline_results


In [None]:

fig, ax = plt.subplots(figsize=(8, 4))
sns.barplot(data=baseline_results, x='baseline', y='per', ax=ax)
ax.set_ylabel('Aggregate PER')
ax.set_xlabel('Baseline strategy')
ax.set_title('Phoneme error rate baselines')
plt.xticks(rotation=20)
plt.tight_layout()
