In [1]:
from torch import nn
from transformers.modeling_outputs import SequenceClassifierOutput
from transformers.models.bert.modeling_bert import BertModel, BertPreTrainedModel

import torch
torch.cuda.empty_cache()

In [2]:
import pandas as pd
import numpy as np
import os

In [3]:
from transformers import AutoTokenizer
from datasets import Dataset, DatasetDict
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer

In [4]:
from sklearn.model_selection import train_test_split
from sklearn import metrics
from sklearn.model_selection import KFold

In [5]:
models_dir = "models/bert-k-fold"

In [6]:
labels = open('data/classes.txt').read().splitlines()
df = pd.read_csv("data/belief_benchmark_all_train.csv")
df.head()

Unnamed: 0,text,sentiment,label
0,Private-sector actors expect their business pa...,NEGATIVE,1
1,Lack of product uniformity is perceived by con...,"NEGATIVE, NEGATIVE, NEGATIVE, NEGATIVE",1
2,"However , to the extent that the empirical str...",POSITIVE,2
3,Firms prefer to contract with producers with w...,POSITIVE,2
4,Farmers attempted to use alternate wetting and...,POSITIVE,2


In [7]:
transformer_name = "bert-base-cased"
tokenizer = AutoTokenizer.from_pretrained(transformer_name)
# NOTE: for cross validation, the model should be initialized inside the cv loop

In [8]:
def tokenize(batch):
    return tokenizer(batch['text'], truncation=True)

In [9]:
def compute_metrics(eval_pred):
    y_true = eval_pred.label_ids
    y_pred = np.argmax(eval_pred.predictions, axis=-1)
    report = metrics.classification_report(y_true, y_pred)
    print("report: \n", report)
    
    print("rep type: ", type(report))
    

    return {'f1':metrics.f1_score(y_true, y_pred, average="macro")}

In [10]:
class BertForSequenceClassification(BertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels
        self.bert = BertModel(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)
        self.init_weights()
        
    def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, labels=None, **kwargs):
        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            **kwargs,
        )
        cls_outputs = outputs.last_hidden_state[:, 0, :]
        cls_outputs = self.dropout(cls_outputs)
        logits = self.classifier(cls_outputs)
        loss = None
        if labels is not None:
            loss_fn = nn.CrossEntropyLoss()
            loss = loss_fn(logits, labels)
        return SequenceClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )


In [11]:
# this is for creating cross-validation folds
def get_sample_based_on_idx(data, indeces):
    return data.iloc[indeces, :].reset_index()

In [12]:
# defining hyperparams
num_epochs = 8
batch_size = 6
weight_decay = 0.01
print(f"num_epochs: {num_epochs}, batch_size: {batch_size}, weight_decay: {weight_decay}")
training_args = TrainingArguments(
    output_dir="./results_sentiment_analysis",
    log_level='error',
    num_train_epochs=num_epochs,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    evaluation_strategy='epoch',
    save_strategy='epoch',
    weight_decay=weight_decay,
    load_best_model_at_end=True, # this is supposed to make sure the best model is loaded by the trainer at the end
    metric_for_best_model="eval_f1" 
    )

num_epochs: 8, batch_size: 6, weight_decay: 0.01


In [13]:
# specify what percent of the all_train data should be used in cross-fold validation 
# we can use this to create a learning curve, e.g., performance with 50, 80, 90% of data
train_size = 0.8

if train_size < 1.0:
    df, _ = train_test_split(df, test_size=1.0-train_size, random_state=1, stratify=df[['label']])
print(f"train_size: {train_size} (total samples: {len(df)})")

train_size: 0.8 (total samples: 173)


In [14]:
output = open("original.txt", "a")

fold = 0
kfold = KFold(n_splits=5, shuffle=True, random_state=1)
for train_df_idx, eval_df_idx in kfold.split(df):
    
    print(f"************** BEGIN FOLD: {fold+1} **************")
    output.write(f"FOLD: {fold}\n")
    new_df = pd.DataFrame()
    
    train_df = get_sample_based_on_idx(df, train_df_idx)
    print("LEN DF: ", len(train_df))
    output.write(f"LEN DF: {len(train_df)}\n")

    print("done train df")
    output.write("done train df\n")
    eval_df = get_sample_based_on_idx(df, eval_df_idx)

    print("done eval df")
    output.write("done eval df\n")
    print("LEN EVAL: ", len(eval_df))
    output.write(f"LEN EVAL: {len(eval_df)}\n")

    ds = DatasetDict()
    ds['train'] = Dataset.from_pandas(train_df)
    ds['validation'] = Dataset.from_pandas(eval_df)
    train_ds = ds['train'].map(
        tokenize, batched=True
    )
    eval_ds = ds['validation'].map(
        tokenize,
        batched=True
    )

    model = AutoModelForSequenceClassification.from_pretrained(transformer_name, num_labels=4)
    tokenizer = AutoTokenizer.from_pretrained(transformer_name)

    trainer = Trainer(
        model=model,
        args=training_args,
        compute_metrics=compute_metrics,
        train_dataset=train_ds,
        eval_dataset=eval_ds,
        tokenizer=tokenizer,
    )
    trainer.train()
    # after training, predict (will use best model?)
    preds = trainer.predict(eval_ds)
#     print("HERE: " , preds)
    final_preds = [np.argmax(x) for x in preds.predictions]
    real_f1 = metrics.f1_score(final_preds, eval_df["label"], average="macro")
    print("F-1: ", real_f1)
    output.write(f"F-1: {real_f1}\n")
    model_name = f"{transformer_name}-best-of-fold-{fold}-f1-{real_f1}"
    model_dir = os.path.join(models_dir, model_name)

    trainer.save_model(model_dir)

    for i, item in enumerate(final_preds):
        if item != eval_ds["label"][i]:
            wrong_df = pd.DataFrame()
            wrong_df["text"] = [eval_df["text"][i]]
            wrong_df["real"] = [eval_df["label"][i]]
            wrong_df["predicted"] = [item]
            new_df = pd.concat([new_df, wrong_df])

    new_df.to_csv(f"{models_dir}/wrong_predictions_{fold}.csv")

    print(f"************** END FOLD: {fold+1} **************\n")
    fold += 1
        

************** BEGIN FOLD: 1 **************
LEN DF:  138
done train df
done eval df
LEN EVAL:  35


Map:   0%|          | 0/138 [00:00<?, ? examples/s]

Map:   0%|          | 0/35 [00:00<?, ? examples/s]

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at b

  0%|          | 0/184 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


report: 
               precision    recall  f1-score   support

           0       0.33      0.08      0.13        12
           1       0.00      0.00      0.00         9
           2       0.41      0.93      0.57        14

    accuracy                           0.40        35
   macro avg       0.25      0.34      0.23        35
weighted avg       0.28      0.40      0.27        35

rep type:  <class 'str'>
{'eval_loss': 1.1941088438034058, 'eval_f1': 0.2328502415458937, 'eval_runtime': 0.2585, 'eval_samples_per_second': 135.41, 'eval_steps_per_second': 23.213, 'epoch': 1.0}


  0%|          | 0/6 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


report: 
               precision    recall  f1-score   support

           0       0.50      0.25      0.33        12
           1       0.00      0.00      0.00         9
           2       0.45      0.93      0.60        14

    accuracy                           0.46        35
   macro avg       0.32      0.39      0.31        35
weighted avg       0.35      0.46      0.36        35

rep type:  <class 'str'>
{'eval_loss': 1.248270034790039, 'eval_f1': 0.31266149870801035, 'eval_runtime': 0.249, 'eval_samples_per_second': 140.577, 'eval_steps_per_second': 24.099, 'epoch': 2.0}


  0%|          | 0/6 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


report: 
               precision    recall  f1-score   support

           0       0.38      0.25      0.30        12
           1       1.00      0.11      0.20         9
           2       0.56      1.00      0.72        14
           3       0.00      0.00      0.00         0

    accuracy                           0.51        35
   macro avg       0.48      0.34      0.30        35
weighted avg       0.61      0.51      0.44        35

rep type:  <class 'str'>
{'eval_loss': 1.1952141523361206, 'eval_f1': 0.3044871794871795, 'eval_runtime': 0.2508, 'eval_samples_per_second': 139.573, 'eval_steps_per_second': 23.927, 'epoch': 3.0}


  0%|          | 0/6 [00:00<?, ?it/s]

report: 
               precision    recall  f1-score   support

           0       0.62      0.42      0.50        12
           1       0.80      0.89      0.84         9
           2       0.65      0.79      0.71        14

    accuracy                           0.69        35
   macro avg       0.69      0.70      0.68        35
weighted avg       0.68      0.69      0.67        35

rep type:  <class 'str'>
{'eval_loss': 1.0636564493179321, 'eval_f1': 0.6839275608375778, 'eval_runtime': 0.2473, 'eval_samples_per_second': 141.504, 'eval_steps_per_second': 24.258, 'epoch': 4.0}


  0%|          | 0/6 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


report: 
               precision    recall  f1-score   support

           0       0.40      0.50      0.44        12
           1       1.00      0.22      0.36         9
           2       0.67      0.71      0.69        14
           3       0.00      0.00      0.00         0

    accuracy                           0.51        35
   macro avg       0.52      0.36      0.37        35
weighted avg       0.66      0.51      0.52        35

rep type:  <class 'str'>
{'eval_loss': 1.2590973377227783, 'eval_f1': 0.3744339951236503, 'eval_runtime': 0.2433, 'eval_samples_per_second': 143.842, 'eval_steps_per_second': 24.659, 'epoch': 5.0}


  0%|          | 0/6 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


report: 
               precision    recall  f1-score   support

           0       0.67      0.50      0.57        12
           1       0.78      0.78      0.78         9
           2       0.64      0.64      0.64        14
           3       0.00      0.00      0.00         0

    accuracy                           0.63        35
   macro avg       0.52      0.48      0.50        35
weighted avg       0.69      0.63      0.65        35

rep type:  <class 'str'>
{'eval_loss': 1.181607961654663, 'eval_f1': 0.498015873015873, 'eval_runtime': 0.2519, 'eval_samples_per_second': 138.937, 'eval_steps_per_second': 23.818, 'epoch': 6.0}


  0%|          | 0/6 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


report: 
               precision    recall  f1-score   support

           0       0.55      0.50      0.52        12
           1       0.80      0.44      0.57         9
           2       0.67      0.71      0.69        14
           3       0.00      0.00      0.00         0

    accuracy                           0.57        35
   macro avg       0.50      0.41      0.45        35
weighted avg       0.66      0.57      0.60        35

rep type:  <class 'str'>
{'eval_loss': 1.4113693237304688, 'eval_f1': 0.44570571856928676, 'eval_runtime': 0.2551, 'eval_samples_per_second': 137.219, 'eval_steps_per_second': 23.523, 'epoch': 7.0}


  0%|          | 0/6 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


report: 
               precision    recall  f1-score   support

           0       0.60      0.50      0.55        12
           1       0.83      0.56      0.67         9
           2       0.67      0.71      0.69        14
           3       0.00      0.00      0.00         0

    accuracy                           0.60        35
   macro avg       0.53      0.44      0.48        35
weighted avg       0.69      0.60      0.63        35

rep type:  <class 'str'>
{'eval_loss': 1.3613492250442505, 'eval_f1': 0.47544409613375127, 'eval_runtime': 0.2671, 'eval_samples_per_second': 131.032, 'eval_steps_per_second': 22.463, 'epoch': 8.0}
{'train_runtime': 64.2664, 'train_samples_per_second': 17.179, 'train_steps_per_second': 2.863, 'train_loss': 0.5774908895077913, 'epoch': 8.0}


  0%|          | 0/6 [00:00<?, ?it/s]

report: 
               precision    recall  f1-score   support

           0       0.62      0.42      0.50        12
           1       0.80      0.89      0.84         9
           2       0.65      0.79      0.71        14

    accuracy                           0.69        35
   macro avg       0.69      0.70      0.68        35
weighted avg       0.68      0.69      0.67        35

rep type:  <class 'str'>
F-1:  0.6839275608375778
************** END FOLD: 1 **************

************** BEGIN FOLD: 2 **************
LEN DF:  138
done train df
done eval df
LEN EVAL:  35


Map:   0%|          | 0/138 [00:00<?, ? examples/s]

Map:   0%|          | 0/35 [00:00<?, ? examples/s]



  0%|          | 0/184 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


report: 
               precision    recall  f1-score   support

           0       0.23      1.00      0.37         8
           1       0.00      0.00      0.00         6
           2       0.00      0.00      0.00        15
           3       0.00      0.00      0.00         6

    accuracy                           0.23        35
   macro avg       0.06      0.25      0.09        35
weighted avg       0.05      0.23      0.09        35

rep type:  <class 'str'>
{'eval_loss': 1.4057917594909668, 'eval_f1': 0.09302325581395347, 'eval_runtime': 0.2503, 'eval_samples_per_second': 139.851, 'eval_steps_per_second': 23.974, 'epoch': 1.0}


  0%|          | 0/6 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


report: 
               precision    recall  f1-score   support

           0       0.23      1.00      0.37         8
           1       0.00      0.00      0.00         6
           2       0.00      0.00      0.00        15
           3       0.00      0.00      0.00         6

    accuracy                           0.23        35
   macro avg       0.06      0.25      0.09        35
weighted avg       0.05      0.23      0.09        35

rep type:  <class 'str'>
{'eval_loss': 1.4484484195709229, 'eval_f1': 0.09302325581395347, 'eval_runtime': 0.2535, 'eval_samples_per_second': 138.063, 'eval_steps_per_second': 23.668, 'epoch': 2.0}


  0%|          | 0/6 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


report: 
               precision    recall  f1-score   support

           0       1.00      0.38      0.55         8
           1       0.00      0.00      0.00         6
           2       0.47      1.00      0.64        15
           3       0.00      0.00      0.00         6

    accuracy                           0.51        35
   macro avg       0.37      0.34      0.30        35
weighted avg       0.43      0.51      0.40        35

rep type:  <class 'str'>
{'eval_loss': 1.3577896356582642, 'eval_f1': 0.29593810444874274, 'eval_runtime': 0.2529, 'eval_samples_per_second': 138.39, 'eval_steps_per_second': 23.724, 'epoch': 3.0}


  0%|          | 0/6 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


report: 
               precision    recall  f1-score   support

           0       0.00      0.00      0.00         8
           1       0.00      0.00      0.00         6
           2       0.44      0.93      0.60        15
           3       0.00      0.00      0.00         6

    accuracy                           0.40        35
   macro avg       0.11      0.23      0.15        35
weighted avg       0.19      0.40      0.26        35

rep type:  <class 'str'>
{'eval_loss': 1.3718316555023193, 'eval_f1': 0.14893617021276595, 'eval_runtime': 0.2567, 'eval_samples_per_second': 136.335, 'eval_steps_per_second': 23.372, 'epoch': 4.0}


  0%|          | 0/6 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


report: 
               precision    recall  f1-score   support

           0       0.28      0.88      0.42         8
           1       0.33      0.17      0.22         6
           2       0.86      0.40      0.55        15
           3       0.00      0.00      0.00         6

    accuracy                           0.40        35
   macro avg       0.37      0.36      0.30        35
weighted avg       0.49      0.40      0.37        35

rep type:  <class 'str'>
{'eval_loss': 1.1904164552688599, 'eval_f1': 0.297979797979798, 'eval_runtime': 0.2558, 'eval_samples_per_second': 136.829, 'eval_steps_per_second': 23.456, 'epoch': 5.0}


  0%|          | 0/6 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


report: 
               precision    recall  f1-score   support

           0       0.50      0.62      0.56         8
           1       0.40      0.33      0.36         6
           2       0.60      0.80      0.69        15
           3       0.00      0.00      0.00         6

    accuracy                           0.54        35
   macro avg       0.38      0.44      0.40        35
weighted avg       0.44      0.54      0.48        35

rep type:  <class 'str'>
{'eval_loss': 1.2217886447906494, 'eval_f1': 0.4012265512265512, 'eval_runtime': 0.2577, 'eval_samples_per_second': 135.809, 'eval_steps_per_second': 23.281, 'epoch': 6.0}


  0%|          | 0/6 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


report: 
               precision    recall  f1-score   support

           0       0.50      1.00      0.67         8
           1       0.50      0.50      0.50         6
           2       0.85      0.73      0.79        15
           3       0.00      0.00      0.00         6

    accuracy                           0.63        35
   macro avg       0.46      0.56      0.49        35
weighted avg       0.56      0.63      0.57        35

rep type:  <class 'str'>
{'eval_loss': 1.1068669557571411, 'eval_f1': 0.488095238095238, 'eval_runtime': 0.2565, 'eval_samples_per_second': 136.431, 'eval_steps_per_second': 23.388, 'epoch': 7.0}


  0%|          | 0/6 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


report: 
               precision    recall  f1-score   support

           0       0.57      1.00      0.73         8
           1       0.40      0.33      0.36         6
           2       0.75      0.80      0.77        15
           3       0.00      0.00      0.00         6

    accuracy                           0.63        35
   macro avg       0.43      0.53      0.47        35
weighted avg       0.52      0.63      0.56        35

rep type:  <class 'str'>
{'eval_loss': 1.1687743663787842, 'eval_f1': 0.46627565982404695, 'eval_runtime': 0.2662, 'eval_samples_per_second': 131.504, 'eval_steps_per_second': 22.543, 'epoch': 8.0}
{'train_runtime': 66.8361, 'train_samples_per_second': 16.518, 'train_steps_per_second': 2.753, 'train_loss': 0.8878276659094769, 'epoch': 8.0}


  0%|          | 0/6 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


report: 
               precision    recall  f1-score   support

           0       0.50      1.00      0.67         8
           1       0.50      0.50      0.50         6
           2       0.85      0.73      0.79        15
           3       0.00      0.00      0.00         6

    accuracy                           0.63        35
   macro avg       0.46      0.56      0.49        35
weighted avg       0.56      0.63      0.57        35

rep type:  <class 'str'>
F-1:  0.488095238095238
************** END FOLD: 2 **************

************** BEGIN FOLD: 3 **************
LEN DF:  138
done train df
done eval df
LEN EVAL:  35


Map:   0%|          | 0/138 [00:00<?, ? examples/s]

Map:   0%|          | 0/35 [00:00<?, ? examples/s]



  0%|          | 0/184 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


report: 
               precision    recall  f1-score   support

           0       0.00      0.00      0.00        14
           1       0.00      0.00      0.00         5
           2       0.43      1.00      0.60        15
           3       0.00      0.00      0.00         1

    accuracy                           0.43        35
   macro avg       0.11      0.25      0.15        35
weighted avg       0.18      0.43      0.26        35

rep type:  <class 'str'>
{'eval_loss': 1.1925339698791504, 'eval_f1': 0.15, 'eval_runtime': 0.2675, 'eval_samples_per_second': 130.864, 'eval_steps_per_second': 22.434, 'epoch': 1.0}


  0%|          | 0/6 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


report: 
               precision    recall  f1-score   support

           0       0.50      0.07      0.12        14
           1       0.00      0.00      0.00         5
           2       0.42      0.93      0.58        15
           3       0.00      0.00      0.00         1

    accuracy                           0.43        35
   macro avg       0.23      0.25      0.18        35
weighted avg       0.38      0.43      0.30        35

rep type:  <class 'str'>
{'eval_loss': 1.171134114265442, 'eval_f1': 0.17708333333333331, 'eval_runtime': 0.2693, 'eval_samples_per_second': 129.973, 'eval_steps_per_second': 22.281, 'epoch': 2.0}


  0%|          | 0/6 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


report: 
               precision    recall  f1-score   support

           0       0.00      0.00      0.00        14
           1       0.00      0.00      0.00         5
           2       0.43      1.00      0.60        15
           3       0.00      0.00      0.00         1

    accuracy                           0.43        35
   macro avg       0.11      0.25      0.15        35
weighted avg       0.18      0.43      0.26        35

rep type:  <class 'str'>
{'eval_loss': 1.2608928680419922, 'eval_f1': 0.15, 'eval_runtime': 0.2615, 'eval_samples_per_second': 133.848, 'eval_steps_per_second': 22.945, 'epoch': 3.0}


  0%|          | 0/6 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


report: 
               precision    recall  f1-score   support

           0       0.50      0.07      0.12        14
           1       0.00      0.00      0.00         5
           2       0.42      0.93      0.58        15
           3       0.00      0.00      0.00         1

    accuracy                           0.43        35
   macro avg       0.23      0.25      0.18        35
weighted avg       0.38      0.43      0.30        35

rep type:  <class 'str'>
{'eval_loss': 1.3489395380020142, 'eval_f1': 0.17708333333333331, 'eval_runtime': 0.2649, 'eval_samples_per_second': 132.142, 'eval_steps_per_second': 22.653, 'epoch': 4.0}


  0%|          | 0/6 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


report: 
               precision    recall  f1-score   support

           0       1.00      0.07      0.13        14
           1       0.00      0.00      0.00         5
           2       0.47      1.00      0.64        15
           3       0.00      0.00      0.00         1

    accuracy                           0.46        35
   macro avg       0.37      0.27      0.19        35
weighted avg       0.60      0.46      0.33        35

rep type:  <class 'str'>
{'eval_loss': 1.3954232931137085, 'eval_f1': 0.19290780141843972, 'eval_runtime': 0.2669, 'eval_samples_per_second': 131.12, 'eval_steps_per_second': 22.478, 'epoch': 5.0}


  0%|          | 0/6 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


report: 
               precision    recall  f1-score   support

           0       0.78      0.50      0.61        14
           1       0.29      0.40      0.33         5
           2       0.58      0.73      0.65        15
           3       0.00      0.00      0.00         1

    accuracy                           0.57        35
   macro avg       0.41      0.41      0.40        35
weighted avg       0.60      0.57      0.57        35

rep type:  <class 'str'>
{'eval_loss': 1.0531623363494873, 'eval_f1': 0.3972719522591645, 'eval_runtime': 0.278, 'eval_samples_per_second': 125.912, 'eval_steps_per_second': 21.585, 'epoch': 6.0}


  0%|          | 0/6 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


report: 
               precision    recall  f1-score   support

           0       0.75      0.21      0.33        14
           1       0.15      0.40      0.22         5
           2       0.61      0.73      0.67        15
           3       0.00      0.00      0.00         1

    accuracy                           0.46        35
   macro avg       0.38      0.34      0.31        35
weighted avg       0.58      0.46      0.45        35

rep type:  <class 'str'>
{'eval_loss': 1.2484222650527954, 'eval_f1': 0.3055555555555556, 'eval_runtime': 0.2615, 'eval_samples_per_second': 133.856, 'eval_steps_per_second': 22.947, 'epoch': 7.0}


  0%|          | 0/6 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


report: 
               precision    recall  f1-score   support

           0       0.62      0.36      0.45        14
           1       0.33      0.40      0.36         5
           2       0.52      0.73      0.61        15
           3       0.00      0.00      0.00         1

    accuracy                           0.51        35
   macro avg       0.37      0.37      0.36        35
weighted avg       0.52      0.51      0.50        35

rep type:  <class 'str'>
{'eval_loss': 1.1822458505630493, 'eval_f1': 0.35732323232323226, 'eval_runtime': 0.2736, 'eval_samples_per_second': 127.943, 'eval_steps_per_second': 21.933, 'epoch': 8.0}
{'train_runtime': 64.9175, 'train_samples_per_second': 17.006, 'train_steps_per_second': 2.834, 'train_loss': 0.7985156515370244, 'epoch': 8.0}


  0%|          | 0/6 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


report: 
               precision    recall  f1-score   support

           0       0.78      0.50      0.61        14
           1       0.29      0.40      0.33         5
           2       0.58      0.73      0.65        15
           3       0.00      0.00      0.00         1

    accuracy                           0.57        35
   macro avg       0.41      0.41      0.40        35
weighted avg       0.60      0.57      0.57        35

rep type:  <class 'str'>
F-1:  0.3972719522591645
************** END FOLD: 3 **************

************** BEGIN FOLD: 4 **************
LEN DF:  139
done train df
done eval df
LEN EVAL:  34


Map:   0%|          | 0/139 [00:00<?, ? examples/s]

Map:   0%|          | 0/34 [00:00<?, ? examples/s]



  0%|          | 0/192 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


report: 
               precision    recall  f1-score   support

           0       0.00      0.00      0.00         9
           1       0.00      0.00      0.00         5
           2       0.52      1.00      0.68        17
           3       0.00      0.00      0.00         3

    accuracy                           0.50        34
   macro avg       0.13      0.25      0.17        34
weighted avg       0.26      0.50      0.34        34

rep type:  <class 'str'>
{'eval_loss': 1.2123949527740479, 'eval_f1': 0.16999999999999998, 'eval_runtime': 0.2603, 'eval_samples_per_second': 130.628, 'eval_steps_per_second': 23.052, 'epoch': 1.0}


  0%|          | 0/6 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


report: 
               precision    recall  f1-score   support

           0       1.00      0.22      0.36         9
           1       0.00      0.00      0.00         5
           2       0.53      1.00      0.69        17
           3       0.00      0.00      0.00         3

    accuracy                           0.56        34
   macro avg       0.38      0.31      0.26        34
weighted avg       0.53      0.56      0.44        34

rep type:  <class 'str'>
{'eval_loss': 1.1285600662231445, 'eval_f1': 0.26437847866419295, 'eval_runtime': 0.2711, 'eval_samples_per_second': 125.402, 'eval_steps_per_second': 22.13, 'epoch': 2.0}


  0%|          | 0/6 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


report: 
               precision    recall  f1-score   support

           0       0.50      0.22      0.31         9
           1       1.00      0.40      0.57         5
           2       0.57      0.94      0.71        17
           3       0.00      0.00      0.00         3

    accuracy                           0.59        34
   macro avg       0.52      0.39      0.40        34
weighted avg       0.57      0.59      0.52        34

rep type:  <class 'str'>
{'eval_loss': 1.1088975667953491, 'eval_f1': 0.3975579975579976, 'eval_runtime': 0.2654, 'eval_samples_per_second': 128.103, 'eval_steps_per_second': 22.606, 'epoch': 3.0}


  0%|          | 0/6 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


report: 
               precision    recall  f1-score   support

           0       0.60      0.33      0.43         9
           1       0.33      0.80      0.47         5
           2       0.71      0.71      0.71        17
           3       0.00      0.00      0.00         3

    accuracy                           0.56        34
   macro avg       0.41      0.46      0.40        34
weighted avg       0.56      0.56      0.54        34

rep type:  <class 'str'>
{'eval_loss': 1.3464467525482178, 'eval_f1': 0.4012605042016807, 'eval_runtime': 0.2625, 'eval_samples_per_second': 129.509, 'eval_steps_per_second': 22.854, 'epoch': 4.0}


  0%|          | 0/6 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


report: 
               precision    recall  f1-score   support

           0       0.33      0.22      0.27         9
           1       0.50      0.80      0.62         5
           2       0.75      0.88      0.81        17
           3       0.00      0.00      0.00         3

    accuracy                           0.62        34
   macro avg       0.40      0.48      0.42        34
weighted avg       0.54      0.62      0.57        34

rep type:  <class 'str'>
{'eval_loss': 1.4763931035995483, 'eval_f1': 0.4232155232155232, 'eval_runtime': 0.2642, 'eval_samples_per_second': 128.688, 'eval_steps_per_second': 22.71, 'epoch': 5.0}


  0%|          | 0/6 [00:00<?, ?it/s]

report: 
               precision    recall  f1-score   support

           0       0.30      0.33      0.32         9
           1       0.43      0.60      0.50         5
           2       0.81      0.76      0.79        17
           3       0.00      0.00      0.00         3

    accuracy                           0.56        34
   macro avg       0.39      0.42      0.40        34
weighted avg       0.55      0.56      0.55        34

rep type:  <class 'str'>
{'eval_loss': 1.6467573642730713, 'eval_f1': 0.4009170653907496, 'eval_runtime': 0.2585, 'eval_samples_per_second': 131.54, 'eval_steps_per_second': 23.213, 'epoch': 6.0}


  0%|          | 0/6 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


report: 
               precision    recall  f1-score   support

           0       0.25      0.22      0.24         9
           1       0.50      0.80      0.62         5
           2       0.78      0.82      0.80        17
           3       0.00      0.00      0.00         3

    accuracy                           0.59        34
   macro avg       0.38      0.46      0.41        34
weighted avg       0.53      0.59      0.55        34

rep type:  <class 'str'>
{'eval_loss': 1.7017351388931274, 'eval_f1': 0.41266968325791853, 'eval_runtime': 0.2572, 'eval_samples_per_second': 132.171, 'eval_steps_per_second': 23.324, 'epoch': 7.0}


  0%|          | 0/6 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


report: 
               precision    recall  f1-score   support

           0       0.33      0.22      0.27         9
           1       0.50      0.80      0.62         5
           2       0.75      0.88      0.81        17
           3       0.00      0.00      0.00         3

    accuracy                           0.62        34
   macro avg       0.40      0.48      0.42        34
weighted avg       0.54      0.62      0.57        34

rep type:  <class 'str'>
{'eval_loss': 1.747233271598816, 'eval_f1': 0.4232155232155232, 'eval_runtime': 0.256, 'eval_samples_per_second': 132.813, 'eval_steps_per_second': 23.438, 'epoch': 8.0}
{'train_runtime': 66.3877, 'train_samples_per_second': 16.75, 'train_steps_per_second': 2.892, 'train_loss': 0.4552481174468994, 'epoch': 8.0}


  0%|          | 0/6 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


report: 
               precision    recall  f1-score   support

           0       0.33      0.22      0.27         9
           1       0.50      0.80      0.62         5
           2       0.75      0.88      0.81        17
           3       0.00      0.00      0.00         3

    accuracy                           0.62        34
   macro avg       0.40      0.48      0.42        34
weighted avg       0.54      0.62      0.57        34

rep type:  <class 'str'>
F-1:  0.4232155232155232
************** END FOLD: 4 **************

************** BEGIN FOLD: 5 **************
LEN DF:  139
done train df
done eval df
LEN EVAL:  34


Map:   0%|          | 0/139 [00:00<?, ? examples/s]

Map:   0%|          | 0/34 [00:00<?, ? examples/s]



  0%|          | 0/192 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


report: 
               precision    recall  f1-score   support

           0       0.00      0.00      0.00        15
           1       0.00      0.00      0.00         6
           2       0.26      1.00      0.42         9
           3       0.00      0.00      0.00         4

    accuracy                           0.26        34
   macro avg       0.07      0.25      0.10        34
weighted avg       0.07      0.26      0.11        34

rep type:  <class 'str'>
{'eval_loss': 1.361945629119873, 'eval_f1': 0.10465116279069768, 'eval_runtime': 0.2636, 'eval_samples_per_second': 128.991, 'eval_steps_per_second': 22.763, 'epoch': 1.0}


  0%|          | 0/6 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


report: 
               precision    recall  f1-score   support

           0       0.50      0.87      0.63        15
           1       0.00      0.00      0.00         6
           2       0.62      0.56      0.59         9
           3       0.00      0.00      0.00         4

    accuracy                           0.53        34
   macro avg       0.28      0.36      0.31        34
weighted avg       0.39      0.53      0.44        34

rep type:  <class 'str'>
{'eval_loss': 1.1791257858276367, 'eval_f1': 0.30559540889526543, 'eval_runtime': 0.2479, 'eval_samples_per_second': 137.125, 'eval_steps_per_second': 24.199, 'epoch': 2.0}


  0%|          | 0/6 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


report: 
               precision    recall  f1-score   support

           0       0.57      0.27      0.36        15
           1       0.57      0.67      0.62         6
           2       0.40      0.89      0.55         9
           3       0.00      0.00      0.00         4

    accuracy                           0.47        34
   macro avg       0.39      0.46      0.38        34
weighted avg       0.46      0.47      0.42        34

rep type:  <class 'str'>
{'eval_loss': 1.4302103519439697, 'eval_f1': 0.38268627923800336, 'eval_runtime': 0.2774, 'eval_samples_per_second': 122.581, 'eval_steps_per_second': 21.632, 'epoch': 3.0}


  0%|          | 0/6 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


report: 
               precision    recall  f1-score   support

           0       0.56      0.67      0.61        15
           1       0.33      0.17      0.22         6
           2       0.46      0.67      0.55         9
           3       0.00      0.00      0.00         4

    accuracy                           0.50        34
   macro avg       0.34      0.38      0.34        34
weighted avg       0.43      0.50      0.45        34

rep type:  <class 'str'>
{'eval_loss': 1.457134485244751, 'eval_f1': 0.3434343434343434, 'eval_runtime': 0.2408, 'eval_samples_per_second': 141.17, 'eval_steps_per_second': 24.912, 'epoch': 4.0}


  0%|          | 0/6 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


report: 
               precision    recall  f1-score   support

           0       0.63      0.80      0.71        15
           1       0.75      0.50      0.60         6
           2       0.55      0.67      0.60         9
           3       0.00      0.00      0.00         4

    accuracy                           0.62        34
   macro avg       0.48      0.49      0.48        34
weighted avg       0.56      0.62      0.58        34

rep type:  <class 'str'>
{'eval_loss': 1.5320632457733154, 'eval_f1': 0.4764705882352941, 'eval_runtime': 0.2565, 'eval_samples_per_second': 132.575, 'eval_steps_per_second': 23.396, 'epoch': 5.0}


  0%|          | 0/6 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


report: 
               precision    recall  f1-score   support

           0       0.83      0.33      0.48        15
           1       0.67      1.00      0.80         6
           2       0.42      0.89      0.57         9
           3       0.00      0.00      0.00         4

    accuracy                           0.56        34
   macro avg       0.48      0.56      0.46        34
weighted avg       0.60      0.56      0.50        34

rep type:  <class 'str'>
{'eval_loss': 2.1215927600860596, 'eval_f1': 0.4619047619047619, 'eval_runtime': 0.2472, 'eval_samples_per_second': 137.544, 'eval_steps_per_second': 24.272, 'epoch': 6.0}


  0%|          | 0/6 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


report: 
               precision    recall  f1-score   support

           0       0.67      0.53      0.59        15
           1       0.57      0.67      0.62         6
           2       0.47      0.78      0.58         9
           3       0.00      0.00      0.00         4

    accuracy                           0.56        34
   macro avg       0.43      0.49      0.45        34
weighted avg       0.52      0.56      0.52        34

rep type:  <class 'str'>
{'eval_loss': 1.5521236658096313, 'eval_f1': 0.44782763532763536, 'eval_runtime': 0.271, 'eval_samples_per_second': 125.462, 'eval_steps_per_second': 22.14, 'epoch': 7.0}


  0%|          | 0/6 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


report: 
               precision    recall  f1-score   support

           0       0.69      0.60      0.64        15
           1       0.57      0.67      0.62         6
           2       0.50      0.78      0.61         9
           3       0.00      0.00      0.00         4

    accuracy                           0.59        34
   macro avg       0.44      0.51      0.47        34
weighted avg       0.54      0.59      0.55        34

rep type:  <class 'str'>
{'eval_loss': 1.5565199851989746, 'eval_f1': 0.4667343526039178, 'eval_runtime': 0.241, 'eval_samples_per_second': 141.078, 'eval_steps_per_second': 24.896, 'epoch': 8.0}
{'train_runtime': 67.4981, 'train_samples_per_second': 16.475, 'train_steps_per_second': 2.845, 'train_loss': 0.4454421599706014, 'epoch': 8.0}


  0%|          | 0/6 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


report: 
               precision    recall  f1-score   support

           0       0.63      0.80      0.71        15
           1       0.75      0.50      0.60         6
           2       0.55      0.67      0.60         9
           3       0.00      0.00      0.00         4

    accuracy                           0.62        34
   macro avg       0.48      0.49      0.48        34
weighted avg       0.56      0.62      0.58        34

rep type:  <class 'str'>
F-1:  0.4764705882352941
************** END FOLD: 5 **************



In [15]:
print("\n******************* holdout results ******************* ")
holdout_df = df = pd.read_csv("data/belief_benchmark_holdout.csv")
holdout_ds = Dataset.from_pandas(holdout_df)
holdout_ds = holdout_ds.map(tokenize, batched=True)

preds = trainer.predict(holdout_ds)


******************* holdout results ******************* 


Map:   0%|          | 0/146 [00:00<?, ? examples/s]

  0%|          | 0/25 [00:00<?, ?it/s]

report: 
               precision    recall  f1-score   support

           0       0.53      0.80      0.63        49
           1       0.67      0.38      0.49        26
           2       0.65      0.63      0.64        59
           3       0.00      0.00      0.00        12

    accuracy                           0.59       146
   macro avg       0.46      0.45      0.44       146
weighted avg       0.56      0.59      0.56       146

rep type:  <class 'str'>


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [16]:
final_preds = [np.argmax(x) for x in preds.predictions]
real_f1 = metrics.f1_score(final_preds, holdout_df["label"], average="macro")
print("F-1: ", real_f1)
y_pred = []
for i, item in enumerate(final_preds):
    y_pred.append(item)

y_true = holdout_ds["label"]


F-1:  0.4399705634987384


In [17]:
print(f"y_true: {y_true}")
print(f"y_pred: {y_pred}")
import pickle
pickle.dump(y_pred, open(f"few_shot_results/y_pred_{int(train_size*100)}", "wb"))
pickle.dump(y_true, open(f"few_shot_results/y_true_{int(train_size*100)}", "wb"))

y_true: [0, 2, 2, 0, 0, 1, 1, 2, 2, 2, 2, 1, 0, 1, 2, 0, 0, 2, 0, 2, 2, 2, 2, 0, 1, 2, 1, 0, 2, 0, 0, 3, 3, 0, 2, 2, 1, 1, 0, 1, 1, 2, 0, 2, 2, 2, 1, 2, 0, 3, 2, 2, 1, 0, 1, 1, 0, 0, 0, 3, 0, 0, 0, 2, 2, 2, 2, 1, 1, 1, 0, 2, 0, 3, 2, 2, 3, 0, 0, 2, 0, 2, 2, 2, 2, 1, 1, 0, 0, 2, 3, 1, 0, 2, 0, 0, 2, 1, 3, 2, 2, 1, 2, 3, 2, 0, 2, 1, 1, 0, 2, 0, 0, 2, 0, 0, 0, 0, 0, 0, 2, 1, 2, 2, 2, 2, 0, 2, 3, 2, 0, 2, 2, 2, 0, 3, 3, 2, 0, 2, 1, 2, 0, 0, 0, 2]
y_pred: [0, 0, 2, 0, 0, 1, 0, 2, 0, 2, 2, 0, 0, 1, 0, 0, 0, 2, 0, 1, 0, 2, 0, 0, 1, 2, 1, 0, 0, 0, 0, 0, 2, 0, 2, 0, 2, 1, 0, 1, 0, 0, 2, 0, 2, 2, 0, 2, 2, 2, 2, 2, 0, 0, 0, 1, 0, 0, 2, 2, 0, 2, 0, 0, 2, 2, 2, 1, 1, 2, 2, 2, 0, 1, 2, 2, 2, 2, 1, 0, 0, 2, 2, 0, 0, 0, 2, 0, 0, 0, 2, 0, 0, 2, 0, 0, 2, 0, 0, 2, 2, 2, 0, 2, 2, 1, 2, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 2, 0, 2, 0, 2, 1, 2, 2, 2, 0, 0, 0, 0, 2, 0, 0, 2, 1, 2, 2, 0, 0, 2]


In [18]:
torch.cuda.empty_cache()
torch.cuda.memory_summary(device=None, abbreviated=False)
output.write("{torch.cuda.memory_summary(device=None, abbreviated=False)}")

output.close()