In [1]:
# This is the training script for fine-tuning bert on 
# unaltered GPT data and manually labelled data
# For better performance/generalization, look for augmented dataset
# Read README.md for comments and details.

In [2]:

# all classes

classes = ["banking","valuation","household","real estate","corporate","external","sovereign","technology", "climate", "energy", "health", "eu"]


In [3]:
from transformers import AutoModelForSequenceClassification, AutoTokenizer, Trainer, TrainingArguments
import torch
import numpy as np
from sklearn.metrics import accuracy_score, precision_recall_fscore_support,top_k_accuracy_score
import math
import pickle
from datasets import Dataset

In [4]:
# load bert-based and finbert
model = AutoModelForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=len(classes))
finbert = AutoModelForSequenceClassification.from_pretrained('ProsusAI/finbert')
tokenizer = AutoTokenizer.from_pretrained('ProsusAI/finbert', use_fast =True)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

In [5]:
# weights transfer for encoder layers only 
finbert_weights = finbert.state_dict()
model_weights = model.state_dict()
del finbert_weights["bert.pooler.dense.weight"]
del finbert_weights["bert.pooler.dense.bias"]
del finbert_weights["classifier.weight"]
del finbert_weights["classifier.bias"]
finbert_weights["bert.pooler.dense.weight"] = model_weights["bert.pooler.dense.weight"]
finbert_weights["bert.pooler.dense.bias"] = model_weights["bert.pooler.dense.bias"]
finbert_weights["classifier.weight"] = model_weights["classifier.weight"]
finbert_weights["classifier.bias"] = model_weights["classifier.bias"]

model.load_state_dict(finbert_weights)


<All keys matched successfully>

In [6]:
# flatten to one list for all 3

# manual labelled
with open('train_data.pickle', 'rb') as file:
    train = pickle.load(file)

# gpt labelled p1
with open('gpt.pickle', 'rb') as file:
    gpt = pickle.load(file)

# gpt labelled p2
with open('gpt_p2.pickle', 'rb') as file:
    gpt2 = pickle.load(file)
gpt = [item for sublist in gpt for item in sublist]
gpt2 = [item for sublist in gpt2 for item in sublist]

mixed = gpt + gpt2

In [7]:
print(len(mixed))

2458


In [8]:
sample = 1

additional_text = []
additional_label = []

for idx in range(1,len(mixed)-1):
    sent1, sent2 = mixed[idx-1:idx+1]
    for i in range(1):
        additional_text.append(sent1["text"]+' '+ sent2["text"])
        dist = (np.array(sent1["dist"]) + np.array(sent2["dist"]))/2
        additional_label.append(np.argmax(dist))

for idx in range(2,len(mixed)-1):
    sent1, sent2, sent3 = mixed[idx-2:idx+1]
    for i in range(1):
        additional_text.append(sent1["text"]+' '+ sent2["text"]+' '+sent3["text"])
        dist = (np.array(sent1["dist"]) + np.array(sent2["dist"])+ np.array(sent3["dist"]))/3
        additional_label.append(np.argmax(dist))

text_max = [item["text"] for i in range(sample) for item in train]
label_max = [np.argmax(item["dist"]) for i in range(sample) for item in train]

text_max_mixed = [item["text"] for i in range(sample) for item in mixed]
label_max_mixed = [np.argmax(item["dist"]) for i in range(sample) for item in mixed]


In [9]:
import random
def randomize(text, label):
    temp = list(zip(text, label))
    random.shuffle(temp)
    comb_text, comb_label =  zip(*temp)
    return comb_text, comb_label

a1 = randomize(text_max, label_max)
a2 = randomize(text_max_mixed, label_max_mixed)
a3 = randomize(additional_text, additional_label)

test_text, train_text =  [*a1[0][math.ceil(len(a1[0])*0.8):],*a2[0][math.ceil(len(a2[0])*0.8):],*a3[0][math.ceil(len(a3[0])*0.8):] ], [*a1[0][:math.ceil(len(a1[0])*0.8)],*a2[0][:math.ceil(len(a2[0])*0.8)],*a3[0][:math.ceil(len(a3[0])*0.8)] ]
test_label, train_label = [*a1[1][math.ceil(len(a1[1])*0.8):],*a2[1][math.ceil(len(a2[1])*0.8):],*a3[1][math.ceil(len(a3[1])*0.8):] ],  [*a1[1][:math.ceil(len(a1[1])*0.8)],*a2[1][:math.ceil(len(a2[1])*0.8)],*a3[1][:math.ceil(len(a3[1])*0.8)] ]


In [10]:
print(len(train_text))
print(len(test_text))

6105
1525


In [11]:
assert len(train_text) == len(train_label)
assert len(test_text) == len(test_label)

In [12]:
train_dataset = Dataset.from_dict({"text":train_text, "label":train_label})
test_dataset = Dataset.from_dict({"text":test_text, "label":test_label})

In [13]:
def tokenize(batch):
    return tokenizer(batch['text'], padding=True, truncation=True)
train_dataset = train_dataset.map(tokenize, batched=True, batch_size=len(train_dataset))
test_dataset = test_dataset.map(tokenize, batched=True, batch_size=len(test_dataset))
train_dataset.set_format('torch', columns=['input_ids', 'attention_mask', 'label'])
test_dataset.set_format('torch', columns=['input_ids', 'attention_mask', 'label'])

100%|██████████| 1/1 [00:00<00:00,  1.71ba/s]
100%|██████████| 1/1 [00:00<00:00,  8.85ba/s]


In [14]:
train_dataset["input_ids"]

tensor([[  101,  2152,  7016,  ...,     0,     0,     0],
        [  101,  1996,  2204,  ...,     0,     0,     0],
        [  101,  2000,  9585,  ...,     0,     0,     0],
        ...,
        [  101,  2339,  2106,  ...,     0,     0,     0],
        [  101,  1998,  2023,  ...,     0,     0,     0],
        [  101,  2005, 12194,  ...,     0,     0,     0]])

In [15]:
def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='macro')
    acc = accuracy_score(labels, preds)
    top3 = top_k_accuracy_score(labels, pred.predictions,k=3)
    top2 = top_k_accuracy_score(labels, pred.predictions,k=2)
    return {
        'accuracy': acc,
        'f1': f1,
        'precision': precision,
        'recall': recall,
        'top3': top3,
        'top2': top2
    }

training_args = TrainingArguments(
    
    output_dir='./results',
    learning_rate=2e-5,
    num_train_epochs=20,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=64,
    warmup_steps=500,
    weight_decay=0.01,
    metric_for_best_model="accuracy",
    evaluation_strategy='epoch',
    save_strategy = "epoch",
    logging_dir='./logs',
    save_total_limit = 1, # Only last 5 models are saved. Older ones are deleted.
    load_best_model_at_end=True,
)
   
trainer = Trainer(
    model=model,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=train_dataset,
    eval_dataset=test_dataset
)

In [16]:
trainer.train()

The following columns in the training set  don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: text.
***** Running training *****
  Num examples = 6105
  Num Epochs = 20
  Instantaneous batch size per device = 16
  Total train batch size (w. parallel, distributed & accumulation) = 16
  Gradient Accumulation steps = 1
  Total optimization steps = 7640
  5%|▌         | 382/7640 [01:00<18:29,  6.54it/s]The following columns in the evaluation set  don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: text.
***** Running Evaluation *****
  Num examples = 1525
  Batch size = 64
  _warn_prf(average, modifier, msg_start, len(result))

  5%|▌         | 382/7640 [01:04<18:29,  6.54it/s]Saving model checkpoint to ./results\checkpoint-382
Configuration saved in ./results\checkpoint-382\config.json


{'eval_loss': 1.3231501579284668, 'eval_accuracy': 0.6216393442622951, 'eval_f1': 0.2641283468775932, 'eval_precision': 0.24853064605410116, 'eval_recall': 0.2845097464096975, 'eval_top3': 0.838688524590164, 'eval_top2': 0.7704918032786885, 'eval_runtime': 3.7587, 'eval_samples_per_second': 405.724, 'eval_steps_per_second': 6.385, 'epoch': 1.0}


Model weights saved in ./results\checkpoint-382\pytorch_model.bin
Deleting older checkpoint [results\checkpoint-1] due to args.save_total_limit
  7%|▋         | 501/7640 [01:27<17:59,  6.61it/s]

{'loss': 1.8285, 'learning_rate': 2e-05, 'epoch': 1.31}


 10%|█         | 764/7640 [02:09<15:31,  7.38it/s]The following columns in the evaluation set  don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: text.
***** Running Evaluation *****
  Num examples = 1525
  Batch size = 64
  _warn_prf(average, modifier, msg_start, len(result))

 10%|█         | 764/7640 [02:12<15:31,  7.38it/s]Saving model checkpoint to ./results\checkpoint-764
Configuration saved in ./results\checkpoint-764\config.json


{'eval_loss': 0.9345628023147583, 'eval_accuracy': 0.7134426229508197, 'eval_f1': 0.48283670571727316, 'eval_precision': 0.5513118992814193, 'eval_recall': 0.47658011177435505, 'eval_top3': 0.9121311475409836, 'eval_top2': 0.8544262295081967, 'eval_runtime': 3.5088, 'eval_samples_per_second': 434.619, 'eval_steps_per_second': 6.84, 'epoch': 2.0}


Model weights saved in ./results\checkpoint-764\pytorch_model.bin
Deleting older checkpoint [results\checkpoint-2] due to args.save_total_limit
 13%|█▎        | 1000/7640 [03:00<22:27,  4.93it/s]

{'loss': 0.8524, 'learning_rate': 1.8599439775910366e-05, 'epoch': 2.62}


 15%|█▌        | 1146/7640 [03:30<19:51,  5.45it/s]The following columns in the evaluation set  don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: text.
***** Running Evaluation *****
  Num examples = 1525
  Batch size = 64
  _warn_prf(average, modifier, msg_start, len(result))

 15%|█▌        | 1146/7640 [03:34<19:51,  5.45it/s]Saving model checkpoint to ./results\checkpoint-1146
Configuration saved in ./results\checkpoint-1146\config.json


{'eval_loss': 0.8317504525184631, 'eval_accuracy': 0.7586885245901639, 'eval_f1': 0.5665392120509521, 'eval_precision': 0.65642689537675, 'eval_recall': 0.5560937077533173, 'eval_top3': 0.921311475409836, 'eval_top2': 0.8767213114754099, 'eval_runtime': 3.9623, 'eval_samples_per_second': 384.875, 'eval_steps_per_second': 6.057, 'epoch': 3.0}


Model weights saved in ./results\checkpoint-1146\pytorch_model.bin
Deleting older checkpoint [results\checkpoint-382] due to args.save_total_limit
 20%|█▉        | 1501/7640 [04:47<15:20,  6.67it/s]

{'loss': 0.4911, 'learning_rate': 1.719887955182073e-05, 'epoch': 3.93}


 20%|██        | 1528/7640 [04:51<13:34,  7.50it/s]The following columns in the evaluation set  don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: text.
***** Running Evaluation *****
  Num examples = 1525
  Batch size = 64
  _warn_prf(average, modifier, msg_start, len(result))

 20%|██        | 1528/7640 [04:54<13:34,  7.50it/s]Saving model checkpoint to ./results\checkpoint-1528
Configuration saved in ./results\checkpoint-1528\config.json


{'eval_loss': 0.816765308380127, 'eval_accuracy': 0.7816393442622951, 'eval_f1': 0.6126348897134983, 'eval_precision': 0.699358173506377, 'eval_recall': 0.5876805128153205, 'eval_top3': 0.9324590163934426, 'eval_top2': 0.8950819672131147, 'eval_runtime': 3.1849, 'eval_samples_per_second': 478.817, 'eval_steps_per_second': 7.535, 'epoch': 4.0}


Model weights saved in ./results\checkpoint-1528\pytorch_model.bin
Deleting older checkpoint [results\checkpoint-764] due to args.save_total_limit
 25%|██▌       | 1910/7640 [05:55<15:24,  6.20it/s]The following columns in the evaluation set  don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: text.
***** Running Evaluation *****
  Num examples = 1525
  Batch size = 64

 25%|██▌       | 1910/7640 [05:59<15:24,  6.20it/s]Saving model checkpoint to ./results\checkpoint-1910
Configuration saved in ./results\checkpoint-1910\config.json


{'eval_loss': 0.9040604829788208, 'eval_accuracy': 0.779016393442623, 'eval_f1': 0.6330762323191418, 'eval_precision': 0.7638509118154689, 'eval_recall': 0.6076949595443342, 'eval_top3': 0.9331147540983606, 'eval_top2': 0.8918032786885246, 'eval_runtime': 3.5426, 'eval_samples_per_second': 430.47, 'eval_steps_per_second': 6.775, 'epoch': 5.0}


Model weights saved in ./results\checkpoint-1910\pytorch_model.bin
Deleting older checkpoint [results\checkpoint-1146] due to args.save_total_limit
 26%|██▌       | 2001/7640 [06:15<14:38,  6.42it/s]

{'loss': 0.2686, 'learning_rate': 1.5798319327731094e-05, 'epoch': 5.24}


 30%|███       | 2292/7640 [06:59<13:08,  6.79it/s]The following columns in the evaluation set  don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: text.
***** Running Evaluation *****
  Num examples = 1525
  Batch size = 64

 30%|███       | 2292/7640 [07:02<13:08,  6.79it/s]Saving model checkpoint to ./results\checkpoint-2292
Configuration saved in ./results\checkpoint-2292\config.json


{'eval_loss': 0.9727253913879395, 'eval_accuracy': 0.7901639344262295, 'eval_f1': 0.6586085765092596, 'eval_precision': 0.7097317461015495, 'eval_recall': 0.63810521327575, 'eval_top3': 0.9291803278688524, 'eval_top2': 0.8918032786885246, 'eval_runtime': 3.1727, 'eval_samples_per_second': 480.66, 'eval_steps_per_second': 7.564, 'epoch': 6.0}


Model weights saved in ./results\checkpoint-2292\pytorch_model.bin
Deleting older checkpoint [results\checkpoint-1528] due to args.save_total_limit
 33%|███▎      | 2501/7640 [07:36<13:25,  6.38it/s]

{'loss': 0.1551, 'learning_rate': 1.4397759103641458e-05, 'epoch': 6.54}


 35%|███▌      | 2674/7640 [08:03<12:42,  6.51it/s]The following columns in the evaluation set  don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: text.
***** Running Evaluation *****
  Num examples = 1525
  Batch size = 64

 35%|███▌      | 2674/7640 [08:06<12:42,  6.51it/s]Saving model checkpoint to ./results\checkpoint-2674
Configuration saved in ./results\checkpoint-2674\config.json


{'eval_loss': 1.0239416360855103, 'eval_accuracy': 0.801967213114754, 'eval_f1': 0.6848236633018122, 'eval_precision': 0.7188087439065267, 'eval_recall': 0.6700492686549996, 'eval_top3': 0.9350819672131148, 'eval_top2': 0.898360655737705, 'eval_runtime': 3.4169, 'eval_samples_per_second': 446.309, 'eval_steps_per_second': 7.024, 'epoch': 7.0}


Model weights saved in ./results\checkpoint-2674\pytorch_model.bin
Deleting older checkpoint [results\checkpoint-1910] due to args.save_total_limit
 39%|███▉      | 3001/7640 [08:57<11:13,  6.89it/s]

{'loss': 0.0936, 'learning_rate': 1.2997198879551822e-05, 'epoch': 7.85}


 40%|████      | 3056/7640 [09:05<10:09,  7.52it/s]The following columns in the evaluation set  don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: text.
***** Running Evaluation *****
  Num examples = 1525
  Batch size = 64

 40%|████      | 3056/7640 [09:08<10:09,  7.52it/s]Saving model checkpoint to ./results\checkpoint-3056
Configuration saved in ./results\checkpoint-3056\config.json


{'eval_loss': 1.1397786140441895, 'eval_accuracy': 0.7993442622950819, 'eval_f1': 0.6760728010262976, 'eval_precision': 0.7189146365000553, 'eval_recall': 0.6590336212439833, 'eval_top3': 0.9350819672131148, 'eval_top2': 0.898360655737705, 'eval_runtime': 3.084, 'eval_samples_per_second': 494.482, 'eval_steps_per_second': 7.782, 'epoch': 8.0}


Model weights saved in ./results\checkpoint-3056\pytorch_model.bin
Deleting older checkpoint [results\checkpoint-2292] due to args.save_total_limit
 45%|████▌     | 3438/7640 [10:14<11:11,  6.26it/s]The following columns in the evaluation set  don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: text.
***** Running Evaluation *****
  Num examples = 1525
  Batch size = 64

 45%|████▌     | 3438/7640 [10:18<11:11,  6.26it/s]Saving model checkpoint to ./results\checkpoint-3438
Configuration saved in ./results\checkpoint-3438\config.json


{'eval_loss': 1.1892651319503784, 'eval_accuracy': 0.8052459016393443, 'eval_f1': 0.6917401467605856, 'eval_precision': 0.7241559320256402, 'eval_recall': 0.6738060778810601, 'eval_top3': 0.9304918032786885, 'eval_top2': 0.9036065573770492, 'eval_runtime': 3.6103, 'eval_samples_per_second': 422.397, 'eval_steps_per_second': 6.648, 'epoch': 9.0}


Model weights saved in ./results\checkpoint-3438\pytorch_model.bin
Deleting older checkpoint [results\checkpoint-2674] due to args.save_total_limit
 46%|████▌     | 3501/7640 [10:31<12:02,  5.73it/s]

{'loss': 0.0549, 'learning_rate': 1.1596638655462186e-05, 'epoch': 9.16}


 50%|█████     | 3820/7640 [11:29<10:05,  6.31it/s]The following columns in the evaluation set  don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: text.
***** Running Evaluation *****
  Num examples = 1525
  Batch size = 64

 50%|█████     | 3820/7640 [11:33<10:05,  6.31it/s]Saving model checkpoint to ./results\checkpoint-3820
Configuration saved in ./results\checkpoint-3820\config.json


{'eval_loss': 1.3900656700134277, 'eval_accuracy': 0.7921311475409836, 'eval_f1': 0.692325295710151, 'eval_precision': 0.7107144207362891, 'eval_recall': 0.6969758656179778, 'eval_top3': 0.9173770491803279, 'eval_top2': 0.8872131147540984, 'eval_runtime': 3.3332, 'eval_samples_per_second': 457.525, 'eval_steps_per_second': 7.2, 'epoch': 10.0}


Model weights saved in ./results\checkpoint-3820\pytorch_model.bin
Deleting older checkpoint [results\checkpoint-3056] due to args.save_total_limit
 52%|█████▏    | 4001/7640 [12:08<08:43,  6.96it/s]

{'loss': 0.0237, 'learning_rate': 1.0196078431372549e-05, 'epoch': 10.47}


 55%|█████▌    | 4202/7640 [12:37<07:46,  7.37it/s]The following columns in the evaluation set  don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: text.
***** Running Evaluation *****
  Num examples = 1525
  Batch size = 64

 55%|█████▌    | 4202/7640 [12:40<07:46,  7.37it/s]Saving model checkpoint to ./results\checkpoint-4202
Configuration saved in ./results\checkpoint-4202\config.json


{'eval_loss': 1.30657160282135, 'eval_accuracy': 0.8045901639344263, 'eval_f1': 0.7008775117123527, 'eval_precision': 0.7653860892522012, 'eval_recall': 0.6744593488621988, 'eval_top3': 0.9311475409836065, 'eval_top2': 0.9009836065573771, 'eval_runtime': 3.2293, 'eval_samples_per_second': 472.238, 'eval_steps_per_second': 7.432, 'epoch': 11.0}


Model weights saved in ./results\checkpoint-4202\pytorch_model.bin
Deleting older checkpoint [results\checkpoint-3820] due to args.save_total_limit
 57%|█████▋    | 4330/7640 [13:03<11:05,  4.97it/s]

RuntimeError: transform: failed to synchronize: cudaErrorLaunchFailure: unspecified launch failure

In [None]:
trainer.evaluate()

The following columns in the evaluation set  don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: text.
***** Running Evaluation *****
  Num examples = 1526
  Batch size = 64
100%|██████████| 24/24 [00:03<00:00,  6.37it/s]


{'eval_loss': 1.3667666912078857,
 'eval_accuracy': 0.8289646133682831,
 'eval_f1': 0.7600767747856746,
 'eval_precision': 0.7783009182422279,
 'eval_recall': 0.74921395397379,
 'eval_top3': 0.936435124508519,
 'eval_top2': 0.9121887287024901,
 'eval_runtime': 3.9514,
 'eval_samples_per_second': 386.188,
 'eval_steps_per_second': 6.074,
 'epoch': 20.0}

In [None]:

predict_dataset = Dataset.from_dict({"text":["In contrast to the radical forces buffeting valuations, for most companies, 2020 was a year of “strategy lockdown.",
"Domestic policies thus tended to reinforce negative spillovers and exacerbate systemic risk across the euro area.",
"Mortgage interest rate in selected European countries as of 4th quarter of 2019 and 2020 increased",
"Accordingly, I shall spend most of my allotted time outlining the chosen monetary policy instruments and procedures of the ESCB and the considerations that have been raised relating to strategy.",
"Credit card interest rate has gone up a lot."]})

def tokenize(batch):
    return tokenizer(batch['text'], padding=True, truncation=True)
predict_dataset = predict_dataset.map(tokenize, batched=True, batch_size=len(train_dataset))

trainer.predict(predict_dataset)

100%|██████████| 1/1 [00:00<00:00, 167.14ba/s]
The following columns in the test set  don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: text.
***** Running Prediction *****
  Num examples = 5
  Batch size = 64
  0%|          | 0/1 [00:00<?, ?it/s]

PredictionOutput(predictions=array([[-0.54258686, -2.283698  , -1.2337625 ,  0.6101017 ,  8.834227  ,
        -0.60477626, -0.1470656 , -0.25806686, -1.9158627 , -0.45982993,
        -0.08515993, -1.3472333 ],
       [ 9.702999  , -0.29276183, -1.5581697 , -1.5263258 , -1.2451688 ,
        -0.5294447 ,  2.195682  , -1.8384721 , -1.5489938 , -3.0295794 ,
        -1.7577528 , -1.730811  ],
       [ 1.4455788 , -1.5327429 ,  2.143541  ,  7.9250283 , -1.1390123 ,
        -3.0250149 , -0.9481666 , -2.0208018 ,  0.02337106, -0.8389467 ,
        -0.08845273, -0.9033677 ],
       [-1.1007565 , -1.3134553 , -1.9952139 , -0.77075744, -1.4220358 ,
         0.21309426, 10.992527  , -0.8264481 , -0.6472416 , -2.5305545 ,
        -1.0293103 , -1.4311087 ],
       [ 2.7709064 ,  7.787258  ,  2.5337188 , -1.0464295 , -2.2166548 ,
        -1.2778659 , -2.6960404 , -2.199739  , -1.8598685 , -1.4836167 ,
        -1.3202602 , -1.9072998 ]], dtype=float32), label_ids=None, metrics={'test_runtime': 0.0389, 