<a href="https://colab.research.google.com/github/carteruh/Neural-Text-Decoding/blob/main/TextDecodingScript.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Text model training for the paper "A high-performance neuroprosthesis for speech decoding and avatar control"

### To understand my work, I have uploaded a new dataset from the following paper: "A High Performance Speech Neuroprosthesis". I have preprocessed, padded, and normalized the data to fit within the training code of the original paper "A high-performance neuroprosthesis for speech decoding and avatar control".

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
!pip install flashlight-text
!pip install git+https://github.com/kpu/kenlm.git
!pip install g2p-en
!pip install -U pynwb

# Load packages
from sklearn.model_selection import train_test_split
import numpy as np
import pandas as pd
import argparse
from os.path import join
import torchaudio
import torch
from torchaudio.models import decoder
from torchaudio.models.decoder import download_pretrained_files

print('torch version', torch.__version__)
print('torch audio version', torchaudio.__version__)

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from torch.utils.data import DataLoader, TensorDataset, Dataset
from torchvision import transforms
!pip install speechbrain
import speechbrain as sb

import copy
!pip install wandb
import wandb
import os

import scipy.io
import numpy as np
import tensorflow as tf
from pathlib import Path
import matplotlib.pyplot as plt
from g2p_en import G2p
import re

Collecting git+https://github.com/kpu/kenlm.git
  Cloning https://github.com/kpu/kenlm.git to /tmp/pip-req-build-1ep2_v9w
  Running command git clone --filter=blob:none --quiet https://github.com/kpu/kenlm.git /tmp/pip-req-build-1ep2_v9w
  Resolved https://github.com/kpu/kenlm.git to commit 35f145839eca742f2402716d17542fd0546efc9d
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
torch version 2.1.0+cu118
torch audio version 2.1.0+cu118


In [None]:
# Clone the original repository to obtain useful files such as the language models and corpus
!git clone https://github.com/UCSF-Chang-Lab-BRAVO/multimodal-decoding.git

fatal: destination path 'multimodal-decoding' already exists and is not an empty directory.


#### Load and Access the 50-word Speech Task Dataset (My Work)

In [None]:
# Get dataset for senetences
!wget -O data.tar.gz https://datadryad.org/stash/downloads/file_stream/2547369

--2023-12-07 04:13:25--  https://datadryad.org/stash/downloads/file_stream/2547369
Resolving datadryad.org (datadryad.org)... 54.188.132.94, 54.185.209.108
Connecting to datadryad.org (datadryad.org)|54.188.132.94|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://dryad-assetstore-merritt-west.s3.us-west-2.amazonaws.com/ark%3A/13030/m53853vd%7C6%7Cproducer/competitionData.tar.gz?response-content-disposition=attachment%3B%20filename%3DcompetitionData.tar.gz&X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIA2KERHV5E3OITXZXC%2F20231207%2Fus-west-2%2Fs3%2Faws4_request&X-Amz-Date=20231207T041326Z&X-Amz-Expires=86400&X-Amz-SignedHeaders=host&X-Amz-Signature=44bd9c3921de55acdad23ee54258faad971227f965a218fad269c6a92ecb1400 [following]
--2023-12-07 04:13:26--  https://dryad-assetstore-merritt-west.s3.us-west-2.amazonaws.com/ark%3A/13030/m53853vd%7C6%7Cproducer/competitionData.tar.gz?response-content-disposition=attachment%3B%20filename%3DcompetitionData.t

In [None]:
import tarfile
!pip install torchvision

# Replace 'your_file.tar.gz' with the path of your .tar.gz file
file_path = 'data.tar.gz'

# Open the tar.gz file
with tarfile.open(file_path, 'r:gz') as file:
    # Extract its contents into the current directory
    file.extractall()

print("Extraction completed.")

Extraction completed.


In [None]:
# Phonememes
PHONE_DEF = [
    'AA', 'AE', 'AH', 'AO', 'AW',
    'AY', 'B',  'CH', 'D', 'DH',
    'EH', 'ER', 'EY', 'F', 'G',
    'HH', 'IH', 'IY', 'JH', 'K',
    'L', 'M', 'N', 'NG', 'OW',
    'OY', 'P', 'R', 'S', 'SH',
    'T', 'TH', 'UH', 'UW', 'V',
    'W', 'Y', 'Z', 'ZH'
]

PHONE_DEF_SIL = [
    'AA', 'AE', 'AH', 'AO', 'AW',
    'AY', 'B',  'CH', 'D', 'DH',
    'EH', 'ER', 'EY', 'F', 'G',
    'HH', 'IH', 'IY', 'JH', 'K',
    'L', 'M', 'N', 'NG', 'OW',
    'OY', 'P', 'R', 'S', 'SH',
    'T', 'TH', 'UH', 'UW', 'V',
    'W', 'Y', 'Z', 'ZH', 'SIL'
]

CHANG_PHONE_DEF = [
    'AA', 'AE', 'AH', 'AW',
    'AY', 'B',  'D', 'DH',
    'EH', 'ER', 'EY', 'F', 'G',
    'HH', 'IH', 'IY', 'K',
    'L', 'M', 'N', 'NG', 'OW',
    'P', 'R', 'S',
    'T', 'TH', 'UH', 'UW', 'V',
    'W', 'Y', 'Z'
]

CONSONANT_DEF = ['CH', 'SH', 'JH', 'R', 'B',
                 'M',  'W',  'V',  'F', 'P',
                 'D',  'N',  'L',  'S', 'T',
                 'Z',  'TH', 'G',  'Y', 'HH',
                 'K', 'NG', 'ZH', 'DH']
VOWEL_DEF = ['EY', 'AE', 'AY', 'EH', 'AA',
             'AW', 'IY', 'IH', 'OY', 'OW',
             'AO', 'UH', 'AH', 'UW', 'ER']

SIL_DEF = ['SIL']

In [None]:
import os
import re
import scipy.io
import numpy as np
from scipy import stats
from g2p_en import G2p

def processNeuralDataWithPhonemes(dataPath):
  max_seq_len = 500
  def _convert_to_ascii(text):
    return [ord(char) for char in text]

  def phoneToId(p):
    return PHONE_DEF_SIL.index(p)

  input_features = []
  all_transcriptions = []
  all_phonemes = []
  padded_ascii_transcriptions = []
  padded_phonemes = []
  g2p = G2p()

  dat = scipy.io.loadmat(dataPath)
  n_trials = dat['sentenceText'].shape[0]
  print(n_trials)
  # Collect area 6v tx1 and spikePow features
  for i in range(n_trials):
    features = np.concatenate([dat['tx1'][0, i][:, 0:128], dat['spikePow'][0, i][:, 0:128]], axis=1)
    sentence = dat['sentenceText'][i].strip()
    sentence = re.sub(r'[^a-zA-Z\- \']', '', sentence).replace('--', '').lower()
    #phonemes = ['|' if p == ' ' else re.sub(r'[0-9]', '', p) for p in g2p(sentence) if re.match(r'[A-Z]+', p)]

    phonemes = []
    if len(sentence) == 0:
      phonemes = SIL_DEF
    else:
      for p in g2p(sentence):
        if p==' ':
            phonemes.append('|')

        p = re.sub(r'[0-9]', '', p)  # Remove stress
        if re.match(r'[A-Z]+', p):  # Only keep phonemes
            phonemes.append(p)

      #add one SIL symbol at the end so there's one at the end of each word
      phonemes.append('|')

    # Pad phonemes to max_seq_len
    padded_phoneme = phonemes + ['SIL'] * (max_seq_len - len(phonemes))
    padded_phoneme = padded_phoneme[:max_seq_len]

    # Convert transcription to ASCII and pad
    ascii_transcription = _convert_to_ascii(sentence)
    padded_ascii_transcription = ascii_transcription + [0] * (max_seq_len - len(ascii_transcription))
    padded_ascii_transcription = padded_ascii_transcription[:max_seq_len]

    padded_ascii_transcriptions.append(padded_ascii_transcription)
    padded_phonemes.append(padded_phoneme)

    input_features.append(features)
    all_transcriptions.append(sentence)
    all_phonemes.append(phonemes)

  #block-wise feature normalization
  blockNums = np.squeeze(dat['blockIdx'])
  blockList = np.unique(blockNums)
  blocks = []
  for b in range(len(blockList)):
    sentIdx = np.argwhere(blockNums==blockList[b])
    sentIdx = sentIdx[:,0].astype(np.int32)
    blocks.append(sentIdx)

  for b in range(len(blocks)):
    feats = np.concatenate(input_features[blocks[b][0]:(blocks[b][-1]+1)], axis=0)
    feats_mean = np.mean(feats, axis=0, keepdims=True)
    feats_std = np.std(feats, axis=0, keepdims=True)
    for i in blocks[b]:
      input_features[i] = (input_features[i] - feats_mean) / (feats_std + 1e-8)

  return input_features, all_transcriptions, all_phonemes, padded_ascii_transcription, padded_phonemes


In [None]:
def pad_features(input_feature):
  # Find the maximum number of time sequences among all trials
  #max_time_sequences = max(array.shape[0] for array in input_feature)
  max_time_sequences = 919

  # Pad each trial to have the same number of time sequences
  padded_features = [np.pad(array, ((0, max_time_sequences - array.shape[0]), (0, 0)),
                        mode='constant', constant_values=0) for array in input_feature]

  # Convert the list of arrays into a 3D numpy array
  # Now, padded_features is a numpy array with shape (280, max_time_sequences, 256)
  padded_features = np.array(padded_features)
  # print(padded_features.shape)
  # print(padded_features[0])
  return padded_features

In [None]:
## Get Train data
# train_input_features, train_transcriptions, train_phonemes, padded_train_ascii_transcription, padded_train_phonemes = processNeuralDataWithPhonemes('/content/competitionData/train/t12.2022.04.28.mat')
# print(train_transcriptions)
# print(train_phonemes)
# print(train_input_features)
# print(padded_train_ascii_transcription)
# print(padded_train_phonemes)

In [None]:
#Get Train data
curdir = '/content/'
train_input_features = []
train_transcriptions = []
train_phonemes = []
padded_train_ascii_transcription=[]
padded_train_phonemes = []
for root, dirs, files in os.walk(os.path.join(curdir, 'competitionData/train')):
  for file in files:
    file_path = os.path.join(root, file)
    train_input_feature, train_transcription, train_phoneme, padded_train_ascii_transcription_obj, padded_train_phonemes_obj = processNeuralDataWithPhonemes(file_path)
    train_input_features.extend(train_input_feature)
    train_transcriptions.extend(train_transcription)
    train_phonemes.extend(train_phoneme)
    padded_train_ascii_transcription.extend(padded_train_ascii_transcription_obj)
    padded_train_phonemes.extend(padded_train_phonemes_obj)

print(train_transcriptions)
print(train_phonemes)
print(train_input_features)
print(padded_train_phonemes)

400
360
320
280
360
480
400
400
360
360
320
520
200
400
360
320
360
440
320
320
520
180
420
400
[['T', 'R', 'AY', '|', 'N', 'AA', 'T', '|', 'T', 'UW', '|', 'Y', 'UW', 'Z', '|', 'EH', 'N', 'IY', '|', 'IH', 'N', 'S', 'EH', 'K', 'T', 'AH', 'S', 'AY', 'D', 'Z', '|', 'AE', 'T', '|', 'AO', 'L', '|'], ['JH', 'OY', 'N', '|', 'DH', 'AH', '|', 'F', 'EY', 'S', 'B', 'UH', 'K', '|', 'F', 'AE', 'N', '|', 'P', 'EY', 'JH', '|'], ['DH', 'EY', '|', 'JH', 'AH', 'S', 'T', '|', 'M', 'UW', 'V', 'D', '|', 'IH', 'N', 'T', 'UW', '|', 'DH', 'AH', '|', 'N', 'UW', '|', 'B', 'IH', 'L', 'D', 'IH', 'NG', '|'], ['AH', '|', 'B', 'IH', 'G', '|', 'M', 'EH', 'T', 'R', 'AH', 'P', 'AA', 'L', 'AH', 'T', 'AH', 'N', '|', 'EH', 'R', 'IY', 'AH', '|'], ['W', 'AH', 'T', '|', 'IH', 'Z', '|', 'DH', 'AE', 'T', '|', 'R', 'EH', 'S', 'T', 'ER', 'AA', 'N', 'T', '|'], ['Y', 'UW', '|', 'M', 'AY', 'T', '|', 'W', 'AA', 'N', 'T', '|', 'T', 'UW', '|', 'AH', 'T', 'AE', 'K', '|'], ['AY', '|', 'W', 'AA', 'Z', '|', 'R', 'IH', 'L', 'IY', '|', 'SH'

IOPub data rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_data_rate_limit`.

Current values:
NotebookApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
NotebookApp.rate_limit_window=3.0 (secs)



In [None]:
print(len(train_input_features[1][1]))
print(len(train_transcriptions[1]))
print(train_phonemes)
print(padded_train_phonemes)

256
26
[['T', 'R', 'AY', '|', 'N', 'AA', 'T', '|', 'T', 'UW', '|', 'Y', 'UW', 'Z', '|', 'EH', 'N', 'IY', '|', 'IH', 'N', 'S', 'EH', 'K', 'T', 'AH', 'S', 'AY', 'D', 'Z', '|', 'AE', 'T', '|', 'AO', 'L', '|'], ['JH', 'OY', 'N', '|', 'DH', 'AH', '|', 'F', 'EY', 'S', 'B', 'UH', 'K', '|', 'F', 'AE', 'N', '|', 'P', 'EY', 'JH', '|'], ['DH', 'EY', '|', 'JH', 'AH', 'S', 'T', '|', 'M', 'UW', 'V', 'D', '|', 'IH', 'N', 'T', 'UW', '|', 'DH', 'AH', '|', 'N', 'UW', '|', 'B', 'IH', 'L', 'D', 'IH', 'NG', '|'], ['AH', '|', 'B', 'IH', 'G', '|', 'M', 'EH', 'T', 'R', 'AH', 'P', 'AA', 'L', 'AH', 'T', 'AH', 'N', '|', 'EH', 'R', 'IY', 'AH', '|'], ['W', 'AH', 'T', '|', 'IH', 'Z', '|', 'DH', 'AE', 'T', '|', 'R', 'EH', 'S', 'T', 'ER', 'AA', 'N', 'T', '|'], ['Y', 'UW', '|', 'M', 'AY', 'T', '|', 'W', 'AA', 'N', 'T', '|', 'T', 'UW', '|', 'AH', 'T', 'AE', 'K', '|'], ['AY', '|', 'W', 'AA', 'Z', '|', 'R', 'IH', 'L', 'IY', '|', 'SH', 'AA', 'K', 'T', '|'], ['N', 'AY', 'S', '|', 'HH', 'AW', 'S', '|', 'AH', 'N', 'D', '|', 

IOPub data rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_data_rate_limit`.

Current values:
NotebookApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
NotebookApp.rate_limit_window=3.0 (secs)



In [None]:
#Get Test data
curdir = '/content/'
test_input_features = []
test_transcriptions = []
test_phonemes = []
padded_test_ascii_transcription=[]
padded_test_phonemes = []
for root, dirs, files in os.walk(os.path.join(curdir, 'competitionData/test')):
  for file in files:
    file_path = os.path.join(root, file)
    test_input_feature, test_transcription, test_phoneme, padded_test_ascii_transcription_obj, padded_test_phonemes_obj = processNeuralDataWithPhonemes(file_path)
    test_input_features.extend(test_input_feature)
    test_transcriptions.extend(test_transcription)
    test_phonemes.extend(test_phoneme)
    padded_test_ascii_transcription.extend(padded_test_ascii_transcription_obj)
    padded_test_phonemes.extend(padded_test_phonemes_obj)

print(test_transcriptions)
print(test_phonemes)
print(test_input_features)
print(padded_test_ascii_transcription)
print(padded_test_phonemes)

40
40
40
20
20
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
20
20
40
["i think that's something we might want", 'she thought for a few minutes', 'that is a sad situation', 'they do have capital punishment', 'it was nice talking to you dudley', 'my favorite team is the rangers', 'kind of goes along with what you were saying', "why they don't go with fourteen one", "reader's choice award", 'they are the detroit delegates', 'he went to vietnam one man and came back another', 'as any city grows up', 'so you did buy a pattern', "it's about sixteen and a half percent", 'i agree with the phone', "so i think it's going to be a lot easier", 'not too much soy sauce', 'we take a certain portion of your paycheck', 'who sponsors you', "and i don't know what's going to happen", 'i think we have about thirteen months left on it', 'i do particularly find it annoying', 'they have clubs and a swimming pool', 'our point of view', 'it really is worth while to sew', 'i felt that was a little insensitive

IOPub data rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_data_rate_limit`.

Current values:
NotebookApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
NotebookApp.rate_limit_window=3.0 (secs)



In [None]:
padded_train_features = pad_features(train_input_features)
padded_test_features = pad_features(test_input_features)

# Print the shapes of the test and train input features
print(padded_train_features.shape)
print(padded_test_features.shape)

(8800, 919, 256)
(880, 919, 256)


In [None]:
# Calculate lengths of samples at each utterance
lengths_train = [len(timeblocks) for timeblocks in train_input_features]
lengths_test = [len(timeblocks) for timeblocks in test_input_features]
all_lengths = lengths_train + lengths_test
print(all_lengths)

[419, 279, 322, 297, 219, 248, 212, 239, 368, 376, 408, 388, 228, 218, 381, 196, 446, 244, 440, 440, 309, 293, 185, 148, 234, 259, 359, 327, 223, 202, 368, 364, 391, 254, 336, 267, 339, 434, 372, 351, 326, 381, 312, 380, 420, 424, 343, 393, 399, 277, 202, 308, 304, 318, 281, 240, 207, 218, 157, 205, 333, 495, 227, 506, 293, 236, 224, 448, 451, 328, 185, 232, 329, 312, 229, 226, 237, 423, 288, 243, 255, 381, 407, 387, 327, 383, 275, 229, 287, 237, 351, 411, 169, 151, 187, 197, 167, 468, 309, 547, 247, 509, 397, 292, 225, 177, 237, 213, 306, 211, 393, 198, 207, 209, 308, 317, 160, 250, 147, 280, 469, 203, 395, 204, 259, 255, 250, 265, 237, 229, 482, 442, 501, 458, 387, 463, 298, 202, 379, 151, 220, 382, 170, 396, 154, 183, 276, 241, 133, 245, 378, 310, 304, 338, 323, 346, 322, 351, 399, 274, 303, 263, 209, 296, 404, 489, 355, 349, 476, 249, 241, 184, 296, 235, 168, 306, 279, 228, 392, 250, 228, 143, 206, 286, 235, 225, 176, 220, 230, 262, 290, 362, 222, 444, 260, 285, 267, 280, 428, 279,

### Import Scripts from Python Files in the Repository

In [None]:
class basics2s(Dataset):
    def __init__(self, X, lens, Y, inds, transform=None):

        self.X = X
        self.Y = Y
        self.lens = lens
        self.transform = transform
        self.inds = inds

    def __len__(self):
        return len(self.X)
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        sample = copy.deepcopy(self.X[idx])
        if not self.transform is None:
            return (self.transform(sample), self.lens[idx], self.Y[idx], self.inds[idx])
        else:
            return (sample, self.lens[idx], self.Y[idx], self.inds[idx])


class hybridloader(Dataset):
    def __init__(self, X, lens, Y, inds, Y_ctc, ctc_lens, transform=None):
        self.X = X
        self.Y = Y
        self.lens= lens
        self.transform = transform
        self.inds = inds
        self.Y_ctc = Y_ctc
        self.ctc_lens = ctc_lens

    def __len__(self):
        return len(self.X)
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        sample = copy.deepcopy(self.X[idx])
        if not self.transform is None:
            return (self.transform(sample), self.lens[idx], self.Y[idx], self.inds[idx], self.Y_ctc[idx], self.ctc_lens[idx])
        else:
            return (sample, self.lens[idx], self.Y[idx], self.inds[idx], self.Y_ctc[idx], self.ctc_lens[idx])


class CTCDataset(Dataset):
    def __init__(self, X, Y, lens, outlens, inds, transform=None, y_transforms=None):

        self.X = X
        self.Y = Y
        self.lens = lens
        self.outlens = outlens
        self.inds = inds
        self.transform = transform

    def __len__(self):
        return len(self.X)
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        sample = copy.deepcopy(self.X[idx])
        if not self.transform is None:
            return (self.transform(sample), self.Y[idx], self.lens[idx], self.outlens[idx],  self.inds[idx])
        else:
            return (sample, self.Y[idx], self.lens[idx], self.outlens[idx], self.inds[idx])


class CTCDataset_Wordct(Dataset):
    def __init__(self, X, Y, lens, outlens, inds, wordct, transform=None, y_transforms=None):

        self.X = X
        self.Y = Y
        self.lens = lens
        self.outlens = outlens
        self.inds = inds
        self.transform = transform
        self.wordct = wordct

    def __len__(self):
        return len(self.X)
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        sample = copy.deepcopy(self.X[idx])
        if not self.transform is None:
            return (self.transform(sample), self.Y[idx], self.lens[idx], self.outlens[idx],  self.inds[idx], self.wordct[idx])
        else:
            return (sample, self.Y[idx], self.lens[idx], self.outlens[idx], self.inds[idx], self.wordct[idx])

class Jitter(object):
    """
    randomly select the default window from the original window
    scale the amt of jitter by jitter amt
    validation: just return the default window.
    """
    def __init__(self, original_window, default_window, jitter_amt, sr=200, decimation=6, validate=False):
        self.original_window = original_window
        self.default_window = default_window
        self.jitter_scale = jitter_amt

        default_samples = np.asarray(default_window) - self.original_window[0]
        default_samples = np.asarray(default_samples)*sr/decimation

        default_samples[0] = int(default_samples[0])
        default_samples[1] = int(default_samples[1])

        self.default_samples = default_samples
        self.validate = validate

        self.winsize = int(default_samples[1] - default_samples[0])+1
        self.max_start = int(int((original_window[1] - original_window[0])*sr/decimation) - self.winsize)


    def __call__(self, sample):
        if self.validate:
            return sample[int(self.default_samples[0]):int(self.default_samples[1])+1, :]
        else:
            start = np.random.randint(0, self.max_start)
            scaled_start = np.abs(start-self.default_samples[0])
            scaled_start = int(scaled_start*self.jitter_scale)
            scaled_start = int(scaled_start*np.sign(start-self.default_samples[0]) + self.default_samples[0])
            return sample[scaled_start:scaled_start+self.winsize]


class Blackout(object):
    """
    The blackout augmentation.
    """
    def __init__(self, blackout_max_length=0.3, blackout_prob=0.5):

        self.bomax = blackout_max_length
        self.bprob = blackout_prob


    def __call__(self, sample):

        blackout_times = int(np.random.uniform(0, 1)*sample.shape[0]*self.bomax)
        start = np.random.randint(0, sample.shape[0]-sample.shape[0]*self.bomax)
        if random.uniform(0, 1) < self.bprob:
            sample[start:(start+blackout_times), :] = 0
        return sample

class ChannelBlackout(object):
    """
    Randomly blackout a channel.
    """
    def __init__(self, blackout_chans_max=20, blackout_prob=0.2):
        self.bcm = blackout_chans_max
        self.bp = blackout_prob
    def __call__(self, sample):
        if random.uniform(0, 1) < self.bp:
            inds = np.arange(sample.shape[-1])
            np.random.shuffle(inds)
            boi = inds[:self.bcm]
            sample[:, bcm] = 0

def normalize(x, axis=-1, order=2):
    """Normalizes a Numpy array.
    Args:
        x: Numpy array to normalize.
        axis: axis along which to normalize.

        order: Normalization order (e.g. `order=2` for L2 norm).
    Returns:
        A normalized copy of the array.
    """
    l2 = np.atleast_1d(np.linalg.norm(x, order, axis))
    l2[l2 == 0] = 1
    return x / np.expand_dims(l2, axis)


class Normalize(object):
    def __init__(self, axis):
        """
        Does normalization func
        """
        self.axis = axis

    def __call__(self, sample):
        sample_ = normalize(sample, axis=self.axis)
        return sample_

class AdditiveNoise(object):
    def __init__(self, sigma):
        """
        Just adds white noise.
        """
        self.sigma = sigma

    def __call__(self, sample):
        sample_ = sample + self.sigma*np.random.randn(*sample.shape)
        return sample_

class ScaleAugment(object):
    def __init__(self, low_range, up_range):
        self.up_range = up_range # e.g. .8
        self.low_range = low_range
        print('scale', self.low_range, self.up_range)
#         assert self.up_range >= self.low_range
    def __call__(self, sample):
        multiplier = np.random.uniform(self.low_range, self.up_range)
        return sample*multiplier

class LevelChannelNoise(object):
    def __init__(self, sigma, channels=128):
        """
        Sigma: the noise std.
        """
        self.sigma= sigma
        self.channels = channels

    def __call__(self, sample):
        sample += self.sigma*np.random.randn(1,sample.shape[-1]) # Add uniform noise across the whole channel.
        return sample



In [None]:
def clean_labels(labels):
    """
    in : pandas dataframe of the labels
    out:

    labels - the same dataframe, but the phone labels will have
        stress markings removed, and all commas will be removed.
    all_ph - list of all the phonemes to make the encdoing dict.
    """

    newlabs= []
    all_ph = []
    for p in labels['ph_label']:
        pp_ = []
        for pp in p:
            if not ',' in pp:
                pp_.append(pp)
        p = pp_
        newlabs.append([pp[:2] for pp in p if not p == ','])
        all_ph.extend([pp[:2] for pp in p if not p == ','])
    labels['ph_label'] = newlabs
    return labels, all_ph

def normalize(x, axis=-1, order=2):
    """Normalizes a Numpy array.
    Args:
        x: Numpy array to normalize.
        axis: axis along which to normalize.

        order: Normalization order (e.g. `order=2` for L2 norm).
    Returns:
        A normalized copy of the array.
    """
    l2 = np.atleast_1d(np.linalg.norm(x, order, axis))
    l2[l2 == 0] = 1
    return x / np.expand_dims(l2, axis)
def minmax_scaling(X):
    chanmins = np.min(np.min(X, axis=0), axis=0)
    chanmax = np.max(np.max(X, axis=0), axis=0)
    X = X-chanmins
    X = X/(chanmax-chanmins)
    print('zero 1', chanmins, chanmax)
    return X

def pertrial_minmax(X):
    chanmins = np.min(X, axis=1, keepdims=True)
    chanmax = np.max(X, axis=1, keepdims=True)
    X = X - chanmins
    X = X/ (chanmax -chanmins)
    return X

def rezscore(X):
    chanmeans = np.mean(np.mean(X, axis=0), axis=0)
    chanstd = np.mean(np.std(X, axis=1), axis=0)
    print('cm, cst', chanmeans, chanstd)
    X = X - chanmeans
    X = X- chanstd
    return X

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torchaudio

class GreedyCTCDecoder(torch.nn.Module):
    def __init__(self, labels, blank=0):
        super().__init__()
        self.labels = labels
        self.blank = blank

    def forward(self, emission: torch.Tensor):
        """Given a sequence emission over labels, get the best path
        Args:
          emission (Tensor): Logit tensors. Shape `[num_seq, num_label]`.

        indices = torch.unique_consecutive(indices, dim=-1)
        Returns:
          List[str]: The resulting transcript
        """
        indices = torch.argmax(emission, dim=-1)  # [num_seq,]
        indices = torch.unique_consecutive(indices, dim=-1)
        indices = [i for i in indices if i != self.blank]
        joined = [self.labels[i] for i in indices]
        return joined


def batch_wer_cer(emission, gts, gt_phones, greedy, beam_search_decoder):
    batch_hypotheses = beam_search_decoder(emission.cpu())
    import ipdb
    ipdb.set_trace()
    transcripts = [" ".join(hypo[0].words) for hypo in batch_hypotheses]
    tokens = [hypo[0].tokens for hypo in batch_hypotheses]
    net_wer = 0
    net_cer = 0
    net_per = 0
#     import ipdb
#     ipdb.set_trace()
    for beam_search_transcript, gt, gt_phonemes, trans_phones in zip(transcripts, gts, gt_phones, tokens):

        trans_phones = [greedy.labels[i] for i in trans_phones]
        net_per += torchaudio.functional.edit_distance(gt_phonemes[1:], trans_phones[1:]) / (len(gt_phonemes)-1)
        net_cer += torchaudio.functional.edit_distance(gt, beam_search_transcript)/len(gt)
        net_wer += torchaudio.functional.edit_distance(gt.split(' '), beam_search_transcript.split(' '))/len(gt.split(' '))

    return net_per, net_cer, net_wer, transcripts


def greedy_beam_wer_cer(emission, gt, gt_phonemes, greedy, beam_search_decoder, print_hypo=False):
    """
    Does beam/greedy search and tells us how we're doing.

    Inputs:
        emission - the probabilities of each phoneme at each timestep
        gt - the text of the ground truth sentence
        gt_phonemes, the list of the ground truth phoneme sequence
        beam_search_decoder, at torchaudio.models.decoder ctc_decoder


    Outputs:
        Greedy phone error rate
        The ground truth text
        the beam search text
        beam search char error rate
        beam search wer
        beam search phone error rate
    """
#     import ipdb
#     ipdb.set_trace()
    greedy_result = greedy(emission)
    greedy_per = torchaudio.functional.edit_distance(gt_phonemes[1:], greedy_result[1:]) / (len(gt_phonemes)-1)
    beam_search_result = beam_search_decoder(emission.cpu().unsqueeze(0))
    if len(beam_search_result) > 0:

        if print_hypo:
            print('gt:', gt)
            print('top 50')
            print(len(beam_search_result[0]))
            for b in beam_search_result[0][:50]:
                print(b.words)
            print('---')
        try:
            beam_search_transcript = " ".join(beam_search_result[0][0].words).strip()
        except Exception:
            beam_search_transcript = "nobeam"
    else:
        beam_search_transcript = 'nobeams'
    try:
        beam_search_phones = [greedy.labels[i] for i in beam_search_result[0][0].tokens]
    except Exception:
        beam_search_phones = 'AH'

    # ignore silence token at the start
    beam_search_per = torchaudio.functional.edit_distance(gt_phonemes[1:], beam_search_phones[1:]) / (len(gt_phonemes) -1)
    beam_cer = torchaudio.functional.edit_distance(gt, beam_search_transcript)/len(gt)
    beam_wer = torchaudio.functional.edit_distance(gt.split(' '), beam_search_transcript.split(' '))/len(gt.split(' '))
    return greedy_per, gt, beam_search_transcript, beam_cer, beam_wer, beam_search_per

def test_ensemble(model, test_loader, greedy, beam_search_decoder,
                  texts, tokens, wandb_name ='final_te_', device='cuda',
                  print_greedy=True, verbose=True, printall=False,
                  print_hypo=False,
                  realtime_eval=False):
    """

    Tests model loss,
        if the print_greedy flag is true, then it calculates wer/cer/per

    returns updated model, loss, and metrics, if they were evaluated,
        otherwsie none.

    Inputs:
    Greedy =greedy ctc deocder
    beam_search_decoder = torchaudio ctc decoder
    print_greedy = True  -
        Then wer, cer, per are evaluated
    verbose = True -
        Then we print out predictions to get a flavor of what was going on :D
    texts - the ground truth text. May not need for audio,m then put in None.

    """
#     model.eval()
    loss_fn = nn.CTCLoss()
    all_emish = []
    wers = []
    pers = []
    cers = []
    with torch.no_grad():
        total_loss, total_samps = 0, 0
        total_wer = 0
        total_cer=0
        total_gcer = 0
        total_per = 0
        gts, transcripts = [], []
        for x, y, l, targ_len, gtsent in test_loader:
            x = x.float().to(device)
            y = y.long().to(device)
            l = l.int().to(device)
            if not texts is None:
                gtsent = gtsent.long().cpu().numpy()
            else:
                gtsent = gtsent
            targ_len = targ_len.int().to(device)
            models = [model]
            for k, model in enumerate(models):
                model.eval()
                model.to(device)
                if k == 0:
                    emissions, _,  lengths = model(x, l)
                    emissions = F.softmax(emissions, dim=-1)

            all_emish.append(emissions.detach().cpu().numpy())
            emissions = torch.log(emissions)
            loss = loss_fn(emissions, y, lengths, targ_len)
            total_loss += loss.item()
            total_samps += x.shape[0]


            ### The code to look at the text.  We dont want to do this every trial since
            # it can be expensive.
            if print_greedy:
                 for k , e in enumerate(emissions.permute(1,0, 2)):
                    gt_phones = [tokens[yy] for yy in y.detach().cpu().numpy()[k] if not yy == -1 and not yy == 0]
                    if not texts is None:
                        gt  = texts[int(gtsent[k])]
                    else:
                        gt = gtsent[k]
                        # Get the ground truth text.
                    greedy_cer, gt_, transcript,cer,  wer, per = greedy_beam_wer_cer(e, gt,
                                                                                     gt_phones,
                                                                                     greedy,
                                                                                     beam_search_decoder,
                                                                                    print_hypo)
                    gts.append(gt_)
                    transcripts.append(transcript)
                    wers.append(wer)
                    pers.append(per)
                    cers.append(cer)
                    total_wer += wer
                    total_cer += cer
                    total_per += per
                    total_gcer += greedy_cer


        if print_greedy and verbose:
            print('net wer', total_wer/total_samps)
            print('net cer', total_cer/total_samps)
            print('greedy per', total_gcer/total_samps)
            print('beam per', total_per/total_samps)

        wandb.log({
            wandb_name + 'wer': total_wer/total_samps,
            wandb_name + 'med_wer':  np.median(wers),
            wandb_name + 'per': total_per/total_samps,
            wandb_name + 'med_per': np.median(pers),
            wandb_name + 'cer' : total_cer/total_samps,
            wandb_name + 'med_cer' : np.median(cers),

        })

        if realtime_eval:
#             print('wers:', wers, len(wers))
#             print('pers:', pers)
#             print('cers:', cers)
            wandb.log({
                'allrt_wer':wers,
                'allrt_cer':cers,
                'allrt_per':pers,
                'gts':gts,
                'trans':transcripts
            })

        ctr =0
        randomsent = np.arange(len(gts))
        np.random.shuffle(randomsent)
        for k in randomsent:
            gt = gts[k]
            trans = transcripts[k]
            wer = (torchaudio.functional.edit_distance(gt.split(' '), trans.split(' '))/len(gt.split(' ')))
            cer = (torchaudio.functional.edit_distance(gt, trans)/len(gt))
            if verbose:
                print('gt', gt)
                print('t', trans, 'wer: %.2f' %wer, 'cer: %.2f' %cer)
            ctr+=1
            if ctr > 10 and not printall:
                break

        epoch_loss = total_loss/total_samps
        # Currently I dont eval wer/cer every time because it is a bit computationally expesnsive.
        if print_greedy:
            return epoch_loss, model, total_wer/total_samps, total_cer/total_samps, total_per/total_samps, wers, pers, cers,  transcripts, gts, all_emish
        else:
            return epoch_loss, model, None, None, None, wers, pers, cers, transcripts, gts, all_emish

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from speechbrain.nnet.RNN import LiGRU_Layer


class AUXCnnRnnClassifier(torch.nn.Module):
    """
    Pytorch model that has an rnn that outputs a prediction at every timepoint

    Uses a CNN/RNN architecture and pytorch's pack padded sequence and pad packed sequence functions.


    """
    def __init__(self, rnn_dim, KS, num_layers, dropout, n_targ, bidirectional, in_channels=506, nword_targ=10):
        super().__init__()

        self.preprocessing_conv = nn.Conv1d(in_channels=in_channels,
                                           out_channels=rnn_dim,
                                           kernel_size=KS,
                                           stride=KS)
        self.BiGRU = nn.GRU(input_size=rnn_dim, hidden_size=rnn_dim,
                           num_layers =num_layers,
                            bidirectional=bidirectional,
                            dropout=dropout)
        self.num_layers = num_layers
        self.rnn_dim = rnn_dim
        self.ks = KS

        self.dropout = nn.Dropout(dropout)
        if bidirectional:
            mult = 2
        else:
            mult = 1
        self.mult = mult
        self.dense = nn.Linear(rnn_dim*mult, n_targ)

        self.word_ct_layer  = nn.Linear(rnn_dim*mult, nword_targ)

    def forward(self, x, lens):
        # x comes in bs, t, c
        lens = lens//self.ks
        # Bs, C, T for conv
        x = x.contiguous().permute(0, 2, 1)
        x = self.preprocessing_conv(x)
        x = self.dropout(x)
        # reshape for RNN.  T, B, C
        x = x.contiguous().permute(2, 0, 1)
        packed = pack_padded_sequence(x, lens.int().cpu(), enforce_sorted=False)
        emissions, hiddens = self.BiGRU(packed)
        unpacked_emissions, lens_unpacked = pad_packed_sequence(emissions)
        unpacked_for_wordct = unpacked_emissions[-1]
        unpacked_outputs = self.dense(unpacked_emissions)
        return unpacked_outputs, self.word_ct_layer(unpacked_for_wordct), lens_unpacked


class FlexibleLiGRUClassifier(torch.nn.Module):
    """
    Pytorch model that has an rnn that outputs a prediction at every timepoint

    Uses a CNN/RNN architecture and pytorch's pack padded sequence and pad packed sequence functions.


    """
    def __init__(self, rnn_dim, KS, num_layers, batch_size, dropout, n_targ, bidirectional, in_channels=506):
        super().__init__()

        self.preprocessing_conv = nn.Conv1d(in_channels=in_channels,
                                           out_channels=rnn_dim,
                                           kernel_size=KS,
                                           stride=KS)
        self.BiGRU = LiGRU_Layer(input_size=rnn_dim, hidden_size=rnn_dim,
                           num_layers =num_layers,
                                 batch_size=batch_size,
                            bidirectional=bidirectional,
                            dropout=dropout)
        self.num_layers = num_layers
        self.rnn_dim = rnn_dim
        self.ks = KS

        self.dropout = nn.Dropout(dropout)
        if bidirectional:
            mult = 2
        else:
            mult = 1
        self.mult = mult
        self.dense = nn.Linear(rnn_dim*mult, n_targ)

    def forward(self, x, lens):
        # x comes in bs, t, c
        lens = lens//self.ks
        # Bs, C, T for conv
        x = x.contiguous().permute(0, 2, 1)
        x = self.preprocessing_conv(x)
        x = self.dropout(x)

        # reshape for RNN.
        x = x.contiguous().permute(2, 0, 1)
#         packed = pack_padded_sequence(x, lens.int().cpu(), enforce_sorted=False)
        unpacked_emissions = self.BiGRU(x)
        lens_unpacked = lens.int().cpu()
        unpacked_outputs = self.dense(unpacked_emissions)
        return unpacked_outputs, lens_unpacked


class FlexLSTM(torch.nn.Module):
    """
    Pytorch model that has an rnn that outputs a prediction at every timepoint

    Uses a CNN/RNN architecture and pytorch's pack padded sequence and pad packed sequence functions.


    """
    def __init__(self, rnn_dim, KS, num_layers, dropout, n_targ, bidirectional, in_channels=506):
        super().__init__()

        self.preprocessing_conv = nn.Conv1d(in_channels=in_channels,
                                           out_channels=rnn_dim,
                                           kernel_size=KS,
                                           stride=KS)
        self.BiGRU = nn.LSTM(input_size=rnn_dim, hidden_size=rnn_dim,
                           num_layers =num_layers,
                            bidirectional=bidirectional,
                            dropout=dropout)
        self.num_layers = num_layers
        self.rnn_dim = rnn_dim
        self.ks = KS

        self.dropout = nn.Dropout(dropout)
        if bidirectional:
            mult = 2
        else:
            mult = 1
        self.mult = mult
        self.dense = nn.Linear(rnn_dim*mult, n_targ)

    def forward(self, x, lens):
        # x comes in bs, t, c
        lens = lens//self.ks
        # Bs, C, T for conv
        x = x.contiguous().permute(0, 2, 1)
        x = self.preprocessing_conv(x)
        x = self.dropout(x)

        # reshape for RNN.
        x = x.contiguous().permute(2, 0, 1)
        packed = pack_padded_sequence(x, lens.int().cpu(), enforce_sorted=False)
        emissions, hiddens = self.BiGRU(packed)
        unpacked_emissions, lens_unpacked = pad_packed_sequence(emissions)
        unpacked_outputs = self.dense(unpacked_emissions)
        return unpacked_outputs, lens_unpacked


class CRDNN(torch.nn.Module):
    """
    Pytorch model that has an rnn that outputs a prediction at every timepoint

    Uses a CNN/RNN architecture and pytorch's pack padded sequence and pad packed sequence functions.
    Uses a CRDNN


    """
    def __init__(self, rnn_dim, KS, num_layers, dropout, n_targ, bidirectional, stride1 = 2,
                 stride2= 1, in_channels=506, activation=nn.LeakyReLU(), KS2 = None):
        super().__init__()

        self.preprocessing_conv = nn.Conv1d(in_channels=in_channels,
                                           out_channels=rnn_dim,
                                           kernel_size=KS,
                                           stride=stride1)

        if KS2 is None:
            KS2 = KS
        self.conv2 = nn.Conv1d(in_channels=rnn_dim, out_channels=rnn_dim, kernel_size=KS2,
                              stride=stride2)

        self.act = activation
        self.BiGRU = nn.GRU(input_size=rnn_dim, hidden_size=rnn_dim,
                           num_layers =num_layers,
                            bidirectional=bidirectional,
                            dropout=dropout)

        self.num_layers = num_layers
        self.rnn_dim = rnn_dim
        self.ks = KS

        self.dropout = nn.Dropout(dropout)
        if bidirectional:
            mult = 2
        else:
            mult = 1
        self.mult = mult
        self.dense1 = nn.Linear(rnn_dim*mult, rnn_dim)
#         self.dense2 = nn.Linear(rnn_dim, rnn_dim)
        self.dense2 = nn.Linear(rnn_dim, n_targ)

    def forward(self, x, lens):
        # x comes in bs, t, c
        lens = lens//(self.stride1 * self.stride2)
        # Bs, C, T for conv
        x = x.contiguous().permute(0, 2, 1)
        x = self.preprocessing_conv(x)
        x = self.dropout(x)
        x = self.act(x)
        x = self.conv2(x)
        x = self.dropout(x)
        x = self.act(x)

        # reshape for RNN.
        x = x.contiguous().permute(2, 0, 1)
        packed = pack_padded_sequence(x, lens.int().cpu(), enforce_sorted=False)
        emissions, hiddens = self.BiGRU(packed)
        unpacked_emissions, lens_unpacked = pad_packed_sequence(emissions)
        unpacked_outputs_ = self.dropout(self.act(self.dense1(unpacked_emissions)))
        unpacked_outputs = self.dense2(unpacked_outputs_)
        return unpacked_outputs, lens_unpacked

class FlexibleCnnRnnClassifier(torch.nn.Module):
    """
    Pytorch model that has an rnn that outputs a prediction at every timepoint

    Uses a CNN/RNN architecture and pytorch's pack padded sequence and pad packed sequence functions.


    """
    def __init__(self, rnn_dim, KS, num_layers, dropout, n_targ, bidirectional, in_channels=506):
      super().__init__()

      self.preprocessing_conv = nn.Conv1d(in_channels=in_channels,
                                          out_channels=rnn_dim,
                                          kernel_size=KS,
                                          stride=KS)
      self.BiGRU = nn.GRU(input_size=rnn_dim, hidden_size=rnn_dim,
                          num_layers =num_layers,
                          bidirectional=bidirectional,
                          dropout=dropout)
      self.num_layers = num_layers
      self.rnn_dim = rnn_dim
      self.ks = KS

      self.dropout = nn.Dropout(dropout)
      if bidirectional:
        mult = 2
      else:
        mult = 1
      self.mult = mult
      self.dense = nn.Linear(rnn_dim*mult, n_targ)

    def forward(self, x, lens):
      # x comes in bs, t, c
      lens = lens//self.ks
      # Bs, C, T for conv
      x = x.contiguous().permute(0, 2, 1)
      x = self.preprocessing_conv(x)
      x = self.dropout(x)

      # reshape for RNN.  T, B, C
      x = x.contiguous().permute(2, 0, 1)
      packed = pack_padded_sequence(x, lens.int().cpu(), enforce_sorted=False)
      emissions, hiddens = self.BiGRU(packed)
      unpacked_emissions, lens_unpacked = pad_packed_sequence(emissions)
      unpacked_outputs = self.dense(unpacked_emissions)
      return unpacked_outputs, lens_unpacked

In [None]:
def check_1024_done(emissions):
    chk = emissions[-8:, :2] # Check last 8 preds, and look at both silence and the blank
    s= torch.sum(chk, dim=-1)
    return torch.mean(s) > 7.1/8 # check avg prob greater than 88.8%


def test_ensemble_wpm(model, test_loader, greedy, beam_search_decoder,
                  texts, tokens, wandb_name ='final_te_', device='cuda',
                  print_greedy=True, verbose=True, printall=False,
                  print_hypo=False,
                  realtime_eval=False, paradigm='1024'):
    """

    Tests model loss,
        if the print_greedy flag is true, then it calculates wer/cer/per

    returns updated model, loss, and metrics, if they were evaluated,
        otherwsie none.

    Inputs:
    Greedy =greedy ctc deocder
    beam_search_decoder = torchaudio ctc decoder
    print_greedy = True  -
        Then wer, cer, per are evaluated
    verbose = True -
        Then we print out predictions to get a flavor of what was going on :D
    texts - the ground truth text. May not need for audio,m then put in None.

    """
    # model.eval()
    loss_fn = nn.CTCLoss()
    all_emish = []
    wers = []
    pers = []
    cers = []
    speech_times = []
    wpms = []


    with torch.no_grad():
        total_loss, total_samps = 0, 0
        total_wer = 0
        total_cer=0
        total_gcer = 0
        total_per = 0
        gts, transcripts = [], []
        for x, y, l, targ_len, gtsent in test_loader:
            assert x.shape[0] == 1 # We need to go one at a time for the WPM calculation.
            x = x.float().to(device)
            y = y.long().to(device)
            l = l.int().to(device)
            if not texts is None:
                gtsent = gtsent.long().cpu().numpy()
            else:
                gtsent = gtsent
            targ_len = targ_len.int().to(device)

            model.eval()
            model.to(device)


            for t_elapsed in [1.9, 2.7, 3.5, 4.3, 5.1, 5.9, 6.7, 7.5]:
                # Time elapsed since go cue
                # Since we start 500ms prior need to add that back in.
                sample_ct = (int(((t_elapsed + .5)*200)/6))
                l = torch.tensor([sample_ct]).int().to(device)
                emissions,  lengths = model(x[:, :sample_ct], l)
                emissions = F.softmax(emissions, dim=-1)
                if paradigm == '1024' and check_1024_done(emissions.squeeze()):
                    break

            if t_elapsed < 7.5:
                # Early stop happened, so lets account for when participant stopped talking as detailed in methods.
                if paradigm == '1024':
                    silent_time = (8*4*6)/200 # 8 samples, theres a 4x conv downsample in the model, plus a 6x decimation, divide by 200Hz
                    speaking_time = t_elapsed-silent_time

            else:
                speaking_time = t_elapsed # An early stop failed to occur, so we just use the time elapsed.
            speech_times.append(speaking_time)
            all_emish.append(emissions.detach().cpu().numpy())
            emissions = torch.log(emissions)
            loss = loss_fn(emissions, y, lengths, targ_len)
            total_loss += loss.item()
            total_samps += x.shape[0]


            ### The code to look at the text.  We dont want to do this every trial since
            # it can be expensive.
            if print_greedy:
                 for k , e in enumerate(emissions.permute(1,0, 2)):
                    gt_phones = [tokens[yy] for yy in y.detach().cpu().numpy()[k] if not yy == -1 and not yy == 0]
                    if not texts is None:
                        gt  = texts[int(gtsent[k])]
                    else:
                        gt = gtsent[k]
                        # Get the ground truth text.
                    greedy_cer, gt_, transcript,cer,  wer, per = greedy_beam_wer_cer(e, gt,
                                                                                     gt_phones,
                                                                                     greedy,
                                                                                     beam_search_decoder,
                                                                                    print_hypo)
                    gts.append(gt_)
                    transcripts.append(transcript)
                    wpms.append(60*len(transcript.strip().split(' '))/speaking_time)
                    wers.append(wer)
                    pers.append(per)
                    cers.append(cer)
                    total_wer += wer
                    total_cer += cer
                    total_per += per
                    total_gcer += greedy_cer


        if print_greedy and verbose:
            print('net wer', total_wer/total_samps)
            print('net cer', total_cer/total_samps)
            print('greedy per', total_gcer/total_samps)
            print('beam per', total_per/total_samps)

        wandb.log({
            wandb_name + 'wer': total_wer/total_samps,
            wandb_name + 'med_wer':  np.median(wers),
            wandb_name + 'per': total_per/total_samps,
            wandb_name + 'med_per': np.median(pers),
            wandb_name + 'cer' : total_cer/total_samps,
            wandb_name + 'med_cer' : np.median(cers),

        })

        if realtime_eval:

            wandb.log({
                'allrt_wer':wers,
                'allrt_cer':cers,
                'allrt_per':pers,
                'gts':gts,
                'trans':transcripts
            })

        ctr =0
        randomsent = np.arange(len(gts))
        np.random.shuffle(randomsent)
        for k in randomsent:
            gt = gts[k]
            trans = transcripts[k]
            wer = (torchaudio.functional.edit_distance(gt.split(' '), trans.split(' '))/len(gt.split(' ')))
            cer = (torchaudio.functional.edit_distance(gt, trans)/len(gt))
            if verbose:
                print('gt', gt)
                print('t', trans, 'wer: %.2f' %wer, 'cer: %.2f' %cer)
            ctr+=1
            if ctr > 10 and not printall:
                break

        epoch_loss = total_loss/total_samps
        # Currently I dont eval wer/cer every time because it is a bit computationally expesnsive.
        if print_greedy:
            return epoch_loss, model, total_wer/total_samps, total_cer/total_samps, total_per/total_samps, wers, pers, cers,  transcripts, gts, all_emish, speech_times, wpms
        else:
            return epoch_loss, model, None, None, None, wers, pers, cers, transcripts, gts, all_emish, speech_times, wpms

In [None]:
from torch.nn import CTCLoss
import torch
import torch.nn.functional as F
import torch.nn as nn
import copy
import numpy as np
import torchaudio

# import wandb

def train_f(model, train_loader, optimizer, device):
    """
    Train a ctc model.
    """
    total_loss = 0
    total_samps = 0
    model.train()
    loss_fn = nn.CTCLoss()
    for x, y, l, targ_len, _ in train_loader:
        x = x.float().to(device)
        y = y.long().to(device)
        l = l.int().cpu()
        targ_len = targ_len.int().cpu()
#         print('in loop', x.shape)
        emissions, lengths = model(x, l)
        emissions = F.log_softmax(emissions, dim=-1)

        optimizer.zero_grad()
#         print(emissions.shape, y.shape)
        loss = loss_fn(emissions, y, lengths, targ_len)
        total_loss += loss.item()
        total_samps += x.shape[0]
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(),1)
        optimizer.step()

    epoch_loss = total_loss/total_samps
    return epoch_loss, model

def test_f(model, test_loader, optimizer, device, greedy, beam_search_decoder, texts, tokens, print_greedy=False, verbose=True, printall=False,
          print_hypo=False, text_direct=False):
    """

    Tests model loss,
        if the print_greedy flag is true, then it calculates wer/cer/per

    returns updated model, loss, and metrics, if they were evaluated,
        otherwsie none.

    Inputs:
    Greedy =greedy ctc deocder
    beam_search_decoder = torchaudio ctc decoder
    print_greedy = True  -
        Then wer, cer, per are evaluated
    verbose = True -
        Then we print out predictions to get a flavor of what was going on :D
    texts - the ground truth text. May not need for audio,m then put in None.

    """
    model.eval()
    loss_fn = nn.CTCLoss()
    with torch.no_grad():
        total_loss, total_samps = 0, 0
        total_wer = 0
        total_cer=0
        total_gcer = 0
        total_per = 0
        gts, transcripts = [], []
        for x, y, l, targ_len, gtsent in test_loader:
            x = x.float().to(device)
            y = y.long().to(device)
            l = l.int().to(device)
            if not texts is None:
                gtsent = gtsent.long().cpu().numpy()
            else:
                gtsent = gtsent
            targ_len = targ_len.int().to(device)

            emissions, lengths = model(x, l)

            emissions = F.log_softmax(emissions, dim=-1)
            loss = loss_fn(emissions, y, lengths, targ_len)
            total_loss += loss.item()
            total_samps += x.shape[0]


            ### The code to look at the text.  We dont want to do this every trial since
            # it can be expensive.
            if print_greedy:
                 for k , e in enumerate(emissions.permute(1,0, 2)):
                    gt_phones = [tokens[yy] for yy in y.detach().cpu().numpy()[k] if not yy == -1 and not yy == 0]
                    if not texts is None:
                        gt  = texts[int(gtsent[k])]
                    else:
                        gt = gtsent[k]
                        # Get the ground truth text.
                    greedy_cer, gt_, transcript,cer,  wer, per = greedy_beam_wer_cer(e, gt,
                                                                                     gt_phones,
                                                                                     greedy,
                                                                                     beam_search_decoder,
                                                                                    print_hypo)
                    gts.append(gt_)
                    transcripts.append(transcript)
                    total_wer += wer
                    total_cer += cer
                    total_per += per
                    total_gcer += greedy_cer



        if print_greedy and verbose:
            print('net wer', total_wer/total_samps)
            print('net cer', total_cer/total_samps)
            print('greedy per', total_gcer/total_samps)
            print('beam per', total_per/total_samps)

            net_wer = total_wer/total_samps
            net_cer = total_cer/total_samps
            greedy_per = total_gcer/total_samps
            beam_per = total_per/total_samps

        ctr =0
        randomsent = np.arange(len(gts))
        np.random.shuffle(randomsent)
        for k in randomsent:
            gt = gts[k]
            trans = transcripts[k]
            wer = (torchaudio.functional.edit_distance(gt.split(' '), trans.split(' '))/len(gt.split(' ')))
            cer = (torchaudio.functional.edit_distance(gt, trans)/len(gt))
            if verbose:
                print('gt', gt)
                print('t', trans, 'wer: %.2f' %wer, 'cer: %.2f' %cer)
            ctr+=1
            if ctr > 10 and not printall:
                break
        epoch_loss = total_loss/total_samps
        # Currently I dont eval wer/cer every time because it is a bit computationally expesnsive.
        if print_greedy:
            return epoch_loss, model, total_wer/total_samps, total_cer/total_samps, total_per/total_samps
        else:
            return epoch_loss, model, total_wer/total_samps, total_cer/total_samps, total_per/total_samps


def train_loop(model, train_loader, test_loader, optimizer, device,
               texts, greedy, beam_search_decoder, tokens,
               patience=10, start_eval=50,
              wandb_log=False, max_epochs=100 ,wercalcrate=3,
              checkpoint_dir=None, train=True, printall=False,
              print_hypo=False, text_direct=False):
    """
    Train (and evaluate) the CTC model

    Non self-explanatory inputs:
    Device = cuda or cpu
    texts = list of the ground truth text
    greedy= greedyCTCDecoder
    beamsearchdecoder = torch ctc deocder
    patience = how long to wait for training to improve. We will do patience based on wer could be made adjustable
    start_eval - what epoch do we want to start evlauting WER/PER.
    wandb_log - log results on wandb.
    wercalcrate  - how often to calculate the wer
    """

    best_wer = np.inf
    best_model = None
    patience_ctr = 0

    for epoch in range(max_epochs):
        if train:
            tr_loss, model = train_f(model, train_loader, optimizer, device)
        else:
            tr_loss = np.nan
        te_loss, model, wer, cer, per = test_f(model, test_loader, optimizer,
                                               device, greedy,  beam_search_decoder, texts, tokens,
                                               print_greedy=(epoch%wercalcrate==0 and epoch >= start_eval), printall=printall,
                                               print_hypo=print_hypo, text_direct=text_direct)
        print('epoch', epoch, 'tr loss: %.3f' %tr_loss, 'te_loss: %.3f' %te_loss, flush=True)
        if not wer is None:
            if wandb_log:
                import wandb
                wandb.log({
                    'tr_loss':tr_loss,
                    'te_loss':te_loss,
                    'wer':wer,
                    'cer':cer,
                    'per':per,
                    'patience_ctr':patience_ctr,
                    'best_wer':min(1, best_wer)
                })
            if wer < best_wer:
                best_wer = wer
                patience_ctr = 0
                best_model = copy.deepcopy(model)
                if not checkpoint_dir is None:
                    if wer < .85:
                        import os
                        torch.save(model.state_dict(), os.path.join(checkpoint_dir, str(wandb.run.name) + f".pth"))
            else:
                patience_ctr +=1

        if patience_ctr > patience:
            break

    return best_model, wer, cer, per

#### I am Loading and Preprocessing the Data by Modifying Existing Code to Cater Toward My Specific Data (My Work)

In [None]:
# Load a GPU and modify this to be your directories!!!!
!nvidia-smi
!nvcc --version
!which nvcc
curdir = './' # TODO: Change to your current directory
data_dir = './data' # TODO: Change to where the data is stored.
device = 'cuda' # Set to cpu if you dont have a gpu avaiable!
print("Is CUDA available:", torch.cuda.is_available())



Thu Dec  7 04:21:03 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.105.17   Driver Version: 525.105.17   CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA A100-SXM...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   33C    P0    42W / 400W |      0MiB / 40960MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

#### For 50-phrase AAC Set

In [None]:
import numpy as np
import pandas as pd
import argparse
from os.path import join
import torchaudio
import torch
from torchaudio.models import decoder
from torchaudio.models.decoder import download_pretrained_files
print('torch version', torch.__version__)
print('torch audio version', torchaudio.__version__, '>= 0.12.0 needed')
curdir = './'
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from torch.utils.data import DataLoader, TensorDataset
import copy
import wandb

# Change to your data dir.
data_dir = './data/'
device = 'cuda' # Set to cpu if you dont have a gpu avail.

# Set up the experiment, can change the hyperparameters as you see fit or edit for your own models
parser = argparse.ArgumentParser()
parser.add_argument('--decimation',
                   default=6,
                   type=int,
                   help='How much to downsample neural data')
parser.add_argument('--hidden_dim',
                    type=int,
                   default=512,
                   help="how many hid.units in model")
parser.add_argument('--lr',
                    type=float,
                   default=1e-3,
                   help='learning rate')
parser.add_argument('--ks',
                    type=int,
                   default=2,
                   help='ks of input conv')
parser.add_argument('--num_layers',
                   type=int,
                   default=3,
                   help='number of layers')
parser.add_argument('--dropout',
                   type=float,
                   default=0.6,
                   help='dropout amount')
parser.add_argument('--feat_stream',
                   type=str,
                   default='both',
                   help='which stream. both, hga, or raw')
parser.add_argument('--bs',
                   type=int,
                   default=64,
                   help='batch size')
parser.add_argument('--smooth',
                   type=int,
                   default=0,
                   help='how much smoothing to apply.')
parser.add_argument('--no_normalize',
                   action='store_false',
                   help='Normalize the neural data or not')
parser.add_argument('--LM_WEIGHT',
                   help='how much the LM is weighted during beam search',
                   type=float,
                   default=3.23)
parser.add_argument('--WORD_SCORE',
                   help='word insertion score for beam',
                    type=float,
                    default=-.26
                   )
parser.add_argument('--beam_width',
                   help='beam size to use',
                   type=int,
                   default=100)
parser.add_argument('--checkpoint_dir',
                   help='where 2 save model',
                   type=str,
                   default=None)
parser.add_argument('--feedforward',
                   help='no bidirectional',
                   action='store_true')
parser.add_argument('--pretrained',
                   help='path to a pretrained model to load',
                   type=str,
                   default=None)
parser.add_argument('--train_amt',
                   help='amt of train data to use',
                   type=float,
                    default=1.0)
parser.add_argument('--samples_to_trim',
                   help='num samps back to go (to shorten window)',
                   default=0,
                   type=int)
parser.add_argument('--ndense',
                   help='Use a different number of classes for a transfer model (useful for 50 phrase transfer.)',
                   default=40,
                   type=int)
parser.add_argument('--normalization_strategy',
                    help='how to normalize the data',
                    type=str,
                    default='typical'),
parser.add_argument('--transfer_audio',
                   help='true if transfer audio. then switch conv',
                   action='store_true')
parser.add_argument('--num_50',
                   help='only used for 500 phrases',
                   type=int)

torch version 2.1.0+cu118
torch audio version 2.1.0+cu118 >= 0.12.0 needed


_StoreAction(option_strings=['--num_50'], dest='num_50', nargs=None, const=None, default=None, type=<class 'int'>, choices=None, required=False, help='only used for 500 phrases', metavar=None)

#### Put In Arguments and Implement Into WandB. **NOTE: You must make a personal Wandb account to continue. It will ask for a personalized authorization token.

In [None]:
exp_str = exp_str = '--hidden_dim 512 --ks 2 --dropout 0.6 --num_layers 3 --num_50 500 --samples_to_trim 0'
# Now we parse the arguments... take out the ' --train_amt 1.0 when you're actually running the script in python.
args = vars(parser.parse_args(exp_str.split()))
wandb.init(project='pub_code',
          config=args)

VBox(children=(Label(value='0.002 MB of 0.002 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
best_wer,█▇▆▆▆▅▅▅▅▅▅█▇▁▁▁▁▁▁▁▁▁▁▁█▇▁▁▁▁▁▁▁▁▁▁▁█
cer,█▆▆▆▅▅▅▅▅▅▅█▁▁▆▁▁▆▁▁▆▁▁▅█▁▁▇▁▁▆▁▁▆▁▁▅▇
num_samples,▁▁▁
patience_ctr,▁▁▁▁▁▁▁▂▂▁▁▁▁▁▂▂▃▄▅▅▆▇▇█▁▁▁▂▂▃▄▅▅▆▇▇█▁
per,█▇▆▆▆▅▆▅▅▅▅█▁▁▆▁▁▆▁▁▆▁▁▅█▁▁▇▁▁▆▁▁▆▁▁▆▇
te_loss,█▃▃▂▁▁▁▁▁▁▁▇▅▄▃▃▂▂▂▂▂▂▁▁█▅▃▃▃▂▂▂▂▂▂▂▂▃
tr_loss,█▃▂▂▂▁▁▁▁▁▁█▄▃▃▃▂▂▂▂▂▂▂▂█▄▃▃▃▂▂▂▂▂▂▂▂
wer,█▇▆▆▆▅▆▆▅▅▅█▁▁▇▁▁▆▁▁▆▁▁▆█▁▁▇▁▁▆▁▁▆▁▁▆▇

0,1
best_wer,1.0
cer,0.74736
num_samples,8600.0
patience_ctr,0.0
per,0.61646
te_loss,0.02611
tr_loss,
wer,0.84577


In [None]:
# Lets load in the labels first and take a look. Its saved as a dataframe.
# key values are the ph_label  - phone label
# and txt_label, the text
# labels = pd.read_hdf(join(data_dir, 'training_labels.h5'))
# labels_te = pd.read_hdf(join(data_dir, "training_labels_test.h5"))

labels = train_transcriptions
labels_te = test_transcriptions

# # Check no test sents in training data :D
# for l in labels_te['txt_label'].values:
#     assert not l in labels['txt_label'].values

# labels = pd.concat((labels, labels_te), ignore_index=True)
all_labels = labels + labels_te

In [None]:
# Next lets load in the neural data.
X = padded_train_features
X_te = padded_test_features

X = np.concatenate((X, X_te), axis=0)
print('train samples', X.shape[0] - X_te.shape[0])
print('test samples', X_te.shape[0])

# Assuming all_labels is a list of labels corresponding to each array in X_combined
# Extract words from labels
all_words = set()
for label in all_labels:
    all_words.update(label.split(' '))

all_words.remove('')
# Print all unique words
print(all_words)
print(f'Number of total words {len(all_words)}')

train samples 8800
test samples 880
Number of total words 6875


In [None]:
# Lets normalize the data, can use a variety of options.
# We just noramlize across time so that the 2 norm across time is equal to 1.
if args['normalization_strategy'] == 'typical':
    X[:, :, :X.shape[-1]//2] = normalize(X[:, :, :X.shape[-1]//2])
    X[:, :, X.shape[-1]//2:] = normalize(X[:, :, X.shape[-1]//2:])
elif args['normalization_strategy'] == 'norm_all_at_once':
    X= normalize(X)
elif args['normalization_strategy'] == 'norm_times':
    X = normalize(X, axis=0)
elif args['normalization_strategy'] == 'zero_to_one':
    X = minmax_scaling(X)
elif args['normalization_strategy'] == 'rezscore':
    X = rezscore(X)
elif args['normalization_strategy'] == 'none':
    print('no normalization.')
elif args['normalization_strategy'] == 'pertrial_minmax':
    X = pertrial_minmax(X)


# Can edit this to be just one feature stream if you'd like to see effect of using hga vs hga + raw.
if args['feat_stream'] == 'hga':
    X = X[:, :, :X.shape[-1]//2]
elif args['feat_stream'] == 'raw':
    X = X[:, :, X.shape[-1]//2:]
print('final X shape', X.shape)

final X shape (9680, 919, 256)


### Now, we need to prepare labels for the ctc_decoding and decoder trainings

In [None]:
import sys
sys.path.append('./')

In [None]:
print(sys.path)

['/content', '/env/python', '/usr/lib/python310.zip', '/usr/lib/python3.10', '/usr/lib/python3.10/lib-dynload', '', '/usr/local/lib/python3.10/dist-packages', '/usr/lib/python3/dist-packages', '/usr/local/lib/python3.10/dist-packages/IPython/extensions', '/root/.ipython', './']


In [None]:
# we will rely on the torchaudio ctc_decoder module, which you can read more about here: https://pytorch.org/audio/main/tutorials/asr_inference_with_ctc_decoder_tutorial.html
from torchaudio.models.decoder import ctc_decoder
from torchaudio.functional import edit_distance

In [None]:
# Gather all phonemes and put it in a unique list
phonemes = padded_train_phonemes + padded_test_phonemes
all_ph = []
for p_list in phonemes:
  for p in p_list:
    if not p == '|' and not p == 'SIL':
      all_ph.append(p)

print(all_ph)

['T', 'R', 'AY', 'N', 'AA', 'T', 'T', 'UW', 'Y', 'UW', 'Z', 'EH', 'N', 'IY', 'IH', 'N', 'S', 'EH', 'K', 'T', 'AH', 'S', 'AY', 'D', 'Z', 'AE', 'T', 'AO', 'L', 'JH', 'OY', 'N', 'DH', 'AH', 'F', 'EY', 'S', 'B', 'UH', 'K', 'F', 'AE', 'N', 'P', 'EY', 'JH', 'DH', 'EY', 'JH', 'AH', 'S', 'T', 'M', 'UW', 'V', 'D', 'IH', 'N', 'T', 'UW', 'DH', 'AH', 'N', 'UW', 'B', 'IH', 'L', 'D', 'IH', 'NG', 'AH', 'B', 'IH', 'G', 'M', 'EH', 'T', 'R', 'AH', 'P', 'AA', 'L', 'AH', 'T', 'AH', 'N', 'EH', 'R', 'IY', 'AH', 'W', 'AH', 'T', 'IH', 'Z', 'DH', 'AE', 'T', 'R', 'EH', 'S', 'T', 'ER', 'AA', 'N', 'T', 'Y', 'UW', 'M', 'AY', 'T', 'W', 'AA', 'N', 'T', 'T', 'UW', 'AH', 'T', 'AE', 'K', 'AY', 'W', 'AA', 'Z', 'R', 'IH', 'L', 'IY', 'SH', 'AA', 'K', 'T', 'N', 'AY', 'S', 'HH', 'AW', 'S', 'AH', 'N', 'D', 'AO', 'L', 'DH', 'AE', 'T', 'IH', 'T', 'S', 'R', 'IH', 'L', 'IY', 'N', 'AY', 'S', 'T', 'UW', 'G', 'OW', 'AH', 'N', 'D', 'S', 'IY', 'DH', 'EH', 'M', 'DH', 'EY', 'HH', 'AE', 'V', 'AO', 'L', 'R', 'EH', 'D', 'IY', 'IH', 'K', '

In [None]:
# Clean up the labels and remove stress marking from phones.
phone_enc  = {v:k for k,v in enumerate(sorted([a for a in list(set(all_ph)) if not a == '|']))}
print(phone_enc)

{'AA': 0, 'AE': 1, 'AH': 2, 'AO': 3, 'AW': 4, 'AY': 5, 'B': 6, 'CH': 7, 'D': 8, 'DH': 9, 'EH': 10, 'ER': 11, 'EY': 12, 'F': 13, 'G': 14, 'HH': 15, 'IH': 16, 'IY': 17, 'JH': 18, 'K': 19, 'L': 20, 'M': 21, 'N': 22, 'NG': 23, 'OW': 24, 'OY': 25, 'P': 26, 'R': 27, 'S': 28, 'SH': 29, 'T': 30, 'TH': 31, 'UH': 32, 'UW': 33, 'V': 34, 'W': 35, 'Y': 36, 'Z': 37, 'ZH': 38}


In [None]:
# Now lets set up a lexicon to map from each word to any valid pronounciation.
from collections import defaultdict
lex = {}

for k,v in zip(all_labels, phonemes):
    if not '|' in v and not 'SIL' in v:
        lex[k] = (' '.join(v) + ' |')
    else:
        v  = '_'.join(v)
        v = v.split('|')
        k = k.replace('  ', ' ')
        for kk, vv in zip(k.split(' '), v):
          vv = vv.replace('_', ' ').strip() + ' |'
          if not kk == '':
              lex[kk] = vv
print(lex)



In [None]:
index = 0
for idx, sentence in enumerate(all_labels):
  if 'congregation' in sentence:
    index = idx
    print(index)
print(all_labels[index])
print(phonemes[index])

8189
why lacerate the  congregation
['W', 'AY', '|', 'L', 'AE', 'S', 'ER', 'EY', 'T', '|', 'DH', 'AH', '|', 'K', 'AA', 'NG', 'G', 'R', 'AH', 'G', 'EY', 'SH', 'AH', 'N', '|', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 'SIL', 

In [None]:
# Write the lexicon to a text file for the ctc decoder, and prepare it.
strings = []
for k, v in lex.items():
    string = k + ' ' + str(v)
    strings.append(string)

curdir = '/content/'
strings =  [s for s in strings if len(s) > 3]

output_dir = os.path.join(curdir, "for_ctc")
if not os.path.exists(output_dir):
  os.makedirs(output_dir)

with open(os.path.join(curdir, "for_ctc/lexicon_phrases_1k.txt"), "w+") as f:
    f.writelines([s+ '\n' for s in strings])

print('example lexicon items')
for s in strings[:5]:
    print(s)
print('vocabulary size:', len(strings))

example lexicon items
try T R AY |
not N AA T |
to T UW |
use Y UW Z |
any EH N IY |
vocabulary size: 6875


In [None]:
# Prepare the tokens for the ctc decoder.
tokens = ['-', '|'] + list(phone_enc.keys())
with open(join(curdir, 'for_ctc/tokens_phrases_1k.txt'), 'w+') as f:
    f.writelines([t + '\n' for t in tokens])
enc_final = {v:k for k,v in enumerate(tokens)} # Use to map phone to y labels
print(tokens)
print(enc_final)

['-', '|', 'AA', 'AE', 'AH', 'AO', 'AW', 'AY', 'B', 'CH', 'D', 'DH', 'EH', 'ER', 'EY', 'F', 'G', 'HH', 'IH', 'IY', 'JH', 'K', 'L', 'M', 'N', 'NG', 'OW', 'OY', 'P', 'R', 'S', 'SH', 'T', 'TH', 'UH', 'UW', 'V', 'W', 'Y', 'Z', 'ZH']
{'-': 0, '|': 1, 'AA': 2, 'AE': 3, 'AH': 4, 'AO': 5, 'AW': 6, 'AY': 7, 'B': 8, 'CH': 9, 'D': 10, 'DH': 11, 'EH': 12, 'ER': 13, 'EY': 14, 'F': 15, 'G': 16, 'HH': 17, 'IH': 18, 'IY': 19, 'JH': 20, 'K': 21, 'L': 22, 'M': 23, 'N': 24, 'NG': 25, 'OW': 26, 'OY': 27, 'P': 28, 'R': 29, 'S': 30, 'SH': 31, 'T': 32, 'TH': 33, 'UH': 34, 'UW': 35, 'V': 36, 'W': 37, 'Y': 38, 'Z': 39, 'ZH': 40}


In [None]:
beam_search_decoder= ctc_decoder(
    lexicon = os.path.join(curdir, 'for_ctc/lexicon_phrases_1k.txt'),
    tokens = os.path.join(curdir, 'for_ctc/tokens_phrases_1k.txt'),
    lm = os.path.join(curdir, 'multimodal-decoding/text/custom_lms/full_corpus_lm_3_abs_slm.binary'),
    nbest=3,
    beam_size=args['beam_width'],
    lm_weight=args['LM_WEIGHT'],
    word_score=args['WORD_SCORE'],
    sil_token = '|',
    unk_word = '<unk>',
)

greedy_decoder = GreedyCTCDecoder(tokens)
greedy = GreedyCTCDecoder(labels=list(enc_final.keys()))

# Prepare neural and target data for CTC loss
y_final = []
for t, targ in zip(all_labels, phonemes):
    cur_y = []
    cur_y.append(enc_final['|'])
    for ph in targ:
      if ph != "SIL":
        cur_y.append(enc_final[ph])
    cur_y.append(enc_final['|'])
    y_final.append(cur_y)

# Pad with -1
y_final_ = -1*np.ones((len(y_final), np.max([len(y) for y in y_final])))

# Store the lengths
targ_lengths =[]
for k, y in enumerate(y_final):
    y_final_[k, :len(y)] = np.array(y)
    targ_lengths.append(len(y))
targ_lengths = np.array(targ_lengths)
Y = y_final_

lens = all_lengths
#lens = [(l//args['decimation']) for l in all_lengths] # Adjust lengths based on decimation.

# Finalize the lengths.
outlens = targ_lengths
lens = np.array(lens)
lens = lens - args['samples_to_trim']
# Adjust lenghts to be correct.
lens = [min(l, X.shape[1]) for l in lens]
lens = np.array(lens)
print(lens)

[419 279 322 ... 319 249 235]


In [None]:
X.shape[0]-249

9431

In [None]:
print(X.shape)

(9680, 919, 256)


# What to expect from training:

Model will run for around 100 epochs. There is some variability in performance based on initializations. We interrupted training here for the sake of time

We are using a 3-gram LM and small beam search here - as a result accuracy is LOWER during training, than when you test on the final set

The WER should end around 35-42% prior to testing on the realtime blocks

Using the 5-gram LM and larger beam search will help get the WER much lower when we evaluate on the realtime blocks, below 30% without the blocking.

Note: The saved output here was interrupted for the sake of time

In [None]:
# Set up cv folds. We train on 95% of the data, test on heldout 5%
print(X.shape, Y.shape)
te_len = X_te.shape[0]
gt_text = all_labels
trainsets = []
inds = np.arange(len(X))
np.random.seed(1337)
np.random.shuffle(inds)
for k in range(10):
    te_inds = sorted(inds)[-te_len:] #[k*(len(inds)//20): (k+1)*(len(inds)//20)]
    tr_inds = [i for i in inds if not i in te_inds]
    val_inds = tr_inds[-200:]
    tr_inds = [t for t in tr_inds if not t in val_inds]
    trainsets.append((tr_inds, val_inds, te_inds))

### Train the neural network. Every 3 trials the wer/cer are evaluated.


for train, val, test in trainsets:
    # Train test split, plus load into dataset
        # Train test split, plus load into dataset
    train_amt = int(len(train)*args['train_amt'])
    print('num samples', train_amt)
    wandb.log({'num_samples':train_amt})
    train = train[:train_amt]
    print(len(train), train_amt)

    X_tr, X_te, X_v = X[train], X[test], X[val]
    Y_tr, Y_te , Y_v = Y[train], Y[test], Y[val]
    lens_tr, lens_te, lens_v = lens[train], lens[test], lens[val]
    inds_tr, inds_te, inds_v = np.array(train), np.array(test), np.array(val) # for loading text labels.
    outlens_tr, outlens_te, outlens_v = outlens[train], outlens[test], outlens[val]

    # Make datasets
    train_dset = TensorDataset(torch.from_numpy(X_tr.copy()),
                               torch.from_numpy(Y_tr.copy()),
                              torch.from_numpy(lens_tr.copy()),
                              torch.from_numpy(outlens_tr.copy()),
                              torch.from_numpy(inds_tr.copy()))
    test_dset = TensorDataset(torch.from_numpy(X_te.copy()),
                              torch.from_numpy(Y_te.copy()),
                             torch.from_numpy(lens_te.copy()),
                             torch.from_numpy(outlens_te.copy()),
                             torch.from_numpy(inds_te.copy()))
    val_dset = TensorDataset(torch.from_numpy(X_v.copy()),
                              torch.from_numpy(Y_v.copy()),
                             torch.from_numpy(lens_v.copy()),
                             torch.from_numpy(outlens_v.copy()),
                             torch.from_numpy(inds_v.copy()))

    # TODO: Add transforms from torchaudio.transforms
    train_loader = DataLoader(train_dset, batch_size=args['bs'], shuffle=True)
    val_loader = DataLoader(val_dset, batch_size=args['bs'], shuffle=False)
    test_loader = DataLoader(test_dset, batch_size=args['bs'], shuffle=False)

    # Initialize the model.
    if not args['feedforward']:
        if not args['pretrained'] is None:
            n_targ = args['ndense']
        else:
            n_targ=len((enc_final))

        model = FlexibleCnnRnnClassifier(rnn_dim=args['hidden_dim'], KS=args['ks'],
                                         num_layers=args['num_layers'],
                                         dropout=args['dropout'], n_targ=n_targ,
                                  bidirectional=True, in_channels=X_tr.shape[-1])
    else:
        model = FlexibleCnnRnnClassifier(rnn_dim=args['hidden_dim'], KS=args['ks'],
                                 num_layers=args['num_layers'],
                                 dropout=args['dropout'], n_targ=len((enc_final)),
                          bidirectional=False, in_channels=X_tr.shape[-1])

    if not args['pretrained'] is None:
        model.load_state_dict(torch.load(join(curdir, args['pretrained'])))
        model.dense = nn.Linear(2*args['hidden_dim'], len((enc_final)))

        if args['transfer_audio']:
            model.preprocessing_conv = torch.nn.Conv1d(in_channels=X_te.shape[-1],
                                               out_channels=args['hidden_dim'],
                                               kernel_size=args['ks'],
                                               stride=args['ks'])

    optimizer = torch.optim.Adam(model.parameters(), lr=args['lr'])
    model = model.to(device)
    model, wer, cer, per =train_loop(model, train_loader,
                    val_loader,
                      optimizer,
                    device, gt_text, greedy, beam_search_decoder, tokens, start_eval=0,
                     wandb_log=True, checkpoint_dir=args['checkpoint_dir'], max_epochs= 200)

    # Currently we only use one fold for model dev.
    break

(9680, 919, 256) (9680, 80)
num samples 8600
8600 8600
net wer 0.921922258297258
net cer 0.8598880912624384
greedy per 0.5780285627070234
beam per 0.7054507391198415
gt more areas where they would pick it up
t is it wer: 0.88 cer: 0.89
gt what outfit does she drive for
t are you wer: 1.00 cer: 0.83
gt the best you can
t you can wer: 0.50 cer: 0.56
gt the homeowner can't be touched
t i am wer: 1.00 cer: 0.93
gt boy wouldn't i give to get
t i want to wer: 0.67 cer: 0.73
gt let all projects dry slowly for several days
t but wer: 1.00 cer: 0.98
gt have you spoken before
t you know wer: 0.75 cer: 0.68
gt i can hear it in your voice
t i can do wer: 0.71 cer: 0.74
gt why he's going to kill me he thought wildly
t what will you wer: 1.00 cer: 0.77
gt more of the volunteer network service
t the wer: 0.83 cer: 0.92
gt i just grew up in oklahoma
t i am wer: 0.83 cer: 0.85
epoch 0 tr loss: 0.045 te_loss: 0.037
epoch 1 tr loss: 0.027 te_loss: 0.031
epoch 2 tr loss: 0.024 te_loss: 0.027
net wer 0.756

In [None]:
# Print the WER, CER, PER
print(f'WER: {wer}')
print(f'CER: {cer}')
print(f'PER: {per}')

WER: 0.5783245088245086
CER: 0.4791933195143656
PER: 0.4281504110506169


# Run on the realtime test data

In [None]:
# Step 1: Define a new beam size, we found taht using a larger beam size
# is quite helpful here and will lower the word error rate.
beam_search_decoder_final = ctc_decoder(
    lexicon = os.path.join(curdir, 'for_ctc/lexicon_phrases_1k.txt'),
    tokens = os.path.join(curdir, 'for_ctc/tokens_phrases_1k.txt'),
    lm = os.path.join(curdir, 'multimodal-decoding/text/custom_lms/full_corpus_lm_5_abs_slm.binary'),
    nbest=3,
    beam_size=3000,
    lm_weight=4.5,
    word_score=args['WORD_SCORE'],
    sil_token = '|',
    unk_word = '<unk>')

In [None]:
# Step 2: Evaluate on totally UNSEEN sentences
# Run on test set. For early stopping, this model used a different strategy (some predictions were left to end), see paper for more details.
model, wer, cer, per =train_loop(model, train_loader,
                test_loader,
                  optimizer,
                device, gt_text, greedy, beam_search_decoder, tokens, start_eval=0, max_epochs=1,
                 wandb_log=True, checkpoint_dir=args['checkpoint_dir'], train=False)


  result = _VF.gru(input, batch_sizes, hx, self._flat_weights, self.bias,


net wer 0.8220357447254172
net cer 0.7080618074994636
greedy per 0.4874620511663054
beam per 0.6279458036715946
gt the tooth fairy forgot to come when roger's tooth fell out
t that i do wer: 1.00 cer: 0.88
gt the junk mail that you get
t i like that you got wer: 0.67 cer: 0.46
gt those answers will be straightforward if you think them through carefully first
t and what do they know wer: 1.00 cer: 0.82
gt have you ever been drug tested
t i have one wer: 1.00 cer: 0.80
gt draw your own conclusions
t i can wer: 1.00 cer: 0.88
gt on a working farm
t working some wer: 0.75 cer: 0.53
gt i think we have about thirteen months left on it
t i think we have a thing on this one wer: 0.60 cer: 0.42
gt i mean lunch today was eighteen dollars
t me not to wer: 1.00 cer: 0.79
gt you are frequently exploited
t but wer: 1.00 cer: 0.93
gt the only other place i've ever vacationed
t go in wer: 1.00 cer: 0.90
gt i can finish it in a day
t can i see wer: 0.86 cer: 0.75
epoch 0 tr loss: nan te_loss: 0.025


In [None]:
# Print the WER, CER, PER
print(f'WER: {wer}')
print(f'CER: {cer}')
print(f'PER: {per}')

WER: 0.8220357447254172
CER: 0.7080618074994636
PER: 0.6279458036715946
