In [1]:
import transformer_repetition_kit as trk
import random
import numpy as np
import torch

In [2]:
# Set config for this run
ASR_df_filepath = '../repetition_data_generation/data/rep_audio.csv'

config = dict(
    epochs=5,
    batch_size=128,
    learning_rate=0.01,
    dataset=ASR_df_filepath,
    hid_dim=256,
    enc_layers=4,
    dec_layers=4,
    enc_heads=8,
    dec_heads=8,
    enc_pf_dim=512,
    dec_pf_dim=512,
    enc_dropout=0.1,
    dec_dropout=0.2,
    clip=1,
    bpe_vocab_size=1600,
    decode_trg = True,
    early_stop = 6
)

asr_text_filepath = 'asr.txt'
ttx_text_filepath = 'ttx.txt'
train_filename = 'train_sentence.csv'
valid_filename = 'valid_sentence.csv'
test_filename = 'test_sentence.csv'

In [3]:
SEED = 1234

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

In [4]:
#scr
import pandas as pd
df = pd.read_csv(ASR_df_filepath, names=['',
                                             'audio_path',
                                             'asr_transcript',
                                             'original_text',
                                             'mutated_text',
                                             'index_tags',
                                             'err_tags'], header=None, index_col='')

In [5]:
trk.load_data(ASR_df_filepath = ASR_df_filepath,
              train_filename = train_filename,
              valid_filename = valid_filename,
              test_filename = test_filename,
              asr_text_filepath = asr_text_filepath,
              ttx_text_filepath = ttx_text_filepath)

In [6]:
tokenizer = trk.create_train_bpe_tokenizer(config['bpe_vocab_size'],
                                           asr_text_filepath = \
                                           asr_text_filepath,
                                           ttx_text_filepath = ttx_text_filepath,
                                           save_tokenizer = True,
                                           tokenizer_filename = "./tokenizer-test.json"
                                          )






In [7]:
train_data, valid_data, test_data, TTX, TRG, ASR = \
    trk.produce_iterators(train_filename,
                          valid_filename,
                          test_filename,
                          asr_tokenizer=tokenizer,
                          ttx_tokenizer=tokenizer
                         )

In [8]:
# Test out the tokenizer
output = tokenizer.encode("Hello, y'all! How are you 😁 ? [WSP]")
print(output.tokens)
print(output.ids)

['H', 'el', 'lo', ',', 'y', "'", 'all', '!', 'H', 'ow', 'are', 'you', '[UNK]', '?', '[', 'W', 'S', 'P', ']']
[32, 803, 294, 13, 79, 8, 329, 5, 32, 592, 275, 416, 0, 23, 51, 47, 43, 40, 53]


In [9]:
for i,t in enumerate(train_data): 
    if i<2: print(t.true_text,'\n',t.asr,'\n',t.tags,'\n')

['m', 'ut', 'ated', '_', 'text'] 
 ['or', 'ig', 'in', 'al', '_', 'text'] 
 [''] 

['member', 'ship', 'of', 'parliament', ':', 'see', ',,,', 'of', 'parliament', ':', 'see', 'min', 'utes', ',,,', 'min', 'utes'] 
 ['Member', 'ship', 'of', 'Parliament', ':', 'see', 'M', 'in', 'utes'] 
 [''] 



In [10]:
torch.cuda.is_available()

True

In [11]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = torch.device('cpu')

In [15]:
model = trk.model_pipeline(config, 
                           device,
                           train_data,
                           valid_data,
                           test_data,
                           TTX,
                           TRG,
                           ASR
                          )

[34m[1mwandb[0m: Currently logged in as: [33mwitw[0m (use `wandb login --relogin` to force relogin)
  warn("The `IPython.html` package has been deprecated since IPython 4.0. "
[34m[1mwandb[0m: wandb version 0.12.1 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


The model has 10,145,289 trainable parameters
Loss after 00400 examples: 1.573
Loss after 00800 examples: 1.548
Loss after 01200 examples: 1.528
Loss after 01600 examples: 0.988
Loss after 02000 examples: 1.082

RuntimeError! Skipping this batch, using previous loss as est

Loss after 02400 examples: 1.127

RuntimeError! Skipping this batch, using previous loss as est

Loss after 02800 examples: 0.976
Loss after 03200 examples: 1.013
Loss after 03600 examples: 0.957
Loss after 04000 examples: 0.970
Loss after 04400 examples: 1.130
Loss after 04800 examples: 1.114
Loss after 05200 examples: 1.068
Loss after 05600 examples: 1.013
Loss after 06000 examples: 1.047

RuntimeError! Skipping this batch, using previous loss as est

Loss after 06400 examples: 1.031
Loss after 06787 examples: 1.023

RuntimeError! Skipping this batch, using previous loss as est

Loss after 07187 examples: 1.030
Loss after 07587 examples: 0.944
Loss after 07987 examples: 1.012
Epoch: 01 | Time: 2m 4s
	Train Loss: 1

0,1
_runtime,139.0
_timestamp,1630521793.0
_step,8387.0
epoch,1.0
loss,1.09529


0,1
_runtime,▁▁▂▂▂▃▃▃▃▄▄▄▅▅▅▆▆▆▇▇██
_timestamp,▁▁▂▂▂▃▃▃▃▄▄▄▅▅▅▆▆▆▇▇██
_step,▁▁▂▂▂▃▃▃▄▄▄▅▅▅▆▆▆▇▇▇██
epoch,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█
loss,██▇▁▃▃▁▂▁▁▃▃▂▂▂▂▂▂▁▂▃


KeyboardInterrupt: 