# Prep

In [1]:
!mkdir brain-to-text/
!git clone https://github.com/fwillett/speechBCI.git
!pip install -q condacolab

# for colab
!pip install -q condacolab
import condacolab
condacolab.install()

!conda env create -f /content/speechBCI_2024/environment.yml

In [None]:
%%bash
exec bash

In [None]:
%%shell
eval "$(conda shell.bash hook)"
conda activate speech-BCI

In [None]:
!pip install -e speechBCI/NeuralDecoder
!pip install g2p-en

In [None]:
import nltk
nltk.download('averaged_perceptron_tagger_eng')
import os
import pandas as pd
import numpy as np
import re
from g2p_en import G2p
import scipy
import pickle


# Neural Decoder - Pretraining

## Prep Train Data

In [None]:
dataDir = '/content/drive/MyDrive/coding for fun/data/competitionData'
saveDir = '/content/drive/MyDrive/TreeHacks/ptDecoder_ctc'

In [None]:

sessionNames = ['t12.2022.04.28',  't12.2022.05.26',  't12.2022.06.21',  't12.2022.07.21',  't12.2022.08.13',
't12.2022.05.05',  't12.2022.06.02',  't12.2022.06.23',  't12.2022.07.27',  't12.2022.08.18',
't12.2022.05.17',  't12.2022.06.07',  't12.2022.06.28',  't12.2022.07.29',  't12.2022.08.23',
't12.2022.05.19',  't12.2022.06.14',  't12.2022.07.05',  't12.2022.08.02',  't12.2022.08.25',
't12.2022.05.24',  't12.2022.06.16',  't12.2022.07.14',  't12.2022.08.11']
sessionNames.sort()


g2p = G2p()
PHONE_DEF = [
    'AA', 'AE', 'AH', 'AO', 'AW',
    'AY', 'B',  'CH', 'D', 'DH',
    'EH', 'ER', 'EY', 'F', 'G',
    'HH', 'IH', 'IY', 'JH', 'K',
    'L', 'M', 'N', 'NG', 'OW',
    'OY', 'P', 'R', 'S', 'SH',
    'T', 'TH', 'UH', 'UW', 'V',
    'W', 'Y', 'Z', 'ZH'
]
PHONE_DEF_SIL = PHONE_DEF + ['SIL']

def phoneToId(p):
    return PHONE_DEF_SIL.index(p)

import scipy

def loadFeaturesAndNormalize(sessionPath,type='train'):

    dat = scipy.io.loadmat(sessionPath)

    input_features = []
    transcriptions = []
    frame_lens = []
    block_means = []
    block_stds = []
    n_trials = dat['sentenceText'].shape[0]

    #collect area 6v tx1 and spikePow features
    for i in range(n_trials):
        #get time series of TX and spike power for this trial
        #first 128 columns = area 6v only
        if type == 'train':
            features = np.concatenate([dat['tx1'][0,i][:,0:128], dat['spikePow'][0,i][:,0:128]], axis=1)
        else:
            features = dat['data']

        sentence_len = features.shape[0]
        sentence = dat['sentenceText'][i].strip()

        input_features.append(features)
        transcriptions.append(sentence)
        frame_lens.append(sentence_len)

    #block-wise feature normalization
    blockNums = np.squeeze(dat['blockIdx'])
    blockList = np.unique(blockNums)
    blocks = []
    for b in range(len(blockList)):
        sentIdx = np.argwhere(blockNums==blockList[b])
        sentIdx = sentIdx[:,0].astype(np.int32)
        blocks.append(sentIdx)

    for b in range(len(blocks)):
        feats = np.concatenate(input_features[blocks[b][0]:(blocks[b][-1]+1)], axis=0)
        feats_mean = np.mean(feats, axis=0, keepdims=True)
        feats_std = np.std(feats, axis=0, keepdims=True)
        for i in blocks[b]:
            input_features[i] = (input_features[i] - feats_mean) / (feats_std + 1e-8)

    #convert to tfRecord file
    session_data = {
        'inputFeatures': input_features,
        'transcriptions': transcriptions,
        'frameLens': frame_lens
    }

    return session_data


def getDataset(fileName,type='train'):
    session_data = loadFeaturesAndNormalize(fileName,type)

    allDat = []
    trueSentences = []
    seqElements = []

    for x in range(len(session_data['inputFeatures'])):
        allDat.append(session_data['inputFeatures'][x])
        trueSentences.append(session_data['transcriptions'][x])

        thisTranscription = str(session_data['transcriptions'][x]).strip()
        thisTranscription = re.sub(r'[^a-zA-Z\- \']', '', thisTranscription)
        thisTranscription = thisTranscription.replace('--', '').lower()
        addInterWordSymbol = True

        phonemes = []
        for p in g2p(thisTranscription):
            if addInterWordSymbol and p==' ':
                phonemes.append('SIL')
            p = re.sub(r'[0-9]', '', p)  # Remove stress
            if re.match(r'[A-Z]+', p):  # Only keep phonemes
                phonemes.append(p)

        #add one SIL symbol at the end so there's one at the end of each word
        if addInterWordSymbol:
            phonemes.append('SIL')

        seqLen = len(phonemes)
        maxSeqLen = 500
        seqClassIDs = np.zeros([maxSeqLen]).astype(np.int32)
        seqClassIDs[0:seqLen] = [phoneToId(p) + 1 for p in phonemes]
        seqElements.append(seqClassIDs)

    newDataset = {}
    newDataset['sentenceDat'] = allDat
    newDataset['transcriptions'] = trueSentences
    newDataset['phonemes'] = seqElements

    timeSeriesLens = []
    phoneLens = []
    for x in range(len(newDataset['sentenceDat'])):
        timeSeriesLens.append(newDataset['sentenceDat'][x].shape[0])

        zeroIdx = np.argwhere(newDataset['phonemes'][x]==0)
        phoneLens.append(zeroIdx[0,0])

    newDataset['timeSeriesLens'] = np.array(timeSeriesLens)
    newDataset['phoneLens'] = np.array(phoneLens)
    newDataset['phonePerTime'] = newDataset['phoneLens'].astype(np.float32) / newDataset['timeSeriesLens'].astype(np.float32)
    return newDataset

trainDatasets = []
testDatasets = []
competitionDatasets = []


for dayIdx in range(len(sessionNames)):
    print(dayIdx)
    trainDataset = getDataset(dataDir + '/train/' + sessionNames[dayIdx] + '.mat')
    testDataset = getDataset(dataDir + '/test/' + sessionNames[dayIdx] + '.mat')

    trainDatasets.append(trainDataset)
    testDatasets.append(testDataset)

    if os.path.exists(dataDir + '/competitionHoldOut/' + sessionNames[dayIdx] + '.mat'):
        dataset = getDataset(dataDir + '/competitionHoldOut/' + sessionNames[dayIdx] + '.mat')
        competitionDatasets.append(dataset)

In [None]:

competitionDays = []
for dayIdx in range(len(sessionNames)):
    if os.path.exists(dataDir + '/competitionHoldOut/' + sessionNames[dayIdx] + '.mat'):
        competitionDays.append(dayIdx)
print(competitionDays)


allDatasets = {}
allDatasets['train'] = trainDatasets
allDatasets['test'] = testDatasets
allDatasets['competition'] = competitionDatasets


with open(saveDir, 'wb') as handle:
    pickle.dump(allDatasets, handle)

## Build & Train Model

In [None]:
modelSaveDir = '/content/drive/MyDrive/TreeHacks/NeuralDecoderModel'

In [None]:
os.chdir("/content/")
!git clone https://github.com/cffan/neural_seq_decoder.git
os.chdir("/content/neural_seq_decoder/src/")
!pip install -e ..

In [None]:

modelName = 'speechBaseline4'

args = {}
args['outputDir'] = modelSaveDir
args['datasetPath'] = saveDir
args['seqLen'] = 150
args['maxTimeSeriesLen'] = 1200
args['batchSize'] = 64
args['lrStart'] = 0.02
args['lrEnd'] = 0.02
args['nUnits'] = 1024
args['nBatch'] = 10000
args['nLayers'] = 5
args['seed'] = 0
args['nClasses'] = 40
args['nInputFeatures'] = 256
args['dropout'] = 0.4
args['whiteNoiseSD'] = 0.8
args['constantOffsetSD'] = 0.2
args['gaussianSmoothWidth'] = 2.0
args['strideLen'] = 4
args['kernelLen'] = 32
args['bidirectional'] = True
args['l2_decay'] = 1e-5

from neural_decoder.neural_decoder_trainer import trainModel, getDatasetLoaders

model=trainModel(args)

# Neural Decoder - Fine-Tuning

## Prep Data

In [None]:
fine_tuning_dataset_path = "/content/fine-tuning-data.csv"

In [None]:
bci_data = pd.read_csv(fine_tuning_dataset_path)

def inflate_dims(arr, group_size=16):

    x, y = arr.shape
    mid = y // 2

    first_half = arr[:, :mid]
    second_half = arr[:, mid:]

    result = np.concatenate([
        np.repeat(first_half, group_size, axis=1),
        np.repeat(second_half, group_size, axis=1)
    ], axis=1)

    return result

dataset = {
        'data': inflate_dims(bci_data.T.values),
        'sentenceText': np.array(['hello there',]),
        'blockIdx': np.arange(0).repeat(20).reshape(-1,1)
    }

scipy.io.savemat(fine_tuning_dataset_path, dataset)

fine_tune_dataset=getDataset(fine_tuning_dataset_path,type='val')

# temporary fix
d = {"train":[fine_tune_dataset],'test':[fine_tune_dataset]}

with open('/content/local_tmp', 'wb') as handle:
    pickle.dump(d, handle)

## FT

In [None]:
import torch
from torch.utils.data import Dataset


class SpeechDataset(Dataset):
    def __init__(self, data, transform=None):
        self.data = data
        self.transform = transform
        self.n_days = len(data)
        self.n_trials = sum([len(d["sentenceDat"]) for d in data])

        self.neural_feats = []
        self.phone_seqs = []
        self.neural_time_bins = []
        self.phone_seq_lens = []
        self.days = []
        for day in range(self.n_days):
            for trial in range(len(data[day]["sentenceDat"])):
                self.neural_feats.append(data[day]["sentenceDat"][trial])
                self.phone_seqs.append(data[day]["phonemes"][trial])
                self.neural_time_bins.append(data[day]["sentenceDat"][trial].shape[0])
                self.phone_seq_lens.append(data[day]["phoneLens"][trial])
                self.days.append(day)

    def __len__(self):
        return self.n_trials

    def __getitem__(self, idx):
        neural_feats = torch.tensor(self.neural_feats[idx], dtype=torch.float32)

        if self.transform:
            neural_feats = self.transform(neural_feats)

        return (
            neural_feats,
            torch.tensor(self.phone_seqs[idx], dtype=torch.int32),
            torch.tensor(self.neural_time_bins[idx], dtype=torch.int32),
            torch.tensor(self.phone_seq_lens[idx], dtype=torch.int32),
            torch.tensor(self.days[idx], dtype=torch.int64),
        )


In [None]:

def prepare_fine_tune_data(session_path: str):

    dataset = getDataset(session_path, 'val')

    return dataset


def fine_tune_model(
    model: torch.nn.Module,
    session_path: str,
    config: Optional[Dict] = None,
    device: str = "cuda"
)
    default_config = {
        "learning_rate": 1e-4,
        "num_epochs": 10,
        "batch_size": 32,
        "weight_decay": 0.01,
        "freeze_layers": [],
        "gradient_clip": 1.0
    }


    if config is not None:
        default_config.update(config)
    config = default_config

    fine_tune_data = prepare_fine_tune_data(session_path)

    model = model.to(device)

    if config["freeze_layers"]:
        for name, param in model.named_parameters():
            if any(layer in name for layer in config["freeze_layers"]):
                param.requires_grad = False


    dataloader, _, _ = getDatasetLoaders(
        '/content/local_tmp',
        args["batchSize"]
    )

    criterion = torch.nn.CTCLoss(blank=0, reduction="mean", zero_infinity=True)
    optimizer = torch.optim.AdamW(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=config["learning_rate"],
        weight_decay=config["weight_decay"]
    )

    model.train()
    best_loss = float('inf')
    training_stats = {
        "epoch_losses": [],
        "best_epoch": 0,
        "final_phone_per_time": float(fine_tune_data['phonePerTime'].mean())
    }

    for epoch in range(config["num_epochs"]):
        epoch_loss = 0
        num_batches = 0

        for features, phonemes, time_lens, phone_lens, indices in dataloader:
            # Move batch to device
            features = features.to(device)
            phonemes = phonemes.to(device)
            time_lens = time_lens.to(device)
            phone_lens = phone_lens.to(device)
            indices = indices.to(device)

            # Forward pass
            pred = model(features, indices)

            # Calculate loss
            loss = criterion(
                torch.permute(pred.log_softmax(2), [1, 0, 2]),
                phonemes,
                ((time_lens - model.kernelLen) / model.strideLen).to(torch.int32),
                phone_lens
            )

            # Backward pass
            optimizer.zero_grad()
            loss.backward()

            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(
                model.parameters(),
                config["gradient_clip"]
            )

            optimizer.step()

            epoch_loss += loss.item()
            num_batches += 1

        avg_epoch_loss = epoch_loss / num_batches
        training_stats["epoch_losses"].append(avg_epoch_loss)

        print(f"Epoch {epoch+1}/{config['num_epochs']}, "
              f"Average Loss: {avg_epoch_loss:.4f}")

        if avg_epoch_loss < best_loss:
            best_loss = avg_epoch_loss
            best_state = model.state_dict()
            training_stats["best_epoch"] = epoch

    model.load_state_dict(best_state)

    return model, training_stats

# temp
config = {
    "dtype": 'float'
}


In [None]:
FTModelSaveDir = "/content/drive/MyDrive/TreeHacks/FTModel"

In [None]:
# Fine-tune the model
fine_tuned_model, stats = fine_tune_model(
    model=model,
    session_path=fine_tuning_dataset_path,
    config=config
)

torch.save(fine_tuned_model.state_dict(), FTModelSaveDir + "/modelWeights")