In [4]:
import logging
import wandb
import torch.multiprocessing
torch.multiprocessing.set_sharing_strategy('file_system')
import numpy as np
import fastwer

import pandas as pd
from simpletransformers.seq2seq import (
    Seq2SeqModel,
    Seq2SeqArgs,
)


logging.basicConfig(level=logging.INFO)
transformers_logger = logging.getLogger("transformers")
transformers_logger.setLevel(logging.WARNING)

model_args = Seq2SeqArgs()
model_args.num_train_epochs = 15
# model_args.no_save = True
model_args.evaluate_generated_text = True
model_args.evaluate_during_training = True
model_args.evaluate_during_training_verbose = False
model_args.tensorboard_dir = "runs"
model_args.max_length = 200
model_args.train_batch_size=15
model_args.overwrite_output_dir=True
model_args.wandb_project = "cs224u"

# Initialize model
model = Seq2SeqModel(
    encoder_decoder_type="bart",
    encoder_decoder_name="facebook/bart-large-cnn",
    args=model_args,
    use_cuda=True,
)

# model = Seq2SeqModel(encoder_decoder_type="bart", encoder_decoder_name="./outputs/best_model", args=model_args, use_cuda=True,)

# model = Seq2SeqModel(
#     encoder_type="bert",
#     encoder_name="bert-base-uncased",
#     decoder_name="bert-base-uncased",
#     args=model_args,
#     use_cuda=True,
# )


def count_matches(labels, preds):
    return sum(
        [
            1 if label == pred else 0
            for label, pred in zip(labels, preds)
        ]
    )

def get_wer(labels, preds):
    return np.mean(
        [
            fastwer.score_sent(pred, label)
            for label, pred in zip(labels, preds)
        ]
    )

In [5]:
import pandas as pd

train_df = pd.read_pickle("train.pkl")
train_df = train_df.dropna()
dev_df = pd.read_pickle("dev.pkl")

In [6]:
# train_df = train_df.head(100)

In [7]:
# dev_df = dev_df.head(20)

In [8]:
# Train the model
# wandb.init(project='cs224u', entity='gbanerje')

# # 2. Save model inputs and hyperparameters
# config = wandb.config
# config.learning_rate = 0.01

# Model training here

model.train_model(
    train_df, eval_data=dev_df, matches=count_matches, wer=get_wer, show_running_loss=True, args={'fp16': False}
)

# wandb.join()

# # Evaluate the model
results = model.eval_model(dev_df)

INFO:simpletransformers.seq2seq.seq2seq_utils: Creating features from dataset file at cache_dir/


HBox(children=(FloatProgress(value=0.0, max=16163.0), HTML(value='')))

INFO:simpletransformers.seq2seq.seq2seq_model: Training started





HBox(children=(FloatProgress(value=0.0, description='Epoch', max=15.0, style=ProgressStyle(description_width='…

[34m[1mwandb[0m: Currently logged in as: [33mgbanerje[0m (use `wandb login --relogin` to force relogin)


HBox(children=(FloatProgress(value=0.0, description='Running Epoch 0 of 15', max=1078.0, style=ProgressStyle(d…





KeyboardInterrupt: 

In [6]:
model.predict(
        [
            "Hee walks dogks", "Hai my precous boi", "tteko", "e trade often coing sides with other traes", "he kepts extensive nodes on a cosing playurs"
        ]
    )

HBox(children=(FloatProgress(value=0.0, description='Generating outputs', max=1.0, style=ProgressStyle(descrip…




["Hee walks dogks. He walks dog. He walked dogks, he walks dogk. He was walking dogks? He walks dogs. He's walking dogk? He walked dogs. he walks dogs? He's not",
 "Hai my precous boi. I'm not sure what to do. I've got a plan. I'll be back. I hope. I love you. I really do. But I'm a little nervous. I don",
 'Tteko is a national program of the Togo Football Association. The team is based in the city of Tteko. The club has been in business since the 1970s. The organization has been around for more than 30 years',
 'The trade often coing sides with other traes. The trade often trade often with other trade often. The traes are often on opposite sides of the war. The war has been going on for years. The two sides trade often',
 'He kept extensive nodes on a cosing playurs. He kept them on a number of different nodes. He also kept them in a number on a different part of the building. He had a lot of nodes on the building, he']

In [7]:
dev_df.head()

Unnamed: 0,input_text,target_text
0,the coma sat to te parnting afternoon and the ...,Takuma Sato's disappointing afternoon ended wi...
1,leat remark many has fhurtorexxpended to inclu...,Tony Roma's menu has further expanded to inclu...
2,there is nobody that's rich and stupid and not...,There's nobody that rich and stupid and narcis...
3,e trade often coing sides with other traes wit...,The trail often coincides with other trails wi...
4,he kepts extensive nodes on a cosing playurs i...,He kept extensive notes on opposing players an...


In [9]:
model.predict(
        [
            "Hee woks dogks"
        ]
    )

HBox(children=(FloatProgress(value=0.0, description='Generating outputs', max=1.0, style=ProgressStyle(descrip…




['Hee woks!']

In [6]:
model

<simpletransformers.seq2seq.seq2seq_model.Seq2SeqModel at 0x7f95b0f93d50>