In [1]:
!pip install pandas==1.3.4
!pip install transformers==4.12.5
!pip install datasets==1.15.1
!pip install ipywidgets

Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com


In [2]:
import os
import pickle

from collections import Counter

# import pandas as pd
from sklearn.metrics import classification_report

import numpy as np
import torch
import torch.nn as nn

import transformers
from transformers import Trainer
from transformers import BertTokenizer
from transformers import BertForSequenceClassification
from transformers import Trainer, TrainingArguments
from transformers.data.data_collator import DataCollatorWithPadding

import datasets
from datasets import Dataset
from datasets import ClassLabel
from datasets import load_metric

In [3]:
dataset_mlc = torch.load(os.path.join("/notebooks/MTL_tasks/data", 'pe_dataset_for_MLC_w_combined.pt'))

In [4]:
dataset_mlc

DatasetDict({
    train: Dataset({
        features: ['essay_nr', 'component_id', 'label_and_comp_idxs', 'text', 'label_ComponentType', 'relation_SupportAttack', 'label_RelationType', 'label_LinkedNotLinked', 'split_y', 'essay', 'argument_bound_1', 'argument_bound_2', 'argument_id', 'sentence', 'paragraph', 'para_nr', 'total_paras', 'token_count', 'token_count_covering_para', 'tokens_count_covering_sentence', 'preceeding_tokens_in_sentence_count', 'succeeding_tokens_in_sentence_count', 'token_ratio', 'relative_position_in_para_char', 'is_in_intro', 'relative_position_in_para_token', 'is_in_conclusion', 'is_first_in_para', 'is_last_in_para', 'nr_preceeding_comps_in_para', 'nr_following_comps_in_para', 'structural_fts_as_text', 'structural_fts_as_text_combined', 'mc', 'cl', 'prem', 'link_present', 'support', 'attack'],
        num_rows: 3770
    })
    test: Dataset({
        features: ['essay_nr', 'component_id', 'label_and_comp_idxs', 'text', 'label_ComponentType', 'relation_SupportAtt

In [5]:
# what do we need: 'text', 'structural_fts_as_text', 'structural_fts_as_text_combined' 'mc', 'cl', 'prem', 'link'
# that's enough

In [6]:
cols_not_needed = ['essay_nr', 'component_id', 'label_and_comp_idxs', 'label_ComponentType', 'relation_SupportAttack', 'label_RelationType', 'label_LinkedNotLinked', 'split_y', 'essay', 'argument_bound_1', 'argument_bound_2', 'argument_id', 'sentence', 'paragraph', 'para_nr', 'total_paras', 'token_count', 'token_count_covering_para', 'tokens_count_covering_sentence', 'preceeding_tokens_in_sentence_count', 'succeeding_tokens_in_sentence_count', 'token_ratio', 'relative_position_in_para_char', 'is_in_intro', 'relative_position_in_para_token', 'is_in_conclusion', 'is_first_in_para', 'is_last_in_para', 'nr_preceeding_comps_in_para', 'nr_following_comps_in_para']

In [7]:
dataset_mlc = dataset_mlc.remove_columns(cols_not_needed)

In [8]:
dataset_mlc = dataset_mlc.remove_columns(['support', 'attack'])

In [9]:
dataset_mlc

DatasetDict({
    train: Dataset({
        features: ['text', 'structural_fts_as_text', 'structural_fts_as_text_combined', 'mc', 'cl', 'prem', 'link_present'],
        num_rows: 3770
    })
    test: Dataset({
        features: ['text', 'structural_fts_as_text', 'structural_fts_as_text_combined', 'mc', 'cl', 'prem', 'link_present'],
        num_rows: 1260
    })
    validation: Dataset({
        features: ['text', 'structural_fts_as_text', 'structural_fts_as_text_combined', 'mc', 'cl', 'prem', 'link_present'],
        num_rows: 943
    })
})

In [10]:
dataset_mlc['train'][100]

{'text': 'we cannot ignore its negative effects',
 'structural_fts_as_text': 'Topic: Advantages and disadvantages of machines instead of human to do the work, Sentence: In my opinion, although using machines have many benefits, we cannot ignore its negative effects., Para Number: 1, First in Para: No, Last in Para: Yes, Is in Introduction: Yes, Is in Conclusion: No',
 'structural_fts_as_text_combined': 'Topic: Advantages and disadvantages of machines instead of human to do the work, Sentence: In my opinion, although using machines have many benefits, we cannot ignore its negative effects., Structural Features: 1, No, Yes, Yes, No',
 'mc': 1,
 'cl': 0,
 'prem': 0,
 'link_present': 1}

In [11]:
labels = [label for label in dataset_mlc['train'].features.keys() if label not in ['text', 'structural_fts_as_text', 'structural_fts_as_text_combined']]
id2label = {idx:label for idx, label in enumerate(labels)}
label2id = {label:idx for idx, label in enumerate(labels)}
labels

['mc', 'cl', 'prem', 'link_present']

In [12]:
import transformers
#from transformers import Trainer
from transformers import AutoTokenizer
#from transformers import BertForSequenceClassification
#from transformers import Trainer, TrainingArguments
#from transformers.data.data_collator import DataCollatorWithPadding

In [13]:
#from transformers import AutoTokenizer
import numpy as np

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

def preprocess_data(examples):
    # take a batch of texts
    text = examples["structural_fts_as_text_combined"]
    # encode them
    encoding = tokenizer(text, padding="max_length", truncation=True, max_length=256)
    # add labels
    labels_batch = {k: examples[k] for k in examples.keys() if k in labels}
    # create numpy array of shape (batch_size, num_labels)
    labels_matrix = np.zeros((len(text), len(labels)))
    # fill numpy array
    for idx, label in enumerate(labels):
        labels_matrix[:, idx] = labels_batch[label]

        encoding["labels"] = labels_matrix.tolist()

    return encoding

In [14]:
encoded_dataset = dataset_mlc.map(preprocess_data, batched=True, remove_columns=dataset_mlc['train'].column_names)



  0%|          | 0/4 [00:00<?, ?ba/s]

  0%|          | 0/2 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

In [15]:
example = encoded_dataset['train'][0]
print(example.keys())

dict_keys(['attention_mask', 'input_ids', 'labels', 'token_type_ids'])


In [16]:
tokenizer.decode(example['input_ids'])

'[CLS] topic : gender equality at university admission, sentence : therefore, universities follow the requirement of job providers and decide subject suitable for particular gender., structural features : 3, no, yes, no, no [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [P

In [17]:
example['labels']

[0.0, 1.0, 0.0, 1.0]

In [18]:
[id2label[idx] for idx, label in enumerate(example['labels']) if label == 1.0]

['cl', 'link_present']

In [19]:
encoded_dataset.set_format("torch")

In [20]:
from transformers import AutoModelForSequenceClassification

model = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased", 
                                                           problem_type="multi_label_classification", 
                                                           num_labels=len(labels),
                                                           id2label=id2label,
                                                           label2id=label2id)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

In [21]:
model

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, element

In [22]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [23]:
model.to(device)

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, element

In [24]:
device

device(type='cuda')

In [25]:
batch_size = 16
metric_name = "f1"
nr_epochs = 6
results_folder = "/notebooks/MTL_tasks/results"

In [26]:
from transformers import TrainingArguments, Trainer

args = TrainingArguments(
    
    output_dir=results_folder,          
    
    # params
    num_train_epochs=nr_epochs,               # nb of epochs
    per_device_train_batch_size=batch_size,   # batch size per device during training
    per_device_eval_batch_size=batch_size,    # cf. paper Sun et al.
    learning_rate=1e-5,#2e-5,                 # cf. paper Sun et al.
#     warmup_steps=500,                         # number of warmup steps for learning rate scheduler
    warmup_ratio=0.1,                         # cf. paper Sun et al.
    weight_decay=0.01,                        # strength of weight decay
    
    # eval
    evaluation_strategy="steps",              # cf. paper Sun et al.
    eval_steps=20,                            # cf. paper Sun et al.
    
    # log
#     logging_dir="/notebooks/Results/bert_sequence_classification/tb_logs",  
#     logging_strategy='steps',
#     logging_steps=20,
    
    # save
    save_strategy='steps',
    save_total_limit=2,
    # save_steps=20, # default 500
    load_best_model_at_end=True,              # cf. paper Sun et al.
    # metric_for_best_model='eval_loss' 
    metric_for_best_model=metric_name
    
    
    
#     f"bert-finetuned-pe-mlc",
#     evaluation_strategy = "epoch",
#     save_strategy = "epoch",
#     learning_rate=2e-5,
#     per_device_train_batch_size=batch_size,
#     per_device_eval_batch_size=batch_size,
#     num_train_epochs=5,
#     weight_decay=0.01,
#     load_best_model_at_end=True,
#     metric_for_best_model=metric_name,
#     #push_to_hub=True,
)

In [27]:
from sklearn.metrics import f1_score, roc_auc_score, accuracy_score
from transformers import EvalPrediction
import torch
    
# source: https://jesusleal.io/2021/04/21/Longformer-multilabel-classification/
def multi_label_metrics(predictions, labels, threshold=0.5):
    # first, apply sigmoid on predictions which are of shape (batch_size, num_labels)
    sigmoid = torch.nn.Sigmoid()
    probs = sigmoid(torch.Tensor(predictions))
    # next, use threshold to turn them into integer predictions
    y_pred = np.zeros(probs.shape)
    y_pred[np.where(probs >= threshold)] = 1
    # finally, compute metrics
    y_true = labels
    f1_macro_average = f1_score(y_true=y_true, y_pred=y_pred, average='macro')
    roc_auc = roc_auc_score(y_true, y_pred, average = 'micro')
    accuracy = accuracy_score(y_true, y_pred)
    # return as dictionary
    metrics = {'f1': f1_macro_average,
               'roc_auc': roc_auc,
               'accuracy': accuracy}
    return metrics

def compute_metrics(p: EvalPrediction):
    preds = p.predictions[0] if isinstance(p.predictions, 
            tuple) else p.predictions
    result = multi_label_metrics(
        predictions=preds, 
        labels=p.label_ids)
    return result

In [28]:
encoded_dataset['train'][0]['labels'].type()

'torch.FloatTensor'

In [29]:
encoded_dataset['train']['input_ids'][0]

tensor([  101,  8476,  1024,  5907,  9945,  2012,  2118,  9634,  1010,  6251,
         1024,  3568,  1010,  5534,  3582,  1996,  9095,  1997,  3105, 11670,
         1998,  5630,  3395,  7218,  2005,  3327,  5907,  1012,  1010,  8332,
         2838,  1024,  1017,  1010,  2053,  1010,  2748,  1010,  2053,  1010,
         2053,   102,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0, 

In [30]:
#forward pass
# outputs = model(input_ids=encoded_dataset['train']['input_ids'][0].unsqueeze(0), labels=encoded_dataset['train'][0]['labels'].unsqueeze(0))
# outputs

In [31]:
trainer = Trainer(
    model,
    args,
    train_dataset=encoded_dataset["train"],
    eval_dataset=encoded_dataset["validation"],
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

In [32]:
trainer.train()

***** Running training *****
  Num examples = 3770
  Num Epochs = 6
  Instantaneous batch size per device = 16
  Total train batch size (w. parallel, distributed & accumulation) = 16
  Gradient Accumulation steps = 1
  Total optimization steps = 1416


Step,Training Loss,Validation Loss,F1,Roc Auc,Accuracy
20,No log,0.616356,0.188984,0.64043,0.5228
40,No log,0.581645,0.189987,0.642065,0.5228
60,No log,0.520545,0.403593,0.700057,0.45281
80,No log,0.497388,0.421229,0.712257,0.520679
100,No log,0.476716,0.495992,0.736699,0.593849
120,No log,0.46497,0.530684,0.749638,0.632025
140,No log,0.4468,0.557834,0.763622,0.601273
160,No log,0.450329,0.556003,0.759919,0.559915
180,No log,0.425393,0.553727,0.762536,0.617179
200,No log,0.424346,0.53311,0.752891,0.605514


***** Running Evaluation *****
  Num examples = 943
  Batch size = 16
***** Running Evaluation *****
  Num examples = 943
  Batch size = 16
***** Running Evaluation *****
  Num examples = 943
  Batch size = 16
***** Running Evaluation *****
  Num examples = 943
  Batch size = 16
***** Running Evaluation *****
  Num examples = 943
  Batch size = 16
***** Running Evaluation *****
  Num examples = 943
  Batch size = 16
***** Running Evaluation *****
  Num examples = 943
  Batch size = 16
***** Running Evaluation *****
  Num examples = 943
  Batch size = 16
***** Running Evaluation *****
  Num examples = 943
  Batch size = 16
***** Running Evaluation *****
  Num examples = 943
  Batch size = 16
***** Running Evaluation *****
  Num examples = 943
  Batch size = 16
***** Running Evaluation *****
  Num examples = 943
  Batch size = 16
***** Running Evaluation *****
  Num examples = 943
  Batch size = 16
***** Running Evaluation *****
  Num examples = 943
  Batch size = 16
***** Running Evalua

TrainOutput(global_step=1416, training_loss=0.3630037738778497, metrics={'train_runtime': 1365.1621, 'train_samples_per_second': 16.569, 'train_steps_per_second': 1.037, 'total_flos': 2975839472885760.0, 'train_loss': 0.3630037738778497, 'epoch': 6.0})

In [33]:
trainer.evaluate()

***** Running Evaluation *****
  Num examples = 943
  Batch size = 16


{'eval_loss': 0.3537537753582001,
 'eval_f1': 0.7316491358231823,
 'eval_roc_auc': 0.817733943064615,
 'eval_accuracy': 0.6956521739130435,
 'eval_runtime': 9.8135,
 'eval_samples_per_second': 96.092,
 'eval_steps_per_second': 6.012,
 'epoch': 6.0}

### inference on the test set

In [34]:
dataset_mlc['test'][500]

{'text': 'the individual should finance his or her education',
 'structural_fts_as_text': 'Topic: Should the Government Provide Free College?, Sentence: In short, the individual should finance his or her education because it is a personal choice, Para Number: 5, First in Para: Yes, Last in Para: No, Is in Introduction: No, Is in Conclusion: Yes',
 'structural_fts_as_text_combined': 'Topic: Should the Government Provide Free College?, Sentence: In short, the individual should finance his or her education because it is a personal choice, Structural Features: 5, Yes, No, No, Yes',
 'mc': 1,
 'cl': 0,
 'prem': 0,
 'link_present': 1}

In [35]:
text = "Using public transportation has a lot of advantages for the modern society facing a lot of problems: the environmental population, the isolation in life, the depletion of natural resources"

encoding = tokenizer(text, return_tensors="pt")
encoding = {k: v.to(trainer.model.device) for k,v in encoding.items()}

outputs = trainer.model(**encoding)

In [36]:
logits = outputs.logits
logits.shape

torch.Size([1, 4])

In [37]:
# apply sigmoid + threshold
sigmoid = torch.nn.Sigmoid()
probs = sigmoid(logits.squeeze().cpu())
predictions = np.zeros(probs.shape)
predictions[np.where(probs >= 0.5)] = 1
# turn predicted id's into actual label names
predicted_labels = [id2label[idx] for idx, label in enumerate(predictions) if label == 1.0]
print(predicted_labels)

['prem']


In [38]:
from transformers.data.data_collator import DataCollatorWithPadding

In [39]:
test_trainer = Trainer(model, data_collator=DataCollatorWithPadding(tokenizer))
test_raw_preds, test_labels, _ = test_trainer.predict(encoded_dataset["test"])
# test_preds = np.argmax(test_raw_preds, axis=0)

No `TrainingArguments` passed, using `output_dir=tmp_trainer`.
PyTorch: setting up devices
The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).
***** Running Prediction *****
  Num examples = 1260
  Batch size = 8


In [40]:
test_raw_preds.shape

(1260, 4)

In [41]:
test_labels

array([[1., 0., 0., 1.],
       [1., 0., 0., 1.],
       [0., 1., 0., 0.],
       ...,
       [0., 0., 1., 0.],
       [0., 0., 1., 0.],
       [0., 0., 1., 0.]], dtype=float32)

In [42]:
sigmoid = torch.nn.Sigmoid()

In [43]:
test_raw_preds[100]

array([-4.845425 , -2.4463947,  2.5750885, -1.5848606], dtype=float32)

In [44]:
try_pred = sigmoid(torch.tensor(test_raw_preds[0]))

In [45]:
try_pred

tensor([0.6311, 0.3490, 0.0473, 0.6861])

In [46]:
test_preds = sigmoid(torch.tensor(test_raw_preds))

In [47]:
test_preds

tensor([[0.6311, 0.3490, 0.0473, 0.6861],
        [0.8102, 0.2365, 0.0459, 0.7859],
        [0.0689, 0.5003, 0.1705, 0.6263],
        ...,
        [0.0083, 0.0794, 0.9330, 0.1557],
        [0.0142, 0.0548, 0.9538, 0.1742],
        [0.0134, 0.0602, 0.9541, 0.1601]])

In [48]:
# nice! now do the thresholding

In [49]:
predictions = np.zeros(test_preds.shape)
predictions[np.where(test_preds >= 0.5)] = 1
# turn predicted id's into actual label names
#predicted_labels = [id2label[idx] for idx, label in enumerate(predictions) if label == 1.0]
#print(predicted_labels)

In [50]:
predictions

array([[1., 0., 0., 1.],
       [1., 0., 0., 1.],
       [0., 1., 0., 1.],
       ...,
       [0., 0., 1., 0.],
       [0., 0., 1., 0.],
       [0., 0., 1., 0.]])

In [51]:
from sklearn.metrics import classification_report

In [52]:
print(classification_report(test_labels, predictions, digits=3))

              precision    recall  f1-score   support

           0      0.770     0.941     0.847       153
           1      0.730     0.546     0.625       302
           2      0.898     0.918     0.908       805
           3      0.750     0.657     0.700       516

   micro avg      0.822     0.781     0.801      1776
   macro avg      0.787     0.766     0.770      1776
weighted avg      0.815     0.781     0.794      1776
 samples avg      0.834     0.801     0.808      1776



  _warn_prf(average, modifier, msg_start, len(result))


## task: do separate classification reports for separate classes:

In [53]:
test_labels

array([[1., 0., 0., 1.],
       [1., 0., 0., 1.],
       [0., 1., 0., 0.],
       ...,
       [0., 0., 1., 0.],
       [0., 0., 1., 0.],
       [0., 0., 1., 0.]], dtype=float32)

In [54]:
predictions

array([[1., 0., 0., 1.],
       [1., 0., 0., 1.],
       [0., 1., 0., 1.],
       ...,
       [0., 0., 1., 0.],
       [0., 0., 1., 0.],
       [0., 0., 1., 0.]])

In [55]:
test_labels.shape

(1260, 4)

In [56]:
predictions.shape

(1260, 4)

In [57]:
# separate test labels into separate task classes

In [58]:
# first three classes

test_labels[:,0:3]

array([[1., 0., 0.],
       [1., 0., 0.],
       [0., 1., 0.],
       ...,
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.]], dtype=float32)

In [59]:
# linked/not_linked class

In [60]:
test_labels[:,3]

array([1., 1., 0., ..., 0., 0., 0.], dtype=float32)

In [61]:
test_labels[:,3].shape

(1260,)

In [62]:
test_labels_comp_classes = test_labels[:,0:3]

In [63]:
test_labels_link_class = test_labels[:,3]

In [64]:
test_labels_comp_classes.shape

(1260, 3)

In [65]:
test_labels_link_class.shape

(1260,)

In [66]:
# separate test predictions into separate task classes

In [67]:
predictions[:,0:3]

array([[1., 0., 0.],
       [1., 0., 0.],
       [0., 1., 0.],
       ...,
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.]])

In [68]:
predictions[:,3]

array([1., 1., 1., ..., 0., 0., 0.])

In [69]:
test_predictions_comp_classes = predictions[:,0:3]

In [70]:
test_predictions_link_class = predictions[:,3]

In [71]:
test_predictions_comp_classes.shape

(1260, 3)

In [72]:
test_predictions_link_class.shape

(1260,)

In [73]:
# now do two separate classification reports

In [74]:
# for comp classes

print(classification_report(test_labels_comp_classes, test_predictions_comp_classes, digits=3))

              precision    recall  f1-score   support

           0      0.770     0.941     0.847       153
           1      0.730     0.546     0.625       302
           2      0.898     0.918     0.908       805

   micro avg      0.848     0.832     0.840      1260
   macro avg      0.799     0.802     0.793      1260
weighted avg      0.842     0.832     0.833      1260
 samples avg      0.829     0.832     0.830      1260



  _warn_prf(average, modifier, msg_start, len(result))


In [75]:
# for link class

print(classification_report(test_labels_link_class, test_predictions_link_class, digits=3))

              precision    recall  f1-score   support

         0.0      0.781     0.848     0.813       744
         1.0      0.750     0.657     0.700       516

    accuracy                          0.770      1260
   macro avg      0.765     0.753     0.757      1260
weighted avg      0.768     0.770     0.767      1260

