In [1]:
from torch import nn
from transformers.modeling_outputs import SequenceClassifierOutput
from transformers.models.bert.modeling_bert import BertModel, BertPreTrainedModel

In [2]:
import torch
torch.cuda.empty_cache()

In [3]:
import sklearn
import pandas as pd
import numpy as np

In [4]:
from transformers import AutoTokenizer
from datasets import Dataset, DatasetDict
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer
from transformers import AutoConfig

In [5]:
import os
from os import listdir
import sys
import json
from os import path

In [6]:
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from sklearn import metrics

In [7]:
from sklearn.model_selection import KFold

In [8]:
models_dir = "models/bert-k-fold"

In [9]:
import pandas as pd
labels = open('data/classes.txt').read().splitlines()
df = pd.read_csv("data/belief_benchmark_all_train.csv")

In [10]:
df.head()

Unnamed: 0,text,sentiment,label
0,Private-sector actors expect their business pa...,NEGATIVE,1
1,Lack of product uniformity is perceived by con...,"NEGATIVE, NEGATIVE, NEGATIVE, NEGATIVE",1
2,"However , to the extent that the empirical str...",POSITIVE,2
3,Firms prefer to contract with producers with w...,POSITIVE,2
4,Farmers attempted to use alternate wetting and...,POSITIVE,2


In [11]:
transformer_name = "bert-base-cased"
tokenizer = AutoTokenizer.from_pretrained(transformer_name)
# NOTE: for cross validation, the model should be initialized inside the cv loop

In [12]:
def tokenize(batch):
    return tokenizer(batch['text'], truncation=True)

In [13]:
def compute_metrics(eval_pred):
    y_true = eval_pred.label_ids
    y_pred = np.argmax(eval_pred.predictions, axis=-1)
    report = metrics.classification_report(y_true, y_pred)
    print("report: \n", report)
    
    print("rep type: ", type(report))
    

    return {'f1':metrics.f1_score(y_true, y_pred, average="macro")}

In [14]:
# Note: not used right now, but can be
# https://github.com/huggingface/transformers/blob/65659a29cf5a079842e61a63d57fa24474288998/src/transformers/models/bert/modeling_bert.py#L1486

class BertForSequenceClassification(BertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels
        self.bert = BertModel(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)
        self.init_weights()
        
    def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, labels=None, **kwargs):
        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            **kwargs,
        )
        cls_outputs = outputs.last_hidden_state[:, 0, :]
        cls_outputs = self.dropout(cls_outputs)
        logits = self.classifier(cls_outputs)
        loss = None
        if labels is not None:
            loss_fn = nn.CrossEntropyLoss()
            loss = loss_fn(logits, labels)
        return SequenceClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )


In [15]:
# this is for creating cross-validation folds
def get_sample_based_on_idx(data, indeces):
    return data.iloc[indeces, :].reset_index()

In [16]:
# defining hyperparams
num_epochs = 8
batch_size = 6
weight_decay = 0.01
print(f"num_epochs: {num_epochs}, batch_size: {batch_size}, weight_decay: {weight_decay}")
training_args = TrainingArguments(
    output_dir="./results_triggerless", # is this location in the tmp dir? 
    log_level='error',
    num_train_epochs=num_epochs,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    evaluation_strategy='epoch',
    save_strategy='epoch',
    weight_decay=weight_decay,
    load_best_model_at_end=True, # this is supposed to make sure the best model is loaded by the trainer at the end
    metric_for_best_model="eval_f1" 
    )

In [17]:
# specify what percent of the all_train data should be used in cross-fold validation 
# we can use this to create a learning curve, e.g., performance with 50, 80, 90% of data

from sklearn.model_selection import train_test_split
train_size = 0.8

if train_size < 1.0:
    df, _ = train_test_split(df, test_size=1.0-train_size, random_state=1, stratify=df[['label']])
print(f"train_size: {train_size} (total samples: {len(df)})")

In [18]:
output = open("original.txt", "a")

fold = 0
kfold = KFold(n_splits=5, shuffle=True, random_state=1)
for train_df_idx, eval_df_idx in kfold.split(df):
    
    print("FOLD: ", fold)
    output.write(f"FOLD: {fold}\n")
    new_df = pd.DataFrame()
    
    train_df = get_sample_based_on_idx(df, train_df_idx)
    print("LEN DF: ", len(train_df))
    output.write(f"LEN DF: {len(train_df)}\n")
#     train_df['label'] = [int(item) for item in train_df["annotation: b (belief or attitude), n (not a belief and not an attitude)"]]
    print("done train df")
    output.write("done train df\n")
    eval_df = get_sample_based_on_idx(df, eval_df_idx)
#     eval_df["label"] = [int(item) for item in eval_df['annotation: b (belief or attitude), n (not a belief and not an attitude)']]
    print("done eval df")
    output.write("done eval df\n")
    print("LEN EVAL: ", len(eval_df))
    output.write(f"LEN EVAL: {len(eval_df)}\n")
#     print(eval_df.head())
    ds = DatasetDict()
    ds['train'] = Dataset.from_pandas(train_df)
    ds['validation'] = Dataset.from_pandas(eval_df)
    train_ds = ds['train'].map(
        tokenize, batched=True,
        # remove_columns=['index', 'sentence', 'trigger', 'annotation: b (belief or attitude), n (not a belief and not an attitude)', 'paragraph'],
    )
    eval_ds = ds['validation'].map(
        tokenize,
        batched=True,
        # remove_columns=['index', 'sentence', 'trigger', 'annotation: b (belief or attitude), n (not a belief and not an attitude)', 'paragraph'],
    )

#     config = AutoConfig.from_pretrained(
#         transformer_name,
#         num_labels=2,
#     )

    model = AutoModelForSequenceClassification.from_pretrained(transformer_name, num_labels=4)
    tokenizer = AutoTokenizer.from_pretrained(transformer_name)
#     model = (
#         BertForSequenceClassification
#         .from_pretrained(transformer_name, config=config)
#     )
    trainer = Trainer(
        model=model,
        args=training_args,
        compute_metrics=compute_metrics,
        train_dataset=train_ds,
        eval_dataset=eval_ds,
        tokenizer=tokenizer,
    )
    trainer.train()
    # after training, predict (will use best model?)
    preds = trainer.predict(eval_ds)
#     print("HERE: " , preds)
    final_preds = [np.argmax(x) for x in preds.predictions]
    real_f1 = metrics.f1_score(final_preds, eval_df["label"], average="macro")
    print("F-1: ", real_f1)
    output.write(f"F-1: {real_f1}\n")
    model_name = f"{transformer_name}-best-of-fold-{fold}-f1-{real_f1}"
    model_dir = os.path.join(models_dir, model_name)

    trainer.save_model(model_dir)

    for i, item in enumerate(final_preds):
        if item != eval_ds["label"][i]:
            wrong_df = pd.DataFrame()
            wrong_df["text"] = [eval_df["text"][i]]
            wrong_df["real"] = [eval_df["label"][i]]
            wrong_df["predicted"] = [item]
            new_df = pd.concat([new_df, wrong_df])

    new_df.to_csv(f"{models_dir}/wrong_predictions_{fold}.csv")
    fold += 1
        

FOLD:  0
LEN DF:  138
done train df
done eval df
LEN EVAL:  35


Map:   0%|          | 0/138 [00:00<?, ? examples/s]

Map:   0%|          | 0/35 [00:00<?, ? examples/s]

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at b

  0%|          | 0/184 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


report: 
               precision    recall  f1-score   support

           0       0.34      1.00      0.51        12
           1       0.00      0.00      0.00         9
           2       0.00      0.00      0.00        14

    accuracy                           0.34        35
   macro avg       0.11      0.33      0.17        35
weighted avg       0.12      0.34      0.18        35

rep type:  <class 'str'>
{'eval_loss': 1.2403892278671265, 'eval_f1': 0.1702127659574468, 'eval_runtime': 0.2612, 'eval_samples_per_second': 133.999, 'eval_steps_per_second': 22.971, 'epoch': 1.0}


  0%|          | 0/6 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


report: 
               precision    recall  f1-score   support

           0       0.00      0.00      0.00        12
           1       0.00      0.00      0.00         9
           2       0.40      1.00      0.57        14

    accuracy                           0.40        35
   macro avg       0.13      0.33      0.19        35
weighted avg       0.16      0.40      0.23        35

rep type:  <class 'str'>
{'eval_loss': 1.2753573656082153, 'eval_f1': 0.1904761904761905, 'eval_runtime': 0.2511, 'eval_samples_per_second': 139.378, 'eval_steps_per_second': 23.893, 'epoch': 2.0}


  0%|          | 0/6 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


report: 
               precision    recall  f1-score   support

           0       0.00      0.00      0.00        12
           1       0.00      0.00      0.00         9
           2       0.40      1.00      0.57        14

    accuracy                           0.40        35
   macro avg       0.13      0.33      0.19        35
weighted avg       0.16      0.40      0.23        35

rep type:  <class 'str'>
{'eval_loss': 1.2630975246429443, 'eval_f1': 0.1904761904761905, 'eval_runtime': 0.2426, 'eval_samples_per_second': 144.297, 'eval_steps_per_second': 24.737, 'epoch': 3.0}


  0%|          | 0/6 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


report: 
               precision    recall  f1-score   support

           0       0.34      0.83      0.49        12
           1       0.00      0.00      0.00         9
           2       0.00      0.00      0.00        14

    accuracy                           0.29        35
   macro avg       0.11      0.28      0.16        35
weighted avg       0.12      0.29      0.17        35

rep type:  <class 'str'>
{'eval_loss': 1.2740401029586792, 'eval_f1': 0.1626016260162602, 'eval_runtime': 0.2608, 'eval_samples_per_second': 134.184, 'eval_steps_per_second': 23.003, 'epoch': 4.0}


  0%|          | 0/6 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


report: 
               precision    recall  f1-score   support

           0       0.00      0.00      0.00        12
           1       0.00      0.00      0.00         9
           2       0.40      1.00      0.57        14

    accuracy                           0.40        35
   macro avg       0.13      0.33      0.19        35
weighted avg       0.16      0.40      0.23        35

rep type:  <class 'str'>
{'eval_loss': 1.4980642795562744, 'eval_f1': 0.1904761904761905, 'eval_runtime': 0.3777, 'eval_samples_per_second': 92.678, 'eval_steps_per_second': 15.888, 'epoch': 5.0}


  0%|          | 0/6 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


report: 
               precision    recall  f1-score   support

           0       0.42      0.67      0.52        12
           1       0.00      0.00      0.00         9
           2       0.40      0.43      0.41        14
           3       0.00      0.00      0.00         0

    accuracy                           0.40        35
   macro avg       0.21      0.27      0.23        35
weighted avg       0.30      0.40      0.34        35

rep type:  <class 'str'>
{'eval_loss': 1.335120677947998, 'eval_f1': 0.23248053392658513, 'eval_runtime': 0.3201, 'eval_samples_per_second': 109.343, 'eval_steps_per_second': 18.745, 'epoch': 6.0}


  0%|          | 0/6 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


report: 
               precision    recall  f1-score   support

           0       0.60      0.25      0.35        12
           1       0.00      0.00      0.00         9
           2       0.46      0.79      0.58        14
           3       0.00      0.00      0.00         0

    accuracy                           0.40        35
   macro avg       0.26      0.26      0.23        35
weighted avg       0.39      0.40      0.35        35

rep type:  <class 'str'>
{'eval_loss': 1.4206159114837646, 'eval_f1': 0.23297213622291024, 'eval_runtime': 0.2953, 'eval_samples_per_second': 118.526, 'eval_steps_per_second': 20.319, 'epoch': 7.0}


  0%|          | 0/6 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


report: 
               precision    recall  f1-score   support

           0       0.50      0.33      0.40        12
           1       0.00      0.00      0.00         9
           2       0.50      0.86      0.63        14
           3       0.00      0.00      0.00         0

    accuracy                           0.46        35
   macro avg       0.25      0.30      0.26        35
weighted avg       0.37      0.46      0.39        35

rep type:  <class 'str'>
{'eval_loss': 1.33847975730896, 'eval_f1': 0.2578947368421053, 'eval_runtime': 0.2739, 'eval_samples_per_second': 127.768, 'eval_steps_per_second': 21.903, 'epoch': 8.0}
{'train_runtime': 65.3822, 'train_samples_per_second': 16.885, 'train_steps_per_second': 2.814, 'train_loss': 1.075794800468113, 'epoch': 8.0}


  0%|          | 0/6 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


report: 
               precision    recall  f1-score   support

           0       0.50      0.33      0.40        12
           1       0.00      0.00      0.00         9
           2       0.50      0.86      0.63        14
           3       0.00      0.00      0.00         0

    accuracy                           0.46        35
   macro avg       0.25      0.30      0.26        35
weighted avg       0.37      0.46      0.39        35

rep type:  <class 'str'>
F-1:  0.2578947368421053
FOLD:  1
LEN DF:  138
done train df
done eval df
LEN EVAL:  35


Map:   0%|          | 0/138 [00:00<?, ? examples/s]

Map:   0%|          | 0/35 [00:00<?, ? examples/s]



  0%|          | 0/184 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


report: 
               precision    recall  f1-score   support

           0       0.23      1.00      0.37         8
           1       0.00      0.00      0.00         6
           2       0.00      0.00      0.00        15
           3       0.00      0.00      0.00         6

    accuracy                           0.23        35
   macro avg       0.06      0.25      0.09        35
weighted avg       0.05      0.23      0.09        35

rep type:  <class 'str'>
{'eval_loss': 1.4057917594909668, 'eval_f1': 0.09302325581395347, 'eval_runtime': 0.275, 'eval_samples_per_second': 127.287, 'eval_steps_per_second': 21.821, 'epoch': 1.0}


  0%|          | 0/6 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


report: 
               precision    recall  f1-score   support

           0       0.23      1.00      0.37         8
           1       0.00      0.00      0.00         6
           2       0.00      0.00      0.00        15
           3       0.00      0.00      0.00         6

    accuracy                           0.23        35
   macro avg       0.06      0.25      0.09        35
weighted avg       0.05      0.23      0.09        35

rep type:  <class 'str'>
{'eval_loss': 1.4484484195709229, 'eval_f1': 0.09302325581395347, 'eval_runtime': 0.3448, 'eval_samples_per_second': 101.494, 'eval_steps_per_second': 17.399, 'epoch': 2.0}


  0%|          | 0/6 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


report: 
               precision    recall  f1-score   support

           0       1.00      0.38      0.55         8
           1       0.00      0.00      0.00         6
           2       0.47      1.00      0.64        15
           3       0.00      0.00      0.00         6

    accuracy                           0.51        35
   macro avg       0.37      0.34      0.30        35
weighted avg       0.43      0.51      0.40        35

rep type:  <class 'str'>
{'eval_loss': 1.3577896356582642, 'eval_f1': 0.29593810444874274, 'eval_runtime': 0.2918, 'eval_samples_per_second': 119.965, 'eval_steps_per_second': 20.565, 'epoch': 3.0}


  0%|          | 0/6 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


report: 
               precision    recall  f1-score   support

           0       0.00      0.00      0.00         8
           1       0.00      0.00      0.00         6
           2       0.44      0.93      0.60        15
           3       0.00      0.00      0.00         6

    accuracy                           0.40        35
   macro avg       0.11      0.23      0.15        35
weighted avg       0.19      0.40      0.26        35

rep type:  <class 'str'>
{'eval_loss': 1.3718316555023193, 'eval_f1': 0.14893617021276595, 'eval_runtime': 0.3442, 'eval_samples_per_second': 101.693, 'eval_steps_per_second': 17.433, 'epoch': 4.0}


  0%|          | 0/6 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


report: 
               precision    recall  f1-score   support

           0       0.28      0.88      0.42         8
           1       0.33      0.17      0.22         6
           2       0.86      0.40      0.55        15
           3       0.00      0.00      0.00         6

    accuracy                           0.40        35
   macro avg       0.37      0.36      0.30        35
weighted avg       0.49      0.40      0.37        35

rep type:  <class 'str'>
{'eval_loss': 1.1904164552688599, 'eval_f1': 0.297979797979798, 'eval_runtime': 0.2691, 'eval_samples_per_second': 130.049, 'eval_steps_per_second': 22.294, 'epoch': 5.0}


  0%|          | 0/6 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


report: 
               precision    recall  f1-score   support

           0       0.50      0.62      0.56         8
           1       0.40      0.33      0.36         6
           2       0.60      0.80      0.69        15
           3       0.00      0.00      0.00         6

    accuracy                           0.54        35
   macro avg       0.38      0.44      0.40        35
weighted avg       0.44      0.54      0.48        35

rep type:  <class 'str'>
{'eval_loss': 1.2217886447906494, 'eval_f1': 0.4012265512265512, 'eval_runtime': 0.261, 'eval_samples_per_second': 134.103, 'eval_steps_per_second': 22.989, 'epoch': 6.0}


  0%|          | 0/6 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


report: 
               precision    recall  f1-score   support

           0       0.50      1.00      0.67         8
           1       0.50      0.50      0.50         6
           2       0.85      0.73      0.79        15
           3       0.00      0.00      0.00         6

    accuracy                           0.63        35
   macro avg       0.46      0.56      0.49        35
weighted avg       0.56      0.63      0.57        35

rep type:  <class 'str'>
{'eval_loss': 1.1068669557571411, 'eval_f1': 0.488095238095238, 'eval_runtime': 0.2605, 'eval_samples_per_second': 134.36, 'eval_steps_per_second': 23.033, 'epoch': 7.0}


  0%|          | 0/6 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


report: 
               precision    recall  f1-score   support

           0       0.57      1.00      0.73         8
           1       0.40      0.33      0.36         6
           2       0.75      0.80      0.77        15
           3       0.00      0.00      0.00         6

    accuracy                           0.63        35
   macro avg       0.43      0.53      0.47        35
weighted avg       0.52      0.63      0.56        35

rep type:  <class 'str'>
{'eval_loss': 1.1687743663787842, 'eval_f1': 0.46627565982404695, 'eval_runtime': 0.2568, 'eval_samples_per_second': 136.309, 'eval_steps_per_second': 23.367, 'epoch': 8.0}
{'train_runtime': 73.9768, 'train_samples_per_second': 14.924, 'train_steps_per_second': 2.487, 'train_loss': 0.8878276659094769, 'epoch': 8.0}


  0%|          | 0/6 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


report: 
               precision    recall  f1-score   support

           0       0.50      1.00      0.67         8
           1       0.50      0.50      0.50         6
           2       0.85      0.73      0.79        15
           3       0.00      0.00      0.00         6

    accuracy                           0.63        35
   macro avg       0.46      0.56      0.49        35
weighted avg       0.56      0.63      0.57        35

rep type:  <class 'str'>
F-1:  0.488095238095238
FOLD:  2
LEN DF:  138
done train df
done eval df
LEN EVAL:  35


Map:   0%|          | 0/138 [00:00<?, ? examples/s]

Map:   0%|          | 0/35 [00:00<?, ? examples/s]



  0%|          | 0/184 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


report: 
               precision    recall  f1-score   support

           0       0.00      0.00      0.00        14
           1       0.00      0.00      0.00         5
           2       0.43      1.00      0.60        15
           3       0.00      0.00      0.00         1

    accuracy                           0.43        35
   macro avg       0.11      0.25      0.15        35
weighted avg       0.18      0.43      0.26        35

rep type:  <class 'str'>
{'eval_loss': 1.1925339698791504, 'eval_f1': 0.15, 'eval_runtime': 0.2635, 'eval_samples_per_second': 132.826, 'eval_steps_per_second': 22.77, 'epoch': 1.0}


  0%|          | 0/6 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


report: 
               precision    recall  f1-score   support

           0       0.50      0.07      0.12        14
           1       0.00      0.00      0.00         5
           2       0.42      0.93      0.58        15
           3       0.00      0.00      0.00         1

    accuracy                           0.43        35
   macro avg       0.23      0.25      0.18        35
weighted avg       0.38      0.43      0.30        35

rep type:  <class 'str'>
{'eval_loss': 1.171134114265442, 'eval_f1': 0.17708333333333331, 'eval_runtime': 0.278, 'eval_samples_per_second': 125.92, 'eval_steps_per_second': 21.586, 'epoch': 2.0}


  0%|          | 0/6 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


report: 
               precision    recall  f1-score   support

           0       0.00      0.00      0.00        14
           1       0.00      0.00      0.00         5
           2       0.43      1.00      0.60        15
           3       0.00      0.00      0.00         1

    accuracy                           0.43        35
   macro avg       0.11      0.25      0.15        35
weighted avg       0.18      0.43      0.26        35

rep type:  <class 'str'>
{'eval_loss': 1.2608928680419922, 'eval_f1': 0.15, 'eval_runtime': 0.2672, 'eval_samples_per_second': 130.974, 'eval_steps_per_second': 22.453, 'epoch': 3.0}


  0%|          | 0/6 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


report: 
               precision    recall  f1-score   support

           0       0.50      0.07      0.12        14
           1       0.00      0.00      0.00         5
           2       0.42      0.93      0.58        15
           3       0.00      0.00      0.00         1

    accuracy                           0.43        35
   macro avg       0.23      0.25      0.18        35
weighted avg       0.38      0.43      0.30        35

rep type:  <class 'str'>
{'eval_loss': 1.3489395380020142, 'eval_f1': 0.17708333333333331, 'eval_runtime': 0.2612, 'eval_samples_per_second': 133.978, 'eval_steps_per_second': 22.968, 'epoch': 4.0}


  0%|          | 0/6 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


report: 
               precision    recall  f1-score   support

           0       1.00      0.07      0.13        14
           1       0.00      0.00      0.00         5
           2       0.47      1.00      0.64        15
           3       0.00      0.00      0.00         1

    accuracy                           0.46        35
   macro avg       0.37      0.27      0.19        35
weighted avg       0.60      0.46      0.33        35

rep type:  <class 'str'>
{'eval_loss': 1.3954232931137085, 'eval_f1': 0.19290780141843972, 'eval_runtime': 0.2753, 'eval_samples_per_second': 127.14, 'eval_steps_per_second': 21.795, 'epoch': 5.0}


  0%|          | 0/6 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


report: 
               precision    recall  f1-score   support

           0       0.78      0.50      0.61        14
           1       0.29      0.40      0.33         5
           2       0.58      0.73      0.65        15
           3       0.00      0.00      0.00         1

    accuracy                           0.57        35
   macro avg       0.41      0.41      0.40        35
weighted avg       0.60      0.57      0.57        35

rep type:  <class 'str'>
{'eval_loss': 1.0531623363494873, 'eval_f1': 0.3972719522591645, 'eval_runtime': 0.261, 'eval_samples_per_second': 134.099, 'eval_steps_per_second': 22.988, 'epoch': 6.0}


  0%|          | 0/6 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


report: 
               precision    recall  f1-score   support

           0       0.75      0.21      0.33        14
           1       0.15      0.40      0.22         5
           2       0.61      0.73      0.67        15
           3       0.00      0.00      0.00         1

    accuracy                           0.46        35
   macro avg       0.38      0.34      0.31        35
weighted avg       0.58      0.46      0.45        35

rep type:  <class 'str'>
{'eval_loss': 1.2484222650527954, 'eval_f1': 0.3055555555555556, 'eval_runtime': 0.2672, 'eval_samples_per_second': 131.012, 'eval_steps_per_second': 22.459, 'epoch': 7.0}


  0%|          | 0/6 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


report: 
               precision    recall  f1-score   support

           0       0.62      0.36      0.45        14
           1       0.33      0.40      0.36         5
           2       0.52      0.73      0.61        15
           3       0.00      0.00      0.00         1

    accuracy                           0.51        35
   macro avg       0.37      0.37      0.36        35
weighted avg       0.52      0.51      0.50        35

rep type:  <class 'str'>
{'eval_loss': 1.1822458505630493, 'eval_f1': 0.35732323232323226, 'eval_runtime': 0.2607, 'eval_samples_per_second': 134.263, 'eval_steps_per_second': 23.017, 'epoch': 8.0}
{'train_runtime': 59.4124, 'train_samples_per_second': 18.582, 'train_steps_per_second': 3.097, 'train_loss': 0.7985156515370244, 'epoch': 8.0}


  0%|          | 0/6 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


report: 
               precision    recall  f1-score   support

           0       0.78      0.50      0.61        14
           1       0.29      0.40      0.33         5
           2       0.58      0.73      0.65        15
           3       0.00      0.00      0.00         1

    accuracy                           0.57        35
   macro avg       0.41      0.41      0.40        35
weighted avg       0.60      0.57      0.57        35

rep type:  <class 'str'>
F-1:  0.3972719522591645
FOLD:  3
LEN DF:  139
done train df
done eval df
LEN EVAL:  34


Map:   0%|          | 0/139 [00:00<?, ? examples/s]

Map:   0%|          | 0/34 [00:00<?, ? examples/s]



  0%|          | 0/192 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


report: 
               precision    recall  f1-score   support

           0       0.00      0.00      0.00         9
           1       0.00      0.00      0.00         5
           2       0.52      1.00      0.68        17
           3       0.00      0.00      0.00         3

    accuracy                           0.50        34
   macro avg       0.13      0.25      0.17        34
weighted avg       0.26      0.50      0.34        34

rep type:  <class 'str'>
{'eval_loss': 1.2123949527740479, 'eval_f1': 0.16999999999999998, 'eval_runtime': 0.2619, 'eval_samples_per_second': 129.828, 'eval_steps_per_second': 22.911, 'epoch': 1.0}


  0%|          | 0/6 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


report: 
               precision    recall  f1-score   support

           0       1.00      0.22      0.36         9
           1       0.00      0.00      0.00         5
           2       0.53      1.00      0.69        17
           3       0.00      0.00      0.00         3

    accuracy                           0.56        34
   macro avg       0.38      0.31      0.26        34
weighted avg       0.53      0.56      0.44        34

rep type:  <class 'str'>
{'eval_loss': 1.1285600662231445, 'eval_f1': 0.26437847866419295, 'eval_runtime': 0.2605, 'eval_samples_per_second': 130.52, 'eval_steps_per_second': 23.033, 'epoch': 2.0}


  0%|          | 0/6 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


report: 
               precision    recall  f1-score   support

           0       0.50      0.22      0.31         9
           1       1.00      0.40      0.57         5
           2       0.57      0.94      0.71        17
           3       0.00      0.00      0.00         3

    accuracy                           0.59        34
   macro avg       0.52      0.39      0.40        34
weighted avg       0.57      0.59      0.52        34

rep type:  <class 'str'>
{'eval_loss': 1.1088975667953491, 'eval_f1': 0.3975579975579976, 'eval_runtime': 0.2606, 'eval_samples_per_second': 130.487, 'eval_steps_per_second': 23.027, 'epoch': 3.0}


  0%|          | 0/6 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


report: 
               precision    recall  f1-score   support

           0       0.60      0.33      0.43         9
           1       0.33      0.80      0.47         5
           2       0.71      0.71      0.71        17
           3       0.00      0.00      0.00         3

    accuracy                           0.56        34
   macro avg       0.41      0.46      0.40        34
weighted avg       0.56      0.56      0.54        34

rep type:  <class 'str'>
{'eval_loss': 1.3464467525482178, 'eval_f1': 0.4012605042016807, 'eval_runtime': 0.2565, 'eval_samples_per_second': 132.554, 'eval_steps_per_second': 23.392, 'epoch': 4.0}


  0%|          | 0/6 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


report: 
               precision    recall  f1-score   support

           0       0.33      0.22      0.27         9
           1       0.50      0.80      0.62         5
           2       0.75      0.88      0.81        17
           3       0.00      0.00      0.00         3

    accuracy                           0.62        34
   macro avg       0.40      0.48      0.42        34
weighted avg       0.54      0.62      0.57        34

rep type:  <class 'str'>
{'eval_loss': 1.4763931035995483, 'eval_f1': 0.4232155232155232, 'eval_runtime': 0.2812, 'eval_samples_per_second': 120.896, 'eval_steps_per_second': 21.335, 'epoch': 5.0}


  0%|          | 0/6 [00:00<?, ?it/s]

report: 
               precision    recall  f1-score   support

           0       0.30      0.33      0.32         9
           1       0.43      0.60      0.50         5
           2       0.81      0.76      0.79        17
           3       0.00      0.00      0.00         3

    accuracy                           0.56        34
   macro avg       0.39      0.42      0.40        34
weighted avg       0.55      0.56      0.55        34

rep type:  <class 'str'>
{'eval_loss': 1.6467573642730713, 'eval_f1': 0.4009170653907496, 'eval_runtime': 0.2574, 'eval_samples_per_second': 132.086, 'eval_steps_per_second': 23.309, 'epoch': 6.0}


  0%|          | 0/6 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


report: 
               precision    recall  f1-score   support

           0       0.25      0.22      0.24         9
           1       0.50      0.80      0.62         5
           2       0.78      0.82      0.80        17
           3       0.00      0.00      0.00         3

    accuracy                           0.59        34
   macro avg       0.38      0.46      0.41        34
weighted avg       0.53      0.59      0.55        34

rep type:  <class 'str'>
{'eval_loss': 1.7017351388931274, 'eval_f1': 0.41266968325791853, 'eval_runtime': 0.2631, 'eval_samples_per_second': 129.212, 'eval_steps_per_second': 22.802, 'epoch': 7.0}


  0%|          | 0/6 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


report: 
               precision    recall  f1-score   support

           0       0.33      0.22      0.27         9
           1       0.50      0.80      0.62         5
           2       0.75      0.88      0.81        17
           3       0.00      0.00      0.00         3

    accuracy                           0.62        34
   macro avg       0.40      0.48      0.42        34
weighted avg       0.54      0.62      0.57        34

rep type:  <class 'str'>
{'eval_loss': 1.747233271598816, 'eval_f1': 0.4232155232155232, 'eval_runtime': 0.2558, 'eval_samples_per_second': 132.94, 'eval_steps_per_second': 23.46, 'epoch': 8.0}
{'train_runtime': 62.1013, 'train_samples_per_second': 17.906, 'train_steps_per_second': 3.092, 'train_loss': 0.4552481174468994, 'epoch': 8.0}


  0%|          | 0/6 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


report: 
               precision    recall  f1-score   support

           0       0.33      0.22      0.27         9
           1       0.50      0.80      0.62         5
           2       0.75      0.88      0.81        17
           3       0.00      0.00      0.00         3

    accuracy                           0.62        34
   macro avg       0.40      0.48      0.42        34
weighted avg       0.54      0.62      0.57        34

rep type:  <class 'str'>
F-1:  0.4232155232155232
FOLD:  4
LEN DF:  139
done train df
done eval df
LEN EVAL:  34


Map:   0%|          | 0/139 [00:00<?, ? examples/s]

Map:   0%|          | 0/34 [00:00<?, ? examples/s]



  0%|          | 0/192 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


report: 
               precision    recall  f1-score   support

           0       0.00      0.00      0.00        15
           1       0.00      0.00      0.00         6
           2       0.26      1.00      0.42         9
           3       0.00      0.00      0.00         4

    accuracy                           0.26        34
   macro avg       0.07      0.25      0.10        34
weighted avg       0.07      0.26      0.11        34

rep type:  <class 'str'>
{'eval_loss': 1.361945629119873, 'eval_f1': 0.10465116279069768, 'eval_runtime': 0.2427, 'eval_samples_per_second': 140.067, 'eval_steps_per_second': 24.718, 'epoch': 1.0}


  0%|          | 0/6 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


report: 
               precision    recall  f1-score   support

           0       0.50      0.87      0.63        15
           1       0.00      0.00      0.00         6
           2       0.62      0.56      0.59         9
           3       0.00      0.00      0.00         4

    accuracy                           0.53        34
   macro avg       0.28      0.36      0.31        34
weighted avg       0.39      0.53      0.44        34

rep type:  <class 'str'>
{'eval_loss': 1.1791257858276367, 'eval_f1': 0.30559540889526543, 'eval_runtime': 0.2464, 'eval_samples_per_second': 137.984, 'eval_steps_per_second': 24.35, 'epoch': 2.0}


  0%|          | 0/6 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


report: 
               precision    recall  f1-score   support

           0       0.57      0.27      0.36        15
           1       0.57      0.67      0.62         6
           2       0.40      0.89      0.55         9
           3       0.00      0.00      0.00         4

    accuracy                           0.47        34
   macro avg       0.39      0.46      0.38        34
weighted avg       0.46      0.47      0.42        34

rep type:  <class 'str'>
{'eval_loss': 1.4302103519439697, 'eval_f1': 0.38268627923800336, 'eval_runtime': 0.246, 'eval_samples_per_second': 138.231, 'eval_steps_per_second': 24.394, 'epoch': 3.0}


  0%|          | 0/6 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


report: 
               precision    recall  f1-score   support

           0       0.56      0.67      0.61        15
           1       0.33      0.17      0.22         6
           2       0.46      0.67      0.55         9
           3       0.00      0.00      0.00         4

    accuracy                           0.50        34
   macro avg       0.34      0.38      0.34        34
weighted avg       0.43      0.50      0.45        34

rep type:  <class 'str'>
{'eval_loss': 1.457134485244751, 'eval_f1': 0.3434343434343434, 'eval_runtime': 0.2461, 'eval_samples_per_second': 138.181, 'eval_steps_per_second': 24.385, 'epoch': 4.0}


  0%|          | 0/6 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


report: 
               precision    recall  f1-score   support

           0       0.63      0.80      0.71        15
           1       0.75      0.50      0.60         6
           2       0.55      0.67      0.60         9
           3       0.00      0.00      0.00         4

    accuracy                           0.62        34
   macro avg       0.48      0.49      0.48        34
weighted avg       0.56      0.62      0.58        34

rep type:  <class 'str'>
{'eval_loss': 1.5320632457733154, 'eval_f1': 0.4764705882352941, 'eval_runtime': 0.2496, 'eval_samples_per_second': 136.239, 'eval_steps_per_second': 24.042, 'epoch': 5.0}


  0%|          | 0/6 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


report: 
               precision    recall  f1-score   support

           0       0.83      0.33      0.48        15
           1       0.67      1.00      0.80         6
           2       0.42      0.89      0.57         9
           3       0.00      0.00      0.00         4

    accuracy                           0.56        34
   macro avg       0.48      0.56      0.46        34
weighted avg       0.60      0.56      0.50        34

rep type:  <class 'str'>
{'eval_loss': 2.1215927600860596, 'eval_f1': 0.4619047619047619, 'eval_runtime': 0.2496, 'eval_samples_per_second': 136.215, 'eval_steps_per_second': 24.038, 'epoch': 6.0}


  0%|          | 0/6 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


report: 
               precision    recall  f1-score   support

           0       0.67      0.53      0.59        15
           1       0.57      0.67      0.62         6
           2       0.47      0.78      0.58         9
           3       0.00      0.00      0.00         4

    accuracy                           0.56        34
   macro avg       0.43      0.49      0.45        34
weighted avg       0.52      0.56      0.52        34

rep type:  <class 'str'>
{'eval_loss': 1.5521236658096313, 'eval_f1': 0.44782763532763536, 'eval_runtime': 0.2464, 'eval_samples_per_second': 138.007, 'eval_steps_per_second': 24.354, 'epoch': 7.0}


  0%|          | 0/6 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


report: 
               precision    recall  f1-score   support

           0       0.69      0.60      0.64        15
           1       0.57      0.67      0.62         6
           2       0.50      0.78      0.61         9
           3       0.00      0.00      0.00         4

    accuracy                           0.59        34
   macro avg       0.44      0.51      0.47        34
weighted avg       0.54      0.59      0.55        34

rep type:  <class 'str'>
{'eval_loss': 1.5565199851989746, 'eval_f1': 0.4667343526039178, 'eval_runtime': 0.2634, 'eval_samples_per_second': 129.061, 'eval_steps_per_second': 22.775, 'epoch': 8.0}
{'train_runtime': 62.9142, 'train_samples_per_second': 17.675, 'train_steps_per_second': 3.052, 'train_loss': 0.4454421599706014, 'epoch': 8.0}


  0%|          | 0/6 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


report: 
               precision    recall  f1-score   support

           0       0.63      0.80      0.71        15
           1       0.75      0.50      0.60         6
           2       0.55      0.67      0.60         9
           3       0.00      0.00      0.00         4

    accuracy                           0.62        34
   macro avg       0.48      0.49      0.48        34
weighted avg       0.56      0.62      0.58        34

rep type:  <class 'str'>
F-1:  0.4764705882352941


In [19]:
holdout_df = df = pd.read_csv("data/belief_benchmark_holdout.csv")
holdout_ds = Dataset.from_pandas(holdout_df)
holdout_ds = holdout_ds.map(tokenize, batched=True)

trainer.predict(holdout_ds)

Map:   0%|          | 0/146 [00:00<?, ? examples/s]

  0%|          | 0/25 [00:00<?, ?it/s]

report: 
               precision    recall  f1-score   support

           0       0.53      0.80      0.63        49
           1       0.67      0.38      0.49        26
           2       0.65      0.63      0.64        59
           3       0.00      0.00      0.00        12

    accuracy                           0.59       146
   macro avg       0.46      0.45      0.44       146
weighted avg       0.56      0.59      0.56       146

rep type:  <class 'str'>


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


PredictionOutput(predictions=array([[ 3.12748075e+00, -1.11327565e+00, -7.41954386e-01,
        -1.59095049e+00],
       [ 2.39527154e+00, -1.87988591e+00,  1.11360741e+00,
        -1.95397151e+00],
       [-2.22187090e+00, -1.84547091e+00,  4.52050304e+00,
        -1.39935017e+00],
       [ 3.85241342e+00, -1.50445819e+00, -1.24845469e+00,
        -1.21730661e+00],
       [ 2.42610884e+00, -2.10561609e+00,  9.97956753e-01,
        -1.43785489e+00],
       [ 7.10558712e-01,  2.01579595e+00, -2.03670073e+00,
        -7.05665588e-01],
       [ 2.23595381e+00,  1.01800799e+00, -2.00976968e+00,
        -1.51437628e+00],
       [-1.18690062e+00, -2.12582541e+00,  4.07473803e+00,
        -1.74679625e+00],
       [ 2.16095924e+00, -1.74904287e+00,  1.31997478e+00,
        -1.46853054e+00],
       [-2.51119757e+00, -1.46990967e+00,  4.16050053e+00,
        -6.10300362e-01],
       [-7.23505914e-01, -2.18629956e+00,  2.73375130e+00,
        -1.09689020e-01],
       [ 3.44948816e+00,  2.81788707

In [20]:
import torch
torch.cuda.empty_cache()

In [21]:
torch.cuda.memory_summary(device=None, abbreviated=False)
output.write("{torch.cuda.memory_summary(device=None, abbreviated=False)}")

output.close()