In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]='5'
import torch  
torch.manual_seed(3407)
import random
random.seed(3407)
import numpy as np
np.random.seed(3407)

import logging
import pandas as pd
from seq2seq_model import Seq2SeqModel
logging.basicConfig(level=logging.INFO)
transformers_logger = logging.getLogger("transformers")
transformers_logger.setLevel(logging.WARNING)

In [2]:
train_data = pd.read_csv("./templated_data/en-train.csv").values.tolist()
train_df = pd.DataFrame(train_data, columns=["input_text", "target_text"])

eval_data = pd.read_csv("./templated_data/en-dev.csv").values.tolist()
eval_df = pd.DataFrame(eval_data, columns=["input_text", "target_text"])

model_args = {
    "reprocess_input_data": True,
    "overwrite_output_dir": True,
    "max_seq_length": 50,
    "train_batch_size": 100,
    "num_train_epochs": 20,
    "save_eval_checkpoints": False,
    "save_model_every_epoch": False,
    "evaluate_during_training": True,
    "evaluate_generated_text": True,
    "evaluate_during_training_verbose": True,
    "use_multiprocessing": False,
    "max_length": 25,
    "manual_seed": 4,
    "save_steps": 11898,
    "gradient_accumulation_steps": 1,
    "output_dir": "./exp/template",
    # "num_labels":33
}

# Initialize model
model = Seq2SeqModel(
    encoder_decoder_type="bart",
    encoder_decoder_name="facebook/bart-large",
    args=model_args,
    # use_cuda=False,
)

In [3]:
model.config

BartConfig {
  "_name_or_path": "facebook/bart-large",
  "activation_dropout": 0.1,
  "activation_function": "gelu",
  "add_bias_logits": false,
  "add_final_layer_norm": false,
  "architectures": [
    "BartModel"
  ],
  "attention_dropout": 0.1,
  "bos_token_id": 0,
  "classif_dropout": 0.1,
  "classifier_dropout": 0.0,
  "d_model": 1024,
  "decoder_attention_heads": 16,
  "decoder_ffn_dim": 4096,
  "decoder_layerdrop": 0.0,
  "decoder_layers": 12,
  "decoder_start_token_id": 2,
  "dropout": 0.1,
  "early_stopping": true,
  "encoder_attention_heads": 16,
  "encoder_ffn_dim": 4096,
  "encoder_layerdrop": 0.0,
  "encoder_layers": 12,
  "eos_token_id": 2,
  "forced_bos_token_id": 0,
  "forced_eos_token_id": 2,
  "gradient_checkpointing": false,
  "id2label": {
    "0": "LABEL_0",
    "1": "LABEL_1",
    "2": "LABEL_2"
  },
  "init_std": 0.02,
  "is_encoder_decoder": true,
  "label2id": {
    "LABEL_0": 0,
    "LABEL_1": 1,
    "LABEL_2": 2
  },
  "max_position_embeddings": 1024,
  "mode

In [4]:
# Train the model
model.train_model(train_df, eval_data=eval_df)

INFO:seq2seq_utils: Creating features from dataset file at cache_dir/


  0%|          | 0/25224 [00:00<?, ?it/s]

INFO:seq2seq_model: Training started


Epoch:   0%|          | 0/20 [00:00<?, ?it/s]

Running Epoch 0 of 20:   0%|          | 0/253 [00:00<?, ?it/s]

  torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
INFO:seq2seq_utils: Creating features from dataset file at cache_dir/


  0%|          | 0/1284 [00:00<?, ?it/s]

INFO:seq2seq_model:{'eval_loss': 0.31893494938101086, 'eval_acc': 0.9090489025798999}
INFO:seq2seq_model:Saving model into outputs/best_model


Running Epoch 1 of 20:   0%|          | 0/253 [00:00<?, ?it/s]

INFO:seq2seq_utils: Creating features from dataset file at cache_dir/


  0%|          | 0/1284 [00:00<?, ?it/s]

INFO:seq2seq_model:{'eval_loss': 0.25947871246600746, 'eval_acc': 0.91998459761263}
INFO:seq2seq_model:Saving model into outputs/best_model


Running Epoch 2 of 20:   0%|          | 0/253 [00:00<?, ?it/s]

INFO:seq2seq_utils: Creating features from dataset file at cache_dir/


  0%|          | 0/1284 [00:00<?, ?it/s]

INFO:seq2seq_model:{'eval_loss': 0.25699868328545405, 'eval_acc': 0.9249133615710435}
INFO:seq2seq_model:Saving model into outputs/best_model


Running Epoch 3 of 20:   0%|          | 0/253 [00:00<?, ?it/s]

INFO:seq2seq_utils: Creating features from dataset file at cache_dir/


  0%|          | 0/1284 [00:00<?, ?it/s]

INFO:seq2seq_model:{'eval_loss': 0.2674537917240436, 'eval_acc': 0.9219869079707355}


Running Epoch 4 of 20:   0%|          | 0/253 [00:00<?, ?it/s]

INFO:seq2seq_utils: Creating features from dataset file at cache_dir/


  0%|          | 0/1284 [00:00<?, ?it/s]

INFO:seq2seq_model:{'eval_loss': 0.2836513363102578, 'eval_acc': 0.924374278013092}


Running Epoch 5 of 20:   0%|          | 0/253 [00:00<?, ?it/s]

INFO:seq2seq_utils: Creating features from dataset file at cache_dir/


  0%|          | 0/1284 [00:00<?, ?it/s]

INFO:seq2seq_model:{'eval_loss': 0.294633080130038, 'eval_acc': 0.9239122063919908}


Running Epoch 6 of 20:   0%|          | 0/253 [00:00<?, ?it/s]

INFO:seq2seq_utils: Creating features from dataset file at cache_dir/


  0%|          | 0/1284 [00:00<?, ?it/s]

INFO:seq2seq_model:{'eval_loss': 0.3233650849426385, 'eval_acc': 0.9244512899499422}


Running Epoch 7 of 20:   0%|          | 0/253 [00:00<?, ?it/s]

INFO:seq2seq_utils: Creating features from dataset file at cache_dir/


  0%|          | 0/1284 [00:00<?, ?it/s]

INFO:seq2seq_model:{'eval_loss': 0.3239800650279344, 'eval_acc': 0.9229880631497882}
INFO:seq2seq_utils: Creating features from dataset file at cache_dir/


  0%|          | 0/1284 [00:00<?, ?it/s]

INFO:seq2seq_model:{'eval_loss': 0.3153576129202887, 'eval_acc': 0.9244512899499422}


Running Epoch 8 of 20:   0%|          | 0/253 [00:00<?, ?it/s]

INFO:seq2seq_utils: Creating features from dataset file at cache_dir/


  0%|          | 0/1284 [00:00<?, ?it/s]

INFO:seq2seq_model:{'eval_loss': 0.34357810943670897, 'eval_acc': 0.9266076241817481}


Running Epoch 9 of 20:   0%|          | 0/253 [00:00<?, ?it/s]

INFO:seq2seq_utils: Creating features from dataset file at cache_dir/


  0%|          | 0/1284 [00:00<?, ?it/s]

INFO:seq2seq_model:{'eval_loss': 0.3370363715087405, 'eval_acc': 0.9270696958028495}


Running Epoch 10 of 20:   0%|          | 0/253 [00:00<?, ?it/s]

INFO:seq2seq_utils: Creating features from dataset file at cache_dir/


  0%|          | 0/1284 [00:00<?, ?it/s]

INFO:seq2seq_model:{'eval_loss': 0.3541497187399716, 'eval_acc': 0.9273007316134001}


Running Epoch 11 of 20:   0%|          | 0/253 [00:00<?, ?it/s]

INFO:seq2seq_utils: Creating features from dataset file at cache_dir/


  0%|          | 0/1284 [00:00<?, ?it/s]

INFO:seq2seq_model:{'eval_loss': 0.3845473116391009, 'eval_acc': 0.9266076241817481}


Running Epoch 12 of 20:   0%|          | 0/253 [00:00<?, ?it/s]

INFO:seq2seq_utils: Creating features from dataset file at cache_dir/


  0%|          | 0/1284 [00:00<?, ?it/s]

INFO:seq2seq_model:{'eval_loss': 0.39517636088110647, 'eval_acc': 0.9272237196765498}


Running Epoch 13 of 20:   0%|          | 0/253 [00:00<?, ?it/s]

INFO:seq2seq_utils: Creating features from dataset file at cache_dir/


  0%|          | 0/1284 [00:00<?, ?it/s]

INFO:seq2seq_model:{'eval_loss': 0.38539703841674033, 'eval_acc': 0.9259915286869465}


Running Epoch 14 of 20:   0%|          | 0/253 [00:00<?, ?it/s]

INFO:seq2seq_utils: Creating features from dataset file at cache_dir/


  0%|          | 0/1284 [00:00<?, ?it/s]

INFO:seq2seq_model:{'eval_loss': 0.3968125297948255, 'eval_acc': 0.9274547554871005}


Running Epoch 15 of 20:   0%|          | 0/253 [00:00<?, ?it/s]

INFO:seq2seq_utils: Creating features from dataset file at cache_dir/


  0%|          | 0/1284 [00:00<?, ?it/s]

INFO:seq2seq_model:{'eval_loss': 0.41106409937004496, 'eval_acc': 0.9275317674239507}
INFO:seq2seq_utils: Creating features from dataset file at cache_dir/


  0%|          | 0/1284 [00:00<?, ?it/s]

INFO:seq2seq_model:{'eval_loss': 0.41100642471487475, 'eval_acc': 0.9282248748556026}


Running Epoch 16 of 20:   0%|          | 0/253 [00:00<?, ?it/s]

INFO:seq2seq_utils: Creating features from dataset file at cache_dir/


  0%|          | 0/1284 [00:00<?, ?it/s]

INFO:seq2seq_model:{'eval_loss': 0.41384772008851817, 'eval_acc': 0.9282248748556026}


Running Epoch 17 of 20:   0%|          | 0/253 [00:00<?, ?it/s]

INFO:seq2seq_utils: Creating features from dataset file at cache_dir/


  0%|          | 0/1284 [00:00<?, ?it/s]

INFO:seq2seq_model:{'eval_loss': 0.41559366294130773, 'eval_acc': 0.928378898729303}


Running Epoch 18 of 20:   0%|          | 0/253 [00:00<?, ?it/s]

INFO:seq2seq_utils: Creating features from dataset file at cache_dir/


  0%|          | 0/1284 [00:00<?, ?it/s]

INFO:seq2seq_model:{'eval_loss': 0.43124131961314943, 'eval_acc': 0.9287639584135541}


Running Epoch 19 of 20:   0%|          | 0/253 [00:00<?, ?it/s]

INFO:seq2seq_utils: Creating features from dataset file at cache_dir/


  0%|          | 0/1284 [00:00<?, ?it/s]

INFO:seq2seq_model:{'eval_loss': 0.4384798976876166, 'eval_acc': 0.9288409703504044}
INFO:seq2seq_model:Saving model into ./exp/template
INFO:seq2seq_model: Training of facebook/bart-large model complete. Saved to ./exp/template.


In [24]:
# Evaluate the model
results = model.eval_model(eval_df, 
                           verbose = False, 
                           silent = True)

INFO:seq2seq_utils: Creating features from dataset file at cache_dir/


  0%|          | 0/1284 [00:00<?, ?it/s]

In [25]:
results

{'eval_loss': 0.4384798976876166, 'eval_acc': 0.9288409703504044}

In [7]:
eval_df['input_text'].iloc[1282]

'afpm represents companies including chevron corporation exxonmobil koch industries marathon petroleum and valero energy .'

In [8]:
model.predict(['afpm represents companies including chevron corporation exxonmobil koch industries marathon petroleum and valero energy .'])

['koch industries is a PrivateCorp entity']

In [9]:
eval_df

Unnamed: 0,input_text,target_text
0,eli lilly founder president of pharmaceutical ...,eli lilly and company is a PublicCorp entity
1,eli lilly founder president of pharmaceutical ...,eli lilly is an OtherPER entity
2,christoph haberland designed a new marble pulp...,italy is a HumanSettlement entity
3,christoph haberland designed a new marble pulp...,christoph haberland is an OtherPER entity
4,christoph haberland designed a new marble pulp...,pulpit is an OtherPROD entity
...,...,...
1279,afpm represents companies including chevron co...,exxonmobil is a PublicCorp entity
1280,afpm represents companies including chevron co...,valero energy is a PublicCorp entity
1281,afpm represents companies including chevron co...,marathon petroleum is a PublicCorp entity
1282,afpm represents companies including chevron co...,chevron corporation is a PublicCorp entity


In [11]:
eval_df['preds'] = [model.predict([x]) for x in eval_df['input_text']]
eval_df

Unnamed: 0,input_text,target_text,preds
0,eli lilly founder president of pharmaceutical ...,eli lilly and company is a PublicCorp entity,[eli lilly is a Scientist entity]
1,eli lilly founder president of pharmaceutical ...,eli lilly is an OtherPER entity,[eli lilly is a Scientist entity]
2,christoph haberland designed a new marble pulp...,italy is a HumanSettlement entity,[pulpit is an OtherPROD entity]
3,christoph haberland designed a new marble pulp...,christoph haberland is an OtherPER entity,[pulpit is an OtherPROD entity]
4,christoph haberland designed a new marble pulp...,pulpit is an OtherPROD entity,[pulpit is an OtherPROD entity]
...,...,...,...
1279,afpm represents companies including chevron co...,exxonmobil is a PublicCorp entity,[koch industries is a PrivateCorp entity]
1280,afpm represents companies including chevron co...,valero energy is a PublicCorp entity,[koch industries is a PrivateCorp entity]
1281,afpm represents companies including chevron co...,marathon petroleum is a PublicCorp entity,[koch industries is a PrivateCorp entity]
1282,afpm represents companies including chevron co...,chevron corporation is a PublicCorp entity,[koch industries is a PrivateCorp entity]


In [18]:
eval_df

Unnamed: 0,input_text,target_text,preds
0,eli lilly founder president of pharmaceutical ...,eli lilly and company is a PublicCorp entity,[eli lilly is a Scientist entity]
1,eli lilly founder president of pharmaceutical ...,eli lilly is an OtherPER entity,[eli lilly is a Scientist entity]
2,christoph haberland designed a new marble pulp...,italy is a HumanSettlement entity,[pulpit is an OtherPROD entity]
3,christoph haberland designed a new marble pulp...,christoph haberland is an OtherPER entity,[pulpit is an OtherPROD entity]
4,christoph haberland designed a new marble pulp...,pulpit is an OtherPROD entity,[pulpit is an OtherPROD entity]
...,...,...,...
1279,afpm represents companies including chevron co...,exxonmobil is a PublicCorp entity,[koch industries is a PrivateCorp entity]
1280,afpm represents companies including chevron co...,valero energy is a PublicCorp entity,[koch industries is a PrivateCorp entity]
1281,afpm represents companies including chevron co...,marathon petroleum is a PublicCorp entity,[koch industries is a PrivateCorp entity]
1282,afpm represents companies including chevron co...,chevron corporation is a PublicCorp entity,[koch industries is a PrivateCorp entity]
