## Create training datasets for decoder

In [None]:
import sys
import os
sys.path.append('../../')
from utils.lmutils import cer_with_lm_decoder, build_lm_decoder

lmDir = '../../LanguageModels/Typing5gramPunc'

ngramDecoder = build_lm_decoder(
    lmDir,
    max_active=7000,
    min_active=200,
    beam=17,
    lattice_beam=8,
    acoustic_scale=1.0, #lower means weight LM more highly
    ctc_blank_skip_threshold=1.0,
    length_penalty=0.0,
    nbest=100)

In [None]:
import sys
sys.path.append('../../')
from pathlib import Path
import numpy as np
from T5_SessionArgs_comp import get_session_info
import utils.t5_handwriting_mat_to_tfrecord

sessions = [
    't5.2019.12.09',
    't5.2019.12.11',
    't5.2019.12.18',
    't5.2019.12.20',
    't5.2020.01.06',
    't5.2020.01.08',
    't5.2020.01.13',
    't5.2020.01.15',
     ]

participant = 't5'
bin_compression_factor = 2
channels_to_exclude = list(range(0,0)) 
channels_to_zero = list(range(0,0)) #[] # leave empty to not zero anything

for session in sessions:
    
    trials_to_remove, block_nums, num_test_trials = get_session_info(session)
    session_path = str(Path('../../Data', participant, session))
    tfdata_path = str(Path(session_path, 'tfdata_20ms'))

    print(f'Sesison path: {session_path}')
    print(f'tfdata path: {session_path}')
    print('\n')

    args = {
        'session_mat_path': session_path,
        'block_nums': block_nums,
        'num_test_trials': num_test_trials,
        'trials_to_remove': trials_to_remove,
        'channels_to_exclude': channels_to_exclude,
        'channels_to_zero': channels_to_zero,
        'include_thresh_crossings': True,
        'include_spike_power': False,
        'spike_pow_max': 50000,
        'z_score_data': True,
        'global_std': True,
        'bin_compression_factor': bin_compression_factor,
        'save_path': tfdata_path,
    }

    utils.t5_handwriting_mat_to_tfrecord.main(args)

## Train the decoder. Remember to Restart the notebook first!

In [None]:
all_CERs = [0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7]

import sys
import gc
sys.path.append('../../')
import numpy as np
import importlib
from utils.lmutils import cer_with_lm_decoder
from omegaconf import OmegaConf

from utils.brainToText_trainDecoder import brainToText_decoder
args = OmegaConf.load('T5_trainArgs_comp.yaml')

for target_cer in all_CERs:
    args['outputDir'] = 'T5_LM_comparison_models/' + str(target_cer) + 'CER'
    args['EndCER'] = target_cer
    
    decoder = brainToText_decoder(args)
    infOut, stats = decoder.train()
    out, out_by_day = decoder.inference()
    decoder_out = cer_with_lm_decoder(ngramDecoder, out, blankPenalty=1, rescore=False)
    print('Target CER: ', str(target_cer))
    print('Val CER post LM: ', decoder_out['cer'])
    print('Val WER post LM: ',decoder_out['wer'])
    for d,t in zip(decoder_out['true_transcripts'], decoder_out['decoded_transcripts']):
        print('True :' ,d, ', Decoded: ', t)
        
    np.save('T5Train-' + str(target_cer) + '-WERPostLM-.npy',decoder_out['all_wer'])
    decoder = None
    gc.collect()