In [1]:
import logging
import fastwer
import numpy as np
import wandb
import torch.multiprocessing
from transformers import EncoderDecoderConfig, BertConfig
torch.multiprocessing.set_sharing_strategy('file_system')

import pandas as pd
from aamod.seq2seq import (
    Seq2SeqModel,
    Seq2SeqArgs,
)
# from simpletransformersmod.seq2seq import (
#     Seq2SeqModel,
#     Seq2SeqArgs,
# )

logging.basicConfig(level=logging.INFO)
transformers_logger = logging.getLogger("transformers")
transformers_logger.setLevel(logging.WARNING)

model_args = Seq2SeqArgs()
model_args.num_train_epochs = 1
# model_args.no_save = True
model_args.evaluate_generated_text = True
model_args.evaluate_during_training = True
model_args.evaluate_during_training_verbose = False
model_args.tensorboard_dir = "runs"
model_args.max_length = 50
model_args.train_batch_size=25
model_args.overwrite_output_dir=True
model_args.wandb_project = "cs224u"
model_args.use_multiprocessed_decoding = True
model_args.cache_dir = "./cache_dir/"
model_args.eval_batch_size = 25

config_encoder = BertConfig()
# config_decoder = BertConfig(is_decoder=True, add_cross_attention=True)
config_decoder = BertConfig()
config_decoder.is_decoder = True
config_decoder.add_cross_attention = True
config = EncoderDecoderConfig.from_encoder_decoder_configs(config_encoder, config_decoder)
# config.use_return_dict = False
encoder_decoder_name = "characterbert"
# model = EncoderDecoderModel(config=config)
model = Seq2SeqModel(
     encoder_decoder_type="characterbert",
     encoder_name="bert-base-uncased",
     decoder_name="bert-base-uncased",
     args=model_args,
     config=config,
     use_cuda=True,
)


# model = Seq2SeqModel(encoder_decoder_type="bart", encoder_decoder_name="google/roberta2roberta_L-24_cnn_daily_mail", args=model_args, use_cuda=True,)

# model = Seq2SeqModel(
#     encoder_type="bert",
#     encoder_name="bert-base-uncased",
#     decoder_name="bert-base-uncased",
#     args=model_args,
#     use_cuda=True,
# )


def count_matches(labels, preds):
    return sum(
        [
            1 if label == pred else 0
            for label, pred in zip(labels, preds)
        ]
    )

def get_wer(labels, preds):
    return np.mean(
        [
            fastwer.score_sent(pred, label)
            for label, pred in zip(labels, preds)
        ]
    )

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertLMHeadModel: ['cls.seq_relationship.bias', 'cls.seq_relations

In [1]:
import pandas as pd

train_df = pd.read_pickle("train.pkl")
train_df = train_df.dropna()
dev_df = pd.read_pickle("dev.pkl")

In [2]:
train_df = train_df.head(200)

In [3]:
train_df

Unnamed: 0,input_text,target_text
0,she beowright,She'll be all right.
1,six,six
2,alswell that ends well,All's well that ends well.
3,bo bed,Do you mean it?
4,new petches less invasing then the oma the o c...,The new patch is less invasive than the old on...
...,...,...
195,woman er por tar mones momant my tete eventoli...,"You might hear ""font families"" more than ""type..."
196,then men gliped white bertyes righten in a boo...,A man with a white beard is writing in a book ...
197,janlewas now along,Charly was now alone.
198,i had oughd my better otassis with latics,I wrote my bachelor thesis with latex.


In [4]:
dev_df = dev_df.head(200)

In [5]:
# Train the model
# wandb.init(project='cs224u', entity='gbanerje')

# # 2. Save model inputs and hyperparameters
# config = wandb.config
# config.learning_rate = 0.01

# Model training here

model.train_model(
    train_df, eval_data=dev_df, matches=count_matches, wer=get_wer, show_running_loss=True, args={'fp16': False}
)

# wandb.join()

# # Evaluate the model
results = model.eval_model(dev_df)

INFO:aamod.seq2seq.seq2seq_utils: Creating features from dataset file at ./cache_dir/


HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))

INFO:aamod.seq2seq.seq2seq_utils: Saving features into cached file ./cache_dir/bert-base-uncased-bert-base-uncased_cached_128200
INFO:aamod.seq2seq.seq2seq_model: Training started





HBox(children=(FloatProgress(value=0.0, description='Epoch', max=1.0, style=ProgressStyle(description_width='i…

[34m[1mwandb[0m: Currently logged in as: [33mgbanerje[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.10.31 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


HBox(children=(FloatProgress(value=0.0, description='Running Epoch 0 of 1', max=8.0, style=ProgressStyle(descr…

INFO:aamod.seq2seq.seq2seq_model:Saving model into outputs/checkpoint-8-epoch-1





INFO:aamod.seq2seq.seq2seq_utils: Creating features from dataset file at ./cache_dir/


HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))

INFO:aamod.seq2seq.seq2seq_utils: Saving features into cached file ./cache_dir/bert-base-uncased-bert-base-uncased_cached_128200





HBox(children=(FloatProgress(value=0.0, description='Generating outputs', max=8.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Decoding outputs', max=200.0, style=ProgressStyle(descrip…














INFO:aamod.seq2seq.seq2seq_model:Saving model into outputs/checkpoint-8-epoch-1
INFO:aamod.seq2seq.seq2seq_model:Saving model into outputs/best_model
INFO:aamod.seq2seq.seq2seq_model:Saving model into outputs/





INFO:aamod.seq2seq.seq2seq_model: Training of bert-base-uncased-bert-base-uncased model complete. Saved to outputs/.
INFO:aamod.seq2seq.seq2seq_utils: Creating features from dataset file at ./cache_dir/


HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))

INFO:aamod.seq2seq.seq2seq_utils: Saving features into cached file ./cache_dir/bert-base-uncased-bert-base-uncased_cached_128200





HBox(children=(FloatProgress(value=0.0, description='Running Evaluation', max=8.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Generating outputs', max=8.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Decoding outputs', max=200.0, style=ProgressStyle(descrip…




INFO:aamod.seq2seq.seq2seq_model:{'eval_loss': 7.593917489051819, 'matches': 0, 'wer': 100.0}
