In [2]:
import transformer_repetition_kit as trk
import random
import numpy as np
import torch

In [3]:
# Set config for this run
ASR_df_filepath = '..\\repetition_data_generation\\data\\generated_data.csv'

config = dict(
    epochs=5,
    batch_size=12,
    learning_rate=0.0005,
    dataset=ASR_df_filepath,
    hid_dim=256,
    enc_layers=4,
    dec_layers=4,
    enc_heads=8,
    dec_heads=8,
    enc_pf_dim=512,
    dec_pf_dim=512,
    enc_dropout=0.1,
    dec_dropout=0.2,
    clip=1,
    bpe_vocab_size=1200,
    decode_trg = False,
    early_stop = 3
)

asr_text_filepath = 'asr.txt'
ttx_text_filepath = 'ttx.txt'
train_filename = 'train_sentence.csv'
valid_filename = 'valid_sentence.csv'
test_filename = 'test_sentence.csv'

In [4]:
SEED = 1234

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

In [5]:
trk.load_data(ASR_df_filepath = ASR_df_filepath,
              train_filename = train_filename,
              valid_filename = valid_filename,
              test_filename = test_filename,
              asr_text_filepath = asr_text_filepath,
              ttx_text_filepath = ttx_text_filepath)

  df.tags = df.tags.str.replace(r'\n', '')


In [6]:
tokenizer = trk.create_train_bpe_tokenizer(config['bpe_vocab_size'],
                                           asr_text_filepath = \
                                           asr_text_filepath,
                                           ttx_text_filepath = ttx_text_filepath,
                                           save_tokenizer = True,
                                           tokenizer_filename = ".\\tokenizer-test.json"
                                          )

In [7]:
train_data, valid_data, test_data, TTX, TRG, ASR = \
    trk.produce_iterators(train_filename,
                          valid_filename,
                          test_filename,
                          tokenizer)

In [8]:
# Test out the tokenizer
output = tokenizer.encode("Hello, y'all! How are you 😁 ? [WSP]")
print(output.tokens)
print(output.ids)

['[UNK]', 'el', 'lo', ',', 'y', "'", 'all', '!', '[UNK]', 'ow', 'are', 'you', '[UNK]', '?', '[', '[UNK]', '[UNK]', '[UNK]', ']']
[0, 687, 179, 12, 56, 8, 217, 5, 0, 514, 164, 293, 0, 28, 29, 0, 0, 0, 30]


In [9]:
for i,t in enumerate(train_data): 
    if i<2: print(t.true_text,'\n',t.asr,'\n',t.tags,'\n')

['we', 'have', 'adopted', 'further', 'measures', 'of', 'a', 'procedural', 'nature'] 
 ['we', 'have', 'adopted', 'further', 'measures', 'of', 'a', 'proced', 'ural', 'nat', 'ure'] 
 ["'O'", "'O'", "'O'", "'O'", "'O'", "'O'", "'RB2'", "'O'", "'O'"] 

['as', 'a', 'result', 'of', 'the', 'crisis', 'in', 'the', 'next', 'few', 'years', 'the', 'situation', 'will', 'only', 'get', 'worse'] 
 ['result', 'of', 'the', 'crisis', 'in', 'the', 'next', 'fe', 'w', 'years', 'the', 'situation', 'will', 'only', 'get', 'next', 'fe', 'w', 'years', 'the', 'situation', 'will', 'only', 'ye', 't', 'next', 'fe', 'w', 'years', 'the', 'situation', 'will', 'only', 'get', 'wor', 'se'] 
 ["'O'", "'O'", "'O'", "'O'", "'O'", "'O'", "'O'", "'O'", "'RB3'", "'RI3'", "'RI3'", "'RI3'", "'RI3'", "'RI3'", "'RI3'", "'RI3'", "'O'"] 



In [10]:
torch.cuda.is_available()

True

In [11]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = torch.device('cpu')

In [12]:
model = trk.model_pipeline(config, 
                           device,
                           train_data,
                           valid_data,
                           test_data,
                           TTX,
                           TRG,
                           ASR
                          )

[34m[1mwandb[0m: Currently logged in as: [33mwitw[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.12.10 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


The model has 13,618,185 trainable parameters


0,1
_runtime,16
_timestamp,1645459141
_step,0


0,1
_runtime,▁
_timestamp,▁
_step,▁


KeyboardInterrupt: 