In [8]:
from simpletransformers.classification import ClassificationModel
import pandas as pd
from sklearn.metrics import classification_report
import logging
import csv
import os
import numpy as np
from transformers import  utils, Trainer, TrainingArguments, ElectraTokenizer, ElectraForSequenceClassification


In [3]:
logging.basicConfig(level=logging.INFO)
transformers_logger = logging.getLogger("transformers")
transformers_logger.setLevel(logging.WARNING)


In [14]:
LOAD_SAVED_MODEL = False
AUGMENT_WITH_NEUTRAL = True
saved_model_path = "models/binary/electra_classifier"
model_name = "howey/electra-base-mnli"
data_dir = "data/argumentation"

train_df = pd.read_csv(os.path.join(data_dir, 'train_iam.tsv'), sep='\t')
dev_df = pd.read_csv(os.path.join(data_dir, 'dev_iam.txt'), sep='\t')
test_df = pd.read_csv(os.path.join(data_dir, 'test_iam.txt'), sep='\t')
all_claims = pd.read_csv(os.path.join(data_dir, 'claims.txt'), sep='\t')
np.random.seed(42)

if AUGMENT_WITH_NEUTRAL:
    neutral_claims = all_claims[all_claims.type=='O'] 
    lower_bound = 0
    
    min_train_label = min(train_df['label'].value_counts())
    train_sample = neutral_claims.iloc[:min_train_label]
    train_df = pd.concat([train_df, train_sample]).sample(frac=1)
    lower_bound = min_train_label
    
    min_dev_label = min(dev_df['label'].value_counts())
    dev_sample = neutral_claims.iloc[lower_bound: lower_bound + min_dev_label]    
    dev_df = pd.concat([dev_df, dev_sample]).sample(frac=1)
    lower_bound = lower_bound + min_dev_label
    
    min_test_label = min(test_df['label'].value_counts())
    test_sample = neutral_claims.iloc[lower_bound: lower_bound + min_test_label]    
    test_df = pd.concat([test_df, test_sample]).sample(frac=1)
    
    
train_df.columns = ['claim_label', 'text_a', 'text_b', 'id', 'labels']
train_df = train_df[['text_a', 'text_b', 'labels']]

dev_df.columns = ['claim_label', 'text_a', 'text_b', 'id', 'labels']
dev_df = dev_df[['text_a', 'text_b', 'labels']]

test_df.columns = ['claim_label', 'text_a', 'text_b', 'id', 'labels']
test_df = dev_df[['text_a', 'text_b', 'labels']]

test_df['labels'].value_counts()

226    -1
1947    0
217     1
239    -1
433     1
       ..
163     1
95     -1
439     1
355     1
2013    0
Name: labels, Length: 725, dtype: int64

In [15]:
train_args = {
    'evaluate_during_training': True,
    'evaluate_during_training_verbose': True,
    'max_seq_length': 128,
    'num_train_epochs': 10,
    'train_batch_size': 16,
    'labels_list': [1, 0, -1],
    'use_multiprocessing': False,
    'use_multiprocessing_for_evaluation': False,
    'overwrite_output_dir': True,
    'evaluate_during_training_steps': 100000
}

model = ClassificationModel('bert', 'bert-base-cased', num_labels=3, args=train_args, use_cuda=True)

# Define metric
def clf_report(labels, preds):
    return classification_report(labels, preds, output_dict=True, labels = [1, 0, -1] ,target_names=['supporting', 'neutral', 'counter'])


model.train_model(train_df, eval_df=dev_df, clf_report=clf_report)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
INFO:simpletransformers.classification.classification_utils: Converting to features started. Cache is not used.
INFO:simpletransformers.classification.classification_utils: Saving features into cached file cache_dir/cached_train_bert_128_3_3


Epoch:   0%|          | 0/10 [00:00<?, ?it/s]

Running Epoch 0 of 10:   0%|          | 0/353 [00:00<?, ?it/s]

INFO:simpletransformers.classification.classification_utils: Converting to features started. Cache is not used.
INFO:simpletransformers.classification.classification_utils: Saving features into cached file cache_dir/cached_dev_bert_128_3_3
INFO:simpletransformers.classification.classification_model:{'mcc': 0.3098273119471111, 'clf_report': {'0.0': {'precision': 0.5252808988764045, 'recall': 0.722007722007722, 'f1-score': 0.608130081300813, 'support': 259.0}, '1.0': {'precision': 0.87, 'recall': 0.37339055793991416, 'f1-score': 0.5225225225225225, 'support': 233.0}, '2.0': {'precision': 0.42379182156133827, 'recall': 0.4892703862660944, 'f1-score': 0.4541832669322709, 'support': 233.0}, 'accuracy': 0.5351724137931034, 'macro avg': {'precision': 0.6063575734792476, 'recall': 0.5282228887379102, 'f1-score': 0.5282786235852022, 'support': 725.0}, 'weighted avg': {'precision': 0.6034499961831458, 'recall': 0.5351724137931034, 'f1-score': 0.5311422620687966, 'support': 725.0}}, 'eval_loss': 

Running Epoch 1 of 10:   0%|          | 0/353 [00:00<?, ?it/s]

INFO:simpletransformers.classification.classification_utils: Converting to features started. Cache is not used.
INFO:simpletransformers.classification.classification_utils: Saving features into cached file cache_dir/cached_dev_bert_128_3_3
INFO:simpletransformers.classification.classification_model:{'mcc': 0.42219189999587475, 'clf_report': {'0.0': {'precision': 0.5805471124620061, 'recall': 0.7374517374517374, 'f1-score': 0.6496598639455782, 'support': 259.0}, '1.0': {'precision': 0.7167630057803468, 'recall': 0.5321888412017167, 'f1-score': 0.6108374384236452, 'support': 233.0}, '2.0': {'precision': 0.5829596412556054, 'recall': 0.5579399141630901, 'f1-score': 0.5701754385964912, 'support': 233.0}, 'accuracy': 0.6137931034482759, 'macro avg': {'precision': 0.6267565864993193, 'recall': 0.6091934976055148, 'f1-score': 0.6102242469885716, 'support': 725.0}, 'weighted avg': {'precision': 0.625099419154533, 'recall': 0.6137931034482759, 'f1-score': 0.6116384898035815, 'support': 725.0}},

Running Epoch 2 of 10:   0%|          | 0/353 [00:00<?, ?it/s]

INFO:simpletransformers.classification.classification_utils: Converting to features started. Cache is not used.
INFO:simpletransformers.classification.classification_utils: Saving features into cached file cache_dir/cached_dev_bert_128_3_3
INFO:simpletransformers.classification.classification_model:{'mcc': 0.4875585841617746, 'clf_report': {'0.0': {'precision': 0.591715976331361, 'recall': 0.7722007722007722, 'f1-score': 0.6700167504187604, 'support': 259.0}, '1.0': {'precision': 0.8014184397163121, 'recall': 0.48497854077253216, 'f1-score': 0.6042780748663101, 'support': 233.0}, '2.0': {'precision': 0.6544715447154471, 'recall': 0.6909871244635193, 'f1-score': 0.6722338204592903, 'support': 233.0}, 'accuracy': 0.6537931034482759, 'macro avg': {'precision': 0.6825353202543734, 'recall': 0.6493888124789412, 'f1-score': 0.648842881914787, 'support': 725.0}, 'weighted avg': {'precision': 0.6792783506792033, 'recall': 0.6537931034482759, 'f1-score': 0.6496022206473432, 'support': 725.0}}, 

Running Epoch 3 of 10:   0%|          | 0/353 [00:00<?, ?it/s]

INFO:simpletransformers.classification.classification_utils: Converting to features started. Cache is not used.
INFO:simpletransformers.classification.classification_utils: Saving features into cached file cache_dir/cached_dev_bert_128_3_3
INFO:simpletransformers.classification.classification_model:{'mcc': 0.4987124090289529, 'clf_report': {'0.0': {'precision': 0.6875, 'recall': 0.806949806949807, 'f1-score': 0.7424511545293073, 'support': 259.0}, '1.0': {'precision': 0.8787878787878788, 'recall': 0.37339055793991416, 'f1-score': 0.5240963855421688, 'support': 233.0}, '2.0': {'precision': 0.5527950310559007, 'recall': 0.7639484978540773, 'f1-score': 0.6414414414414416, 'support': 233.0}, 'accuracy': 0.6537931034482759, 'macro avg': {'precision': 0.7063609699479265, 'recall': 0.6480962875812661, 'f1-score': 0.6359963271709725, 'support': 725.0}, 'weighted avg': {'precision': 0.7056845765428974, 'recall': 0.6537931034482759, 'f1-score': 0.6398140175314094, 'support': 725.0}}, 'eval_loss'

Running Epoch 4 of 10:   0%|          | 0/353 [00:00<?, ?it/s]

INFO:simpletransformers.classification.classification_utils: Converting to features started. Cache is not used.
INFO:simpletransformers.classification.classification_utils: Saving features into cached file cache_dir/cached_dev_bert_128_3_3
INFO:simpletransformers.classification.classification_model:{'mcc': 0.49085421677469643, 'clf_report': {'0.0': {'precision': 0.6011080332409973, 'recall': 0.8378378378378378, 'f1-score': 0.7, 'support': 259.0}, '1.0': {'precision': 0.875, 'recall': 0.45064377682403434, 'f1-score': 0.5949008498583569, 'support': 233.0}, '2.0': {'precision': 0.6188524590163934, 'recall': 0.648068669527897, 'f1-score': 0.6331236897274634, 'support': 233.0}, 'accuracy': 0.6524137931034483, 'macro avg': {'precision': 0.6983201640857969, 'recall': 0.6455167613965896, 'f1-score': 0.6426748465286067, 'support': 725.0}, 'weighted avg': {'precision': 0.6948339359451559, 'recall': 0.6524137931034483, 'f1-score': 0.6447306451358567, 'support': 725.0}}, 'eval_loss': 2.22639989820

Running Epoch 5 of 10:   0%|          | 0/353 [00:00<?, ?it/s]

INFO:simpletransformers.classification.classification_utils: Converting to features started. Cache is not used.
INFO:simpletransformers.classification.classification_utils: Saving features into cached file cache_dir/cached_dev_bert_128_3_3
INFO:simpletransformers.classification.classification_model:{'mcc': 0.5182434694531054, 'clf_report': {'0.0': {'precision': 0.688135593220339, 'recall': 0.7837837837837838, 'f1-score': 0.7328519855595668, 'support': 259.0}, '1.0': {'precision': 0.8372093023255814, 'recall': 0.463519313304721, 'f1-score': 0.5966850828729282, 'support': 233.0}, '2.0': {'precision': 0.584717607973422, 'recall': 0.7553648068669528, 'f1-score': 0.659176029962547, 'support': 233.0}, 'accuracy': 0.6717241379310345, 'macro avg': {'precision': 0.7033541678397809, 'recall': 0.6675559679851526, 'f1-score': 0.6629043661316807, 'support': 725.0}, 'weighted avg': {'precision': 0.7028083982672215, 'recall': 0.6717241379310345, 'f1-score': 0.6654128324835773, 'support': 725.0}}, 'ev

Running Epoch 6 of 10:   0%|          | 0/353 [00:00<?, ?it/s]

INFO:simpletransformers.classification.classification_utils: Converting to features started. Cache is not used.
INFO:simpletransformers.classification.classification_utils: Saving features into cached file cache_dir/cached_dev_bert_128_3_3
INFO:simpletransformers.classification.classification_model:{'mcc': 0.4776364370939345, 'clf_report': {'0.0': {'precision': 0.6111111111111112, 'recall': 0.8494208494208494, 'f1-score': 0.7108239095315023, 'support': 259.0}, '1.0': {'precision': 0.8636363636363636, 'recall': 0.40772532188841204, 'f1-score': 0.553935860058309, 'support': 233.0}, '2.0': {'precision': 0.592156862745098, 'recall': 0.648068669527897, 'f1-score': 0.6188524590163934, 'support': 233.0}, 'accuracy': 0.6427586206896552, 'macro avg': {'precision': 0.6889681124975243, 'recall': 0.6350716136123862, 'f1-score': 0.6278707428687348, 'support': 725.0}, 'weighted avg': {'precision': 0.6861759993443564, 'recall': 0.6427586206896552, 'f1-score': 0.6308456150525031, 'support': 725.0}}, '

Running Epoch 7 of 10:   0%|          | 0/353 [00:00<?, ?it/s]

INFO:simpletransformers.classification.classification_utils: Converting to features started. Cache is not used.
INFO:simpletransformers.classification.classification_utils: Saving features into cached file cache_dir/cached_dev_bert_128_3_3
INFO:simpletransformers.classification.classification_model:{'mcc': 0.5021316554570409, 'clf_report': {'0.0': {'precision': 0.6304985337243402, 'recall': 0.8301158301158301, 'f1-score': 0.7166666666666667, 'support': 259.0}, '1.0': {'precision': 0.8677685950413223, 'recall': 0.45064377682403434, 'f1-score': 0.5932203389830508, 'support': 233.0}, '2.0': {'precision': 0.6045627376425855, 'recall': 0.6824034334763949, 'f1-score': 0.6411290322580645, 'support': 233.0}, 'accuracy': 0.6606896551724138, 'macro avg': {'precision': 0.7009432888027494, 'recall': 0.654387680138753, 'f1-score': 0.6503386793025939, 'support': 725.0}, 'weighted avg': {'precision': 0.6984169941378685, 'recall': 0.6606896551724138, 'f1-score': 0.6527173381597884, 'support': 725.0}},

Running Epoch 8 of 10:   0%|          | 0/353 [00:00<?, ?it/s]

INFO:simpletransformers.classification.classification_utils: Converting to features started. Cache is not used.
INFO:simpletransformers.classification.classification_utils: Saving features into cached file cache_dir/cached_dev_bert_128_3_3
INFO:simpletransformers.classification.classification_model:{'mcc': 0.49789721475017157, 'clf_report': {'0.0': {'precision': 0.6426426426426426, 'recall': 0.8262548262548263, 'f1-score': 0.7229729729729729, 'support': 259.0}, '1.0': {'precision': 0.8738738738738738, 'recall': 0.41630901287553645, 'f1-score': 0.5639534883720929, 'support': 233.0}, '2.0': {'precision': 0.5871886120996441, 'recall': 0.7081545064377682, 'f1-score': 0.642023346303502, 'support': 233.0}, 'accuracy': 0.6565517241379311, 'macro avg': {'precision': 0.7012350428720535, 'recall': 0.6502394485227102, 'f1-score': 0.6429832692161893, 'support': 725.0}, 'weighted avg': {'precision': 0.6991337981741711, 'recall': 0.6565517241379311, 'f1-score': 0.6458518654888463, 'support': 725.0}}

Running Epoch 9 of 10:   0%|          | 0/353 [00:00<?, ?it/s]

INFO:simpletransformers.classification.classification_utils: Converting to features started. Cache is not used.
INFO:simpletransformers.classification.classification_utils: Saving features into cached file cache_dir/cached_dev_bert_128_3_3
INFO:simpletransformers.classification.classification_model:{'mcc': 0.4869243614691068, 'clf_report': {'0.0': {'precision': 0.6184971098265896, 'recall': 0.8262548262548263, 'f1-score': 0.7074380165289257, 'support': 259.0}, '1.0': {'precision': 0.8761061946902655, 'recall': 0.4248927038626609, 'f1-score': 0.5722543352601156, 'support': 233.0}, '2.0': {'precision': 0.5939849624060151, 'recall': 0.6781115879828327, 'f1-score': 0.6332665330661323, 'support': 233.0}, 'accuracy': 0.6496551724137931, 'macro avg': {'precision': 0.6961960889742901, 'recall': 0.6430863727001066, 'f1-score': 0.6376529616183912, 'support': 725.0}, 'weighted avg': {'precision': 0.693409642825545, 'recall': 0.6496551724137931, 'f1-score': 0.6401555980703552, 'support': 725.0}}, 

(3530,
 defaultdict(list,
             {'global_step': [353,
               706,
               1059,
               1412,
               1765,
               2118,
               2471,
               2824,
               3177,
               3530],
              'train_loss': [0.5000143051147461,
               0.4420318603515625,
               0.501922070980072,
               0.056644558906555176,
               0.0009263952379114926,
               0.0024284522514790297,
               0.00035910806036554277,
               0.0001368423254461959,
               0.00012917320418637246,
               0.00024039547133725137],
              'mcc': [0.3098273119471111,
               0.42219189999587475,
               0.4875585841617746,
               0.4987124090289529,
               0.49085421677469643,
               0.5182434694531054,
               0.4776364370939345,
               0.5021316554570409,
               0.49789721475017157,
               0.4869243614691068],
  