In [1]:
import sys
sys.path.insert(0, "../../src")
from pathlib import Path
from collections import Counter

import numpy as np
from scipy.special import softmax

from gen.util import read_data, write_jsonl
from rte.aggregate import generate_micro_macro_df

In [2]:
root_data = Path("../../data").resolve()
root_model = Path("../../models").resolve()

In [3]:
# constants
LOOKUP = {
    "verifiable": {"no": "NOT VERIFIABLE", "yes": "VERIFIABLE"},
    "label": {"nei": "NOT ENOUGH INFO", "r": "REFUTES", "s": "SUPPORTS"}
}

SEED = 123456789

LABEL2ID = {"SUPPORTS": 0, "NOT ENOUGH INFO": 1, "REFUTES": 2}
ID2LABEL = {0: "SUPPORTS", 1: "NOT ENOUGH INFO", 2: "REFUTES"}

# Init

In [4]:
import evaluate
from datasets import Dataset, DatasetDict
from transformers import (
    AutoTokenizer, 
    AutoModelForSequenceClassification,
    TrainingArguments, 
    Trainer,
    DataCollatorWithPadding,
    TextClassificationPipeline,
    pipeline
)

import torch
torch.backends.cuda.matmul.allow_tf32 = True

  from .autonotebook import tqdm as notebook_tqdm


# Huggingface Init

## Model

In [5]:
accuracy_metric = evaluate.load("accuracy")
recall_metric = evaluate.load("recall")
precision_metric = evaluate.load("precision")
f1_metric = evaluate.load("f1")

In [6]:
model_checkpoint = "xlnet-base-cased"
def model_init():
    return AutoModelForSequenceClassification.from_pretrained(
        model_checkpoint, 
        num_labels=3, 
        id2label=ID2LABEL, 
        label2id=LABEL2ID
    )

model = model_init()
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, do_lower_case=True)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

def preprocess(examples):
    return tokenizer(examples["evidence"], examples["claim"], max_length=1024, truncation="only_first")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    
    results = {}
    results.update(accuracy_metric.compute(predictions=predictions, references=labels))
    results.update(recall_metric.compute(predictions=predictions, references=labels, average="macro"))
    results.update(precision_metric.compute(predictions=predictions, references=labels, average="macro"))
    results.update(f1_metric.compute(predictions=predictions, references=labels, average="macro"))
    
    return results

Some weights of the model checkpoint at xlnet-base-cased were not used when initializing XLNetForSequenceClassification: ['lm_loss.weight', 'lm_loss.bias']
- This IS expected if you are initializing XLNetForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing XLNetForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of XLNetForSequenceClassification were not initialized from the model checkpoint at xlnet-base-cased and are newly initialized: ['logits_proj.weight', 'sequence_summary.summary.weight', 'sequence_summary.summary.bias', 'logits_proj.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions a

In [7]:
dataset = ["fever", "climatefeverpure", "fever-climatefeverpure", "climatefever", "fever-climatefever"]
doc_sent = ["doc", "sent"]

di = 0
ds = 0

model_store_path = root_model.joinpath("sentence-models" if ds == 1 else "document-models")
model_store_path.mkdir(exist_ok=True)
model_store_path = model_store_path / f"{dataset[di]}-{model_checkpoint}-{doc_sent[ds]}"

## Dataset

In [8]:
datap = root_data / f"{doc_sent[ds]}-dataset"

data = DatasetDict({
    "train": Dataset.from_list(read_data(datap / f"{dataset[di]}.train.n5.jsonl")),
    "validation": Dataset.from_list(read_data(datap / f"{dataset[di]}.dev.n5.jsonl")),
    "test": Dataset.from_list(read_data(datap / f"{dataset[di]}.test.n5.jsonl"))
}).map(preprocess, batched=True)

                                                                     

## Trainer

In [8]:
# effective batch size of 32
per_device_train_batch_size = 4
gradient_accumulation_steps = 8
per_device_eval_batch_size = 32


learning_rate = 4e-4
epoch = 4
metric_name = "f1"
warmup_ratio=0.1
save_steps=200
eval_steps=200

# Hyperparameter tuning

In [9]:
# shard the data if the dataset is large for hyperparameter tuning
shard = data["train"].num_rows > 50000
hp_tune_train = data["train"].shuffle(seed=SEED).shard(num_shards=5, index=1)

In [10]:
training_args = TrainingArguments(
    model_store_path,
    overwrite_output_dir=True,
    evaluation_strategy = "steps",
    eval_steps=eval_steps,
    save_strategy = "no",
    learning_rate=learning_rate,
    per_device_train_batch_size=per_device_train_batch_size,
    gradient_accumulation_steps=gradient_accumulation_steps,
    per_device_eval_batch_size=per_device_eval_batch_size,
    num_train_epochs=epoch,
    weight_decay=0.01,
    metric_for_best_model=metric_name,
    push_to_hub=False,
    seed=SEED,
    data_seed=SEED,
    warmup_ratio=warmup_ratio,
    report_to="tensorboard",
    tf32=True
)

_ = model.train()
trainer = Trainer(
    model_init=model_init,
    args=training_args,
    train_dataset=hp_tune_train if shard else data["train"],
    eval_dataset=data["validation"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics
)

Some weights of the model checkpoint at xlnet-base-cased were not used when initializing XLNetForSequenceClassification: ['lm_loss.weight', 'lm_loss.bias']
- This IS expected if you are initializing XLNetForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing XLNetForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of XLNetForSequenceClassification were not initialized from the model checkpoint at xlnet-base-cased and are newly initialized: ['logits_proj.weight', 'sequence_summary.summary.weight', 'logits_proj.bias', 'sequence_summary.summary.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions a

In [11]:
def optuna_hp_space(trial):
    return {
        "learning_rate": trial.suggest_categorical("learning_rate", [1e-5, 3e-5, 2e-5]),
        "num_train_epochs": trial.suggest_int("num_train_epochs", 1, 5),
    }

def compute_objective(metrics):
    return metrics["f1"]

In [12]:
best_run = trainer.hyperparameter_search(
    direction="maximize", 
    n_trials=10, 
    hp_space=optuna_hp_space
)

[I 2023-07-04 13:11:59,323] A new study created in memory with name: no-name-64ecacb2-32d0-4e4b-9979-492bc4e80cc0
Some weights of the model checkpoint at xlnet-base-cased were not used when initializing XLNetForSequenceClassification: ['lm_loss.weight', 'lm_loss.bias']
- This IS expected if you are initializing XLNetForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing XLNetForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of XLNetForSequenceClassification were not initialized from the model checkpoint at xlnet-base-cased and are newly initialized: ['logits_proj.weight', 'sequence_summary.summary.weight', 'logits_proj.bias', 'sequence_summar

Step,Training Loss,Validation Loss,Accuracy,Recall,Precision,F1
200,No log,0.327203,0.860186,0.860186,0.861096,0.860057
400,No log,0.264002,0.910291,0.910291,0.917895,0.90979
600,0.429200,0.216248,0.928693,0.928693,0.92849,0.928586
800,0.429200,0.211402,0.928893,0.928893,0.930698,0.928863
1000,0.207100,0.277709,0.928193,0.928193,0.932454,0.928417
1200,0.207100,0.227207,0.934193,0.934193,0.936673,0.934151
1400,0.207100,0.231121,0.929893,0.929893,0.935631,0.929655
1600,0.162400,0.212054,0.935694,0.935694,0.938987,0.935658
1800,0.162400,0.219816,0.933293,0.933293,0.936543,0.93345
2000,0.136900,0.242423,0.942094,0.942094,0.94341,0.942127


[I 2023-07-04 13:38:42,117] Trial 0 finished with value: 3.773920629657178 and parameters: {'learning_rate': 2e-05, 'num_train_epochs': 3}. Best is trial 0 with value: 3.773920629657178.
Some weights of the model checkpoint at xlnet-base-cased were not used when initializing XLNetForSequenceClassification: ['lm_loss.weight', 'lm_loss.bias']
- This IS expected if you are initializing XLNetForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing XLNetForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of XLNetForSequenceClassification were not initialized from the model checkpoint at xlnet-base-cased and are newly initialized: ['logits_proj.weight'

Step,Training Loss,Validation Loss,Accuracy,Recall,Precision,F1
200,No log,0.294838,0.877988,0.877988,0.876522,0.876881
400,No log,0.244326,0.914291,0.914291,0.91651,0.914301
600,0.399100,0.216877,0.925393,0.925393,0.926572,0.925342
800,0.399100,0.208642,0.928793,0.928793,0.929972,0.928828


[I 2023-07-04 13:47:22,731] Trial 1 finished with value: 3.716385750791969 and parameters: {'learning_rate': 1e-05, 'num_train_epochs': 1}. Best is trial 0 with value: 3.773920629657178.
Some weights of the model checkpoint at xlnet-base-cased were not used when initializing XLNetForSequenceClassification: ['lm_loss.weight', 'lm_loss.bias']
- This IS expected if you are initializing XLNetForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing XLNetForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of XLNetForSequenceClassification were not initialized from the model checkpoint at xlnet-base-cased and are newly initialized: ['logits_proj.weight'

Step,Training Loss,Validation Loss,Accuracy,Recall,Precision,F1
200,No log,0.647677,0.651365,0.651365,0.490384,0.542775
400,No log,0.287442,0.891089,0.891089,0.898472,0.890323
600,0.549100,0.227832,0.919792,0.919792,0.92168,0.919731
800,0.549100,0.222048,0.925993,0.925993,0.928381,0.925861
1000,0.225100,0.248289,0.930593,0.930593,0.931507,0.930479
1200,0.225100,0.202605,0.934193,0.934193,0.934863,0.934172
1400,0.225100,0.196666,0.934093,0.934093,0.936395,0.933885
1600,0.183100,0.222643,0.932793,0.932793,0.936958,0.93247
1800,0.183100,0.190744,0.936394,0.936394,0.939722,0.936275
2000,0.155600,0.230196,0.938594,0.938594,0.940186,0.938668


  _warn_prf(average, modifier, msg_start, len(result))
[I 2023-07-04 14:31:44,012] Trial 2 finished with value: 3.7692215900331143 and parameters: {'learning_rate': 1e-05, 'num_train_epochs': 5}. Best is trial 0 with value: 3.773920629657178.
Some weights of the model checkpoint at xlnet-base-cased were not used when initializing XLNetForSequenceClassification: ['lm_loss.weight', 'lm_loss.bias']
- This IS expected if you are initializing XLNetForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing XLNetForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of XLNetForSequenceClassification were not initialized from the model checkpoint at xlnet-bas

Step,Training Loss,Validation Loss,Accuracy,Recall,Precision,F1
200,No log,0.61277,0.654465,0.654465,0.490371,0.544762
400,No log,0.266906,0.894589,0.894589,0.901333,0.894318
600,0.518400,0.21627,0.923792,0.923792,0.924839,0.923751
800,0.518400,0.233802,0.922692,0.922692,0.926638,0.922399
1000,0.216300,0.251286,0.927793,0.927793,0.928314,0.927651
1200,0.216300,0.220513,0.931993,0.931993,0.933338,0.932013
1400,0.216300,0.214051,0.930293,0.930293,0.93375,0.930007
1600,0.182200,0.234253,0.927493,0.927493,0.932438,0.927329
1800,0.182200,0.203328,0.935894,0.935894,0.939039,0.935873
2000,0.155700,0.223856,0.937894,0.937894,0.93864,0.938048


  _warn_prf(average, modifier, msg_start, len(result))
[I 2023-07-04 15:07:32,702] Trial 3 finished with value: 3.7723604692706596 and parameters: {'learning_rate': 1e-05, 'num_train_epochs': 4}. Best is trial 0 with value: 3.773920629657178.
Some weights of the model checkpoint at xlnet-base-cased were not used when initializing XLNetForSequenceClassification: ['lm_loss.weight', 'lm_loss.bias']
- This IS expected if you are initializing XLNetForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing XLNetForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of XLNetForSequenceClassification were not initialized from the model checkpoint at xlnet-bas

Step,Training Loss,Validation Loss,Accuracy,Recall,Precision,F1
200,No log,0.294838,0.877988,0.877988,0.876522,0.876881
400,No log,0.244326,0.914291,0.914291,0.91651,0.914301
600,0.399100,0.216877,0.925393,0.925393,0.926572,0.925342
800,0.399100,0.208642,0.928793,0.928793,0.929972,0.928828


[I 2023-07-04 15:16:12,364] Trial 4 finished with value: 3.716385750791969 and parameters: {'learning_rate': 1e-05, 'num_train_epochs': 1}. Best is trial 0 with value: 3.773920629657178.
Some weights of the model checkpoint at xlnet-base-cased were not used when initializing XLNetForSequenceClassification: ['lm_loss.weight', 'lm_loss.bias']
- This IS expected if you are initializing XLNetForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing XLNetForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of XLNetForSequenceClassification were not initialized from the model checkpoint at xlnet-base-cased and are newly initialized: ['logits_proj.weight'

Step,Training Loss,Validation Loss,Accuracy,Recall,Precision,F1
200,No log,0.248323,0.909991,0.909991,0.9099,0.90976
400,No log,0.246912,0.922392,0.922392,0.926015,0.922162
600,0.345100,0.206967,0.928193,0.928193,0.932654,0.928079
800,0.345100,0.195123,0.937394,0.937394,0.938856,0.937429


[I 2023-07-04 15:24:52,664] Trial 5 finished with value: 3.751072853025028 and parameters: {'learning_rate': 3e-05, 'num_train_epochs': 1}. Best is trial 0 with value: 3.773920629657178.
Some weights of the model checkpoint at xlnet-base-cased were not used when initializing XLNetForSequenceClassification: ['lm_loss.weight', 'lm_loss.bias']
- This IS expected if you are initializing XLNetForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing XLNetForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of XLNetForSequenceClassification were not initialized from the model checkpoint at xlnet-base-cased and are newly initialized: ['logits_proj.weight'

Step,Training Loss,Validation Loss,Accuracy,Recall,Precision,F1
200,No log,0.340842,0.844984,0.844984,0.84624,0.844499


[I 2023-07-04 15:26:51,689] Trial 6 pruned. 
Some weights of the model checkpoint at xlnet-base-cased were not used when initializing XLNetForSequenceClassification: ['lm_loss.weight', 'lm_loss.bias']
- This IS expected if you are initializing XLNetForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing XLNetForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of XLNetForSequenceClassification were not initialized from the model checkpoint at xlnet-base-cased and are newly initialized: ['logits_proj.weight', 'sequence_summary.summary.weight', 'logits_proj.bias', 'sequence_summary.summary.bias']
You should probably TRAIN this model on a down-strea

Step,Training Loss,Validation Loss,Accuracy,Recall,Precision,F1
200,No log,0.392673,0.816982,0.816982,0.832459,0.812412


[I 2023-07-04 15:28:50,571] Trial 7 pruned. 
Some weights of the model checkpoint at xlnet-base-cased were not used when initializing XLNetForSequenceClassification: ['lm_loss.weight', 'lm_loss.bias']
- This IS expected if you are initializing XLNetForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing XLNetForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of XLNetForSequenceClassification were not initialized from the model checkpoint at xlnet-base-cased and are newly initialized: ['logits_proj.weight', 'sequence_summary.summary.weight', 'logits_proj.bias', 'sequence_summary.summary.bias']
You should probably TRAIN this model on a down-strea

Step,Training Loss,Validation Loss,Accuracy,Recall,Precision,F1
200,No log,0.398272,0.811881,0.811881,0.833953,0.805728


[I 2023-07-04 15:30:49,223] Trial 8 pruned. 
Some weights of the model checkpoint at xlnet-base-cased were not used when initializing XLNetForSequenceClassification: ['lm_loss.weight', 'lm_loss.bias']
- This IS expected if you are initializing XLNetForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing XLNetForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of XLNetForSequenceClassification were not initialized from the model checkpoint at xlnet-base-cased and are newly initialized: ['logits_proj.weight', 'sequence_summary.summary.weight', 'logits_proj.bias', 'sequence_summary.summary.bias']
You should probably TRAIN this model on a down-strea

Step,Training Loss,Validation Loss,Accuracy,Recall,Precision,F1
200,No log,0.300135,0.876988,0.876988,0.877071,0.876767
400,No log,0.266715,0.908391,0.908391,0.915148,0.908128


[I 2023-07-04 15:34:46,550] Trial 9 pruned. 


In [13]:
best_run

BestRun(run_id='0', objective=3.773920629657178, hyperparameters={'learning_rate': 2e-05, 'num_train_epochs': 3}, run_summary=None)

## Train with best hyperparameters

In [9]:
training_args = TrainingArguments(
    model_store_path,
    overwrite_output_dir=True,
    evaluation_strategy = "steps",
    eval_steps=1000,
    save_strategy="steps",
    save_steps=1000,
    save_total_limit=5,
    learning_rate=learning_rate,
    per_device_train_batch_size=per_device_train_batch_size,
    gradient_accumulation_steps=gradient_accumulation_steps,
    per_device_eval_batch_size=per_device_eval_batch_size,
    num_train_epochs=epoch,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model=metric_name,
    push_to_hub=False,
    seed=SEED,
    data_seed=SEED,
    warmup_ratio=warmup_ratio,
    report_to="tensorboard",
    tf32=True
)

trainer = Trainer(
    model_init=model_init,
    args=training_args,
    train_dataset=data["train"],
    eval_dataset=data["validation"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics
)

for n, v in best_run.hyperparameters.items():
    setattr(trainer.args, n, v)
    
trainer.train()

Some weights of the model checkpoint at xlnet-base-cased were not used when initializing XLNetForSequenceClassification: ['lm_loss.bias', 'lm_loss.weight']
- This IS expected if you are initializing XLNetForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing XLNetForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of XLNetForSequenceClassification were not initialized from the model checkpoint at xlnet-base-cased and are newly initialized: ['sequence_summary.summary.bias', 'logits_proj.weight', 'sequence_summary.summary.weight', 'logits_proj.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions a

Step,Training Loss,Validation Loss,Accuracy,Recall,Precision,F1
1000,0.2137,0.20707,0.928893,0.928893,0.932075,0.928513
2000,0.1732,0.171401,0.950295,0.950295,0.950874,0.95024
3000,0.1604,0.150962,0.953895,0.953895,0.954582,0.953922
4000,0.1439,0.201096,0.945195,0.945195,0.948588,0.945032
5000,0.1188,0.161349,0.955196,0.955196,0.956071,0.955226
6000,0.1091,0.180735,0.951595,0.951595,0.953588,0.95163
7000,0.1137,0.131233,0.958296,0.958296,0.958895,0.958337
8000,0.1091,0.155056,0.956696,0.956696,0.958075,0.956721
9000,0.1051,0.138161,0.960396,0.960396,0.961059,0.960428
10000,0.0809,0.183378,0.955496,0.955496,0.956557,0.955589


IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



In [10]:
trainer.save_model(model_store_path.parent / (model_store_path.stem + ".out"))

# Evaluate

In [11]:
import pandas as pd
from sklearn.metrics import classification_report

## Test on validation data

In [12]:
preds = trainer.predict(data["validation"])
val = generate_doc_df(data["validation"], preds)

In [13]:
print(classification_report(y_true=val["actual"], y_pred=val["predicted"]))

                 precision    recall  f1-score   support

NOT ENOUGH INFO       1.00      1.00      1.00      3333
        REFUTES       0.96      0.93      0.94      3333
       SUPPORTS       0.93      0.97      0.95      3333

       accuracy                           0.96      9999
      macro avg       0.96      0.96      0.96      9999
   weighted avg       0.96      0.96      0.96      9999



## Test on test data

In [14]:
preds = trainer.predict(data["test"])
tes = generate_doc_df(data["test"], preds)

In [16]:
print(classification_report(y_true=tes["actual"], y_pred=tes["predicted"]))

                 precision    recall  f1-score   support

NOT ENOUGH INFO       1.00      0.99      1.00      3333
        REFUTES       0.96      0.90      0.93      3333
       SUPPORTS       0.90      0.96      0.93      3333

       accuracy                           0.95      9999
      macro avg       0.95      0.95      0.95      9999
   weighted avg       0.95      0.95      0.95      9999

