In [None]:
# !pip3 install vncorenlp
# !pip3 install fairseq
# !pip3 install fastBPE
!pip3 install transformers



In [None]:
from transformers import *
import torch
import pandas as pd
import numpy as np
from sklearn.metrics import accuracy_score, f1_score
from sklearn.model_selection import train_test_split
import os
from tqdm import tqdm, trange

GroupViT models are not usable since `tensorflow_probability` can't be loaded. It seems you have `tensorflow_probability` installed with the wrong tensorflow version.Please try to reinstall it following the instructions here: https://github.com/tensorflow/probability.
TAPAS models are not usable since `tensorflow_probability` can't be loaded. It seems you have `tensorflow_probability` installed with the wrong tensorflow version. Please try to reinstall it following the instructions here: https://github.com/tensorflow/probability.


R-BERT model architecture similar to the one proposed in [1], reimplementation inspired by [2]

[1] Enriching Pre-trained Language Model with Entity Information for
Relation Classification, Shanchan Wu, Yifan He, https://arxiv.org/pdf/1905.08284v1

[2] Reimplement R-BERT from "Enriching Pre-trained Language Model with Entity Information for Relation Classification" paper, https://github.com/heraclex12/R-BERT-Relation-Classification

In [None]:
class FCLayer(torch.nn.Module):
  def __init__(self, input_dim, output_dim, dropout_rate=0., use_activation=True):
    super(FCLayer, self).__init__()
    self.use_activation = use_activation
    self.dropout = torch.nn.Dropout(dropout_rate)
    self.linear = torch.nn.Linear(input_dim, output_dim)
    self.tanh = torch.nn.Tanh()

  def forward(self, x):
    x = self.dropout(x)
    if self.use_activation:
      x = self.tanh(x)
    return self.linear(x)

In [None]:
class RBERT(BertPreTrainedModel):
  base_model_prefix = "roberta"
  config_class = RobertaConfig
  #pretrained_model_archive_map = ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP
  def __init__(self, config, args):
    super(RBERT, self).__init__(config)
    self.roberta = RobertaModel(config=config)
    self.num_labels = config.num_labels
    self.cls_fc_layer = FCLayer(config.hidden_size, config.hidden_size, args['DROPOUT_RATE'])
    self.e1_fc_layer = FCLayer(config.hidden_size, config.hidden_size, args['DROPOUT_RATE'])
    self.e2_fc_layer = FCLayer(config.hidden_size, config.hidden_size, args['DROPOUT_RATE'])
    self.label_classifier = FCLayer(config.hidden_size * 3, self.num_labels, args['DROPOUT_RATE'], use_activation=False)

  @staticmethod
  def entity_average(hidden_output, e_mask):
    e_mask_unqueeze = e_mask.unsqueeze(1)
    length_tensor = (e_mask != 0).sum(dim=1).unsqueeze(1)

    sum_vector = torch.bmm(e_mask_unqueeze.float(), hidden_output).squeeze(1)
    avg_vector = sum_vector.float() / length_tensor.float()
    return avg_vector

  def forward(self, input_ids, attention_mask, labels, e1_mask, e2_mask):
    outputs = self.roberta(input_ids, attention_mask=attention_mask)
    sequence_output = outputs[0]
    pooled_output = outputs[1]

    e1_h = self.entity_average(sequence_output, e1_mask)
    e2_h = self.entity_average(sequence_output, e2_mask)

    pooled_output = self.cls_fc_layer(pooled_output)
    e1_h = self.e1_fc_layer(e1_h)
    e2_h = self.e2_fc_layer(e2_h)

    concat_h = torch.cat([pooled_output, e1_h, e2_h], dim=-1)
    logits = self.label_classifier(concat_h)

    outputs = (logits, ) + outputs[2:]

    if labels is not None:
      if self.num_labels == 1:
        loss_fct = torch.nn.MSELoss()
        loss = loss_fct(logits.view(-1), labels.view(-1))
      else:
        loss_fct = torch.nn.CrossEntropyLoss()
        loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

      outputs = (loss,) + outputs
    return outputs

Loading our distant supervised dataset and the official test dataset

In [None]:
train_df = pd.read_csv("dataset.csv", header=None, names=["label", "text"])
test_df = pd.read_csv('test.tsv', sep='\t', header=None, names=['label', 'text'])
labels = pd.read_fwf('label.txt', header=None, names=['label'])
train_df.head()

Unnamed: 0,label,text
0,"Component-Whole(e2,e1)",the system described above has its greatest ap...
1,Other,the <e1> child </e1> was carefully wrapped and...
2,"Instrument-Agency(e2,e1)",the <e1> author </e1> keygen uses <e2> disasse...
3,Other,misty <e1> ridge </e1> uprises from the <e2> s...
4,"Member-Collection(e1,e2)",the <e1> student </e1> <e2> association </e2> ...


In [None]:
def seed_everything(SEED):
  np.random.seed(SEED)
  torch.manual_seed(SEED)
  torch.cuda.manual_seed(SEED)
  torch.backends.cudnn.deterministic = True

if torch.cuda.is_available():
  device = torch.device("cuda")
  print("We will use the GPU:", torch.cuda.get_device_name())
else:
  device = torch.device("cpu")
  print("We will use the CPU.")

args = {
    'NUM_LABELS' : len(labels),
    'DROPOUT_RATE' : 0.1,
    'LEARNING_RATE' : 2e-5,
    'EPOCHS' : 5,
    'MAX_SEQUENCE_LENGTH' : 384,
    'BATCH_SIZE' : 16,
    'ADAM_EPSILON' : 1e-8,
    'GRADIENT_ACCUMULATION_STEPS' : 1,
    'MAX_GRAD_NORM' : 1.0,
    'LOGGING_STEPS' : 250,
    'SAVE_STEPS' : 250,
    'WEIGHT_DECAY' : 0.0,
    'NUM_WARMUP_STEPS' : 0,
}

We will use the GPU: Tesla T4


Modified function for filtering usable sentences, entity tags insertion and entity masks for R-bert approach.

In [None]:
def convert_lines(df, label_indexes, max_seq_len, tokenizer, cls_token='[CLS]',
                  sep_token='[SEP]', pad_token=0, add_sep_token=False, mask_padding_with_zero=True):
  input_ids = []
  attention_masks = []
  e1_masks = []
  e2_masks = []
  labels = []
  print("Converting sentence...")

  # Keep track of valid sentences
  valid_sentences = []

  for row in df.itertuples():
    if (row.Index % 5000 == 0 and row.Index > 0) or row.Index == len(df) - 1:
      print('Parsing {} of {}'.format(row.Index + 1, len(df)))

    # Preserve special tokens by splitting the text manually
    text = row.text.lower()
    text = text.replace("<e1>", " <e1> ").replace("</e1>", " </e1> ").replace("<e2>", " <e2> ").replace("</e2>", " </e2> ")
    tokens = tokenizer.tokenize(text)

    try:  # Handle cases where special tokens are missing
        e11_p = tokens.index("<e1>")
        e12_p = tokens.index("</e1>")
        e21_p = tokens.index("<e2>")
        e22_p = tokens.index("</e2>")

        # If all special tokens are found, add the sentence to the valid list
        valid_sentences.append(row)
    except ValueError:
        print(f"Skipping sentence due to missing special tokens: {row.text}")
        continue  # Skip to the next sentence

  # Process only valid sentences
  for row in valid_sentences:

    tokens = tokenizer.tokenize(row.text.lower())

    e11_p = tokens.index("<e1>")
    e12_p = tokens.index("</e1>")
    e21_p = tokens.index("<e2>")
    e22_p = tokens.index("</e2>")

    # Replace token
    tokens[e11_p] = '$'
    tokens[e12_p] = '$'
    tokens[e21_p] = '#'
    tokens[e22_p] = '#'

    # Add 1 because of the [CLS] token
    e11_p += 1
    e12_p += 1
    e21_p += 1
    e22_p += 1

    # Account for [CLS] and [SEP] with "2" and with "3" for RoBERTa
    if add_sep_token:
      special_tokens_count = 2
    else:
      special_tokens_count = 1

    if len(tokens) > max_seq_len - special_tokens_count:
      tokens = tokens[:(max_seq_len - special_tokens_count)]

    if add_sep_token:
      tokens += [sep_token]

    tokens = [cls_token] + tokens

    input_id = tokenizer.convert_tokens_to_ids(tokens)
    attention_mask = [1 if mask_padding_with_zero else 0] * len(input_id)

    padding_length = max_seq_len - len(input_id)
    input_id = input_id + ([pad_token] * padding_length)
    attention_mask = attention_mask + ([0 if mask_padding_with_zero else 1] * padding_length)

    e1_mask = [0] * len(attention_mask)
    e2_mask = [0] * len(attention_mask)

    for i in range(e11_p, e12_p + 1):
      e1_mask[i] = 1
    for i in range(e21_p, e22_p + 1):
      e2_mask[i] = 1

    assert len(input_id) == max_seq_len, "Error with input length {} vs {}".format(len(input_id), max_seq_len)
    assert len(attention_mask) == max_seq_len, "Error with attention mask length {} vs {}".format(len(attention_mask), max_seq_len)

    input_ids.append(input_id)
    attention_masks.append(attention_mask)
    labels.append(label_indexes.index[label_indexes.label == row.label][0])
    e1_masks.append(e1_mask)
    e2_masks.append(e2_mask)


  print(f"\nNumber of sentences left after removing invalid ones: {len(valid_sentences)}")
  dataset = torch.utils.data.TensorDataset(torch.tensor(input_ids, dtype=torch.long),
                          torch.tensor(attention_masks, dtype=torch.long),
                          torch.tensor(labels, dtype=torch.long),
                          torch.tensor(e1_masks, dtype=torch.long),
                          torch.tensor(e2_masks, dtype=torch.long))
  return dataset

Tokenization and model loading

In [None]:
ADDITIONAL_SPECIAL_TOKENS = ["<e1>", "</e1>", "<e2>", "</e2>"]
def load_model(args, mode='en'):
  tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
  tokenizer.add_special_tokens({"additional_special_tokens" : ADDITIONAL_SPECIAL_TOKENS})

  config = RobertaConfig.from_pretrained('roberta-base', num_labels = args['NUM_LABELS'])
  model = RBERT.from_pretrained('roberta-base', config=config, args=args)
  model.resize_token_embeddings(len(tokenizer))
  model.to(device)
  return config, tokenizer, model

seed_everything(42)
config, tokenizer, model = load_model(args)

loading file vocab.json from cache at /root/.cache/huggingface/hub/models--roberta-base/snapshots/e2da8e2f811d1448a5b465c236feacd80ffbac7b/vocab.json
loading file merges.txt from cache at /root/.cache/huggingface/hub/models--roberta-base/snapshots/e2da8e2f811d1448a5b465c236feacd80ffbac7b/merges.txt
loading file added_tokens.json from cache at None
loading file special_tokens_map.json from cache at None
loading file tokenizer_config.json from cache at /root/.cache/huggingface/hub/models--roberta-base/snapshots/e2da8e2f811d1448a5b465c236feacd80ffbac7b/tokenizer_config.json
loading file tokenizer.json from cache at /root/.cache/huggingface/hub/models--roberta-base/snapshots/e2da8e2f811d1448a5b465c236feacd80ffbac7b/tokenizer.json
loading file chat_template.jinja from cache at None
loading configuration file config.json from cache at /root/.cache/huggingface/hub/models--roberta-base/snapshots/e2da8e2f811d1448a5b465c236feacd80ffbac7b/config.json
Model config RobertaConfig {
  "_name_or_path"

Training and Test data processing

In [None]:
train_dataset = convert_lines(train_df, labels, args['MAX_SEQUENCE_LENGTH'], tokenizer, cls_token='<s>', sep_token='</s>', pad_token=1)
test_dataset = convert_lines(test_df, labels, args['MAX_SEQUENCE_LENGTH'], tokenizer, cls_token='<s>', sep_token='</s>', pad_token=1)

train_sampler = torch.utils.data.RandomSampler(train_dataset)
train_loader = torch.utils.data.DataLoader(train_dataset, sampler=train_sampler, batch_size=args['BATCH_SIZE'])

test_sampler = torch.utils.data.SequentialSampler(test_dataset)
test_loader = torch.utils.data.DataLoader(test_dataset, sampler=test_sampler, batch_size=args['BATCH_SIZE'] * 2)

Converting sentence...
Skipping sentence due to missing special tokens: people have been moving back into <e2> downtown </e2>
Skipping sentence due to missing special tokens: even commercial <e1> networks </e1> have moved into highdefinition broadcast
Skipping sentence due to missing special tokens: suicide one the leading causes <e2> death </e2> among preadolescents and teens and victims bullying are increased risk for committing suicide
Skipping sentence due to missing special tokens: newspapers swap content via widgets with the help the newsgator <e2> service </e2>
Skipping sentence due to missing special tokens: calluses are caused improperly fitting shoes <e2> skin </e2> abnormality
Skipping sentence due to missing special tokens: adults use <e2> drugs </e2> for this purpose
Skipping sentence due to missing special tokens: leakage and fire caused due <e2> corrosion </e2> bypass piping for recirculation gas fuel oil desulphurization unit
Skipping sentence due to missing special tok

Evaluation Metrics

In [None]:
def evaluate(model, device, test_loader):
    eval_loss = 0.0
    nb_eval_steps = 0
    preds = None
    out_label_ids = None

    model.eval()

    for batch in tqdm(test_loader, desc="Evaluating"):
        batch = tuple(t.to(device) for t in batch)
        with torch.no_grad():
            inputs = {'input_ids': batch[0],
                      'attention_mask': batch[1],
                      'labels': batch[2],
                      'e1_mask': batch[3],
                      'e2_mask': batch[4]}
            outputs = model(**inputs)
            tmp_eval_loss, logits = outputs[:2]

            eval_loss += tmp_eval_loss.mean().item()
        nb_eval_steps += 1

        if preds is None:
            preds = logits.detach().cpu().numpy()
            out_label_ids = inputs['labels'].detach().cpu().numpy()
        else:
            preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
            out_label_ids = np.append(
                out_label_ids, inputs['labels'].detach().cpu().numpy(), axis=0)

    eval_loss = eval_loss / nb_eval_steps
    preds = np.argmax(preds, axis=1)

    result = {'accuracy' : accuracy_score(out_label_ids, preds), 'f1_score': f1_score(out_label_ids, preds, average='macro'), 'pred' : preds}
    return result

def save_model():
    torch.save(model.state_dict(), 'relation_data/trained_models/model.bin')

def load_saved_model(args):
  tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
  tokenizer.add_special_tokens({"additional_special_tokens" : ADDITIONAL_SPECIAL_TOKENS})

  config = RobertaConfig.from_pretrained('roberta-base', num_labels = args['NUM_LABELS'])
  model = RBERT.from_pretrained('relation_data/trained_models/model.bin', config=config, args=args)

  model.cuda()
  return config, tokenizer, model

Model Training

In [None]:


t_total = len(train_loader) // args['GRADIENT_ACCUMULATION_STEPS'] * args['EPOCHS']

# Prepare optimizer and schedule (linear warmup and decay)
no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
      'weight_decay': args['WEIGHT_DECAY']},
    {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
optimizer = AdamW(optimizer_grouped_parameters, lr=args['LEARNING_RATE'], eps=args['ADAM_EPSILON'])
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args['NUM_WARMUP_STEPS'], num_training_steps=t_total)


global_step = 0
tr_loss = 0.0
model.zero_grad()
train_iterator = trange(int(args['EPOCHS']), desc="Epoch")

for _ in train_iterator:
    epoch_iterator = tqdm(train_loader, desc="Iteration")
    for step, batch in enumerate(epoch_iterator):
        model.train()
        batch = tuple(t.to(device) for t in batch)  # GPU or CPU
        inputs = {'input_ids': batch[0],
                  'attention_mask': batch[1],
                  'labels': batch[2],
                  'e1_mask': batch[3],
                  'e2_mask': batch[4]}
        outputs = model(**inputs)
        loss = outputs[0]

        loss.backward()

        tr_loss += loss.item()
        torch.nn.utils.clip_grad_norm_(model.parameters(), args['MAX_GRAD_NORM'])
        optimizer.step()
        scheduler.step()  # Update learning rate schedule
        model.zero_grad()


    print("\n====Evaluation====")
    print("\nACCURACY: ", evaluate(model, device, test_loader)['accuracy'])


Epoch:   0%|          | 0/5 [00:00<?, ?it/s]
Iteration:   0%|          | 0/1013 [00:00<?, ?it/s][A
Iteration:   0%|          | 1/1013 [00:02<41:08,  2.44s/it][A
Iteration:   0%|          | 2/1013 [00:03<27:15,  1.62s/it][A
Iteration:   0%|          | 3/1013 [00:04<22:49,  1.36s/it][A
Iteration:   0%|          | 4/1013 [00:05<20:48,  1.24s/it][A
Iteration:   0%|          | 5/1013 [00:06<19:39,  1.17s/it][A
Iteration:   1%|          | 6/1013 [00:07<18:58,  1.13s/it][A
Iteration:   1%|          | 7/1013 [00:08<18:34,  1.11s/it][A
Iteration:   1%|          | 8/1013 [00:09<18:19,  1.09s/it][A
Iteration:   1%|          | 9/1013 [00:10<18:08,  1.08s/it][A
Iteration:   1%|          | 10/1013 [00:11<18:01,  1.08s/it][A
Iteration:   1%|          | 11/1013 [00:13<18:00,  1.08s/it][A
Iteration:   1%|          | 12/1013 [00:14<17:59,  1.08s/it][A
Iteration:   1%|▏         | 13/1013 [00:15<17:57,  1.08s/it][A
Iteration:   1%|▏         | 14/1013 [00:16<17:57,  1.08s/it][A
Iteration:   


====Evaluation====



Evaluating:   0%|          | 0/85 [00:00<?, ?it/s][A
Evaluating:   1%|          | 1/85 [00:00<00:49,  1.70it/s][A
Evaluating:   2%|▏         | 2/85 [00:01<00:49,  1.68it/s][A
Evaluating:   4%|▎         | 3/85 [00:01<00:48,  1.69it/s][A
Evaluating:   5%|▍         | 4/85 [00:02<00:48,  1.68it/s][A
Evaluating:   6%|▌         | 5/85 [00:02<00:47,  1.68it/s][A
Evaluating:   7%|▋         | 6/85 [00:03<00:46,  1.68it/s][A
Evaluating:   8%|▊         | 7/85 [00:04<00:46,  1.68it/s][A
Evaluating:   9%|▉         | 8/85 [00:04<00:45,  1.69it/s][A
Evaluating:  11%|█         | 9/85 [00:05<00:45,  1.69it/s][A
Evaluating:  12%|█▏        | 10/85 [00:05<00:44,  1.69it/s][A
Evaluating:  13%|█▎        | 11/85 [00:06<00:43,  1.69it/s][A
Evaluating:  14%|█▍        | 12/85 [00:07<00:43,  1.69it/s][A
Evaluating:  15%|█▌        | 13/85 [00:07<00:42,  1.69it/s][A
Evaluating:  16%|█▋        | 14/85 [00:08<00:42,  1.69it/s][A
Evaluating:  18%|█▊        | 15/85 [00:08<00:41,  1.68it/s][A
Evaluatin


ACCURACY:  0.7861612072138388



Iteration:   0%|          | 0/1013 [00:00<?, ?it/s][A
Iteration:   0%|          | 1/1013 [00:01<17:15,  1.02s/it][A
Iteration:   0%|          | 2/1013 [00:02<17:20,  1.03s/it][A
Iteration:   0%|          | 3/1013 [00:03<17:21,  1.03s/it][A
Iteration:   0%|          | 4/1013 [00:04<17:20,  1.03s/it][A
Iteration:   0%|          | 5/1013 [00:05<17:20,  1.03s/it][A
Iteration:   1%|          | 6/1013 [00:06<17:17,  1.03s/it][A
Iteration:   1%|          | 7/1013 [00:07<17:17,  1.03s/it][A
Iteration:   1%|          | 8/1013 [00:08<17:17,  1.03s/it][A
Iteration:   1%|          | 9/1013 [00:09<17:16,  1.03s/it][A
Iteration:   1%|          | 10/1013 [00:10<17:15,  1.03s/it][A
Iteration:   1%|          | 11/1013 [00:11<17:14,  1.03s/it][A
Iteration:   1%|          | 12/1013 [00:12<17:14,  1.03s/it][A
Iteration:   1%|▏         | 13/1013 [00:13<17:13,  1.03s/it][A
Iteration:   1%|▏         | 14/1013 [00:14<17:12,  1.03s/it][A
Iteration:   1%|▏         | 15/1013 [00:15<17:11,  1.03s/


====Evaluation====



Evaluating:   0%|          | 0/85 [00:00<?, ?it/s][A
Evaluating:   1%|          | 1/85 [00:00<00:49,  1.68it/s][A
Evaluating:   2%|▏         | 2/85 [00:01<00:49,  1.69it/s][A
Evaluating:   4%|▎         | 3/85 [00:01<00:48,  1.69it/s][A
Evaluating:   5%|▍         | 4/85 [00:02<00:48,  1.68it/s][A
Evaluating:   6%|▌         | 5/85 [00:02<00:47,  1.68it/s][A
Evaluating:   7%|▋         | 6/85 [00:03<00:46,  1.69it/s][A
Evaluating:   8%|▊         | 7/85 [00:04<00:46,  1.68it/s][A
Evaluating:   9%|▉         | 8/85 [00:04<00:45,  1.68it/s][A
Evaluating:  11%|█         | 9/85 [00:05<00:45,  1.68it/s][A
Evaluating:  12%|█▏        | 10/85 [00:05<00:44,  1.68it/s][A
Evaluating:  13%|█▎        | 11/85 [00:06<00:44,  1.68it/s][A
Evaluating:  14%|█▍        | 12/85 [00:07<00:43,  1.68it/s][A
Evaluating:  15%|█▌        | 13/85 [00:07<00:42,  1.68it/s][A
Evaluating:  16%|█▋        | 14/85 [00:08<00:42,  1.68it/s][A
Evaluating:  18%|█▊        | 15/85 [00:08<00:41,  1.68it/s][A
Evaluatin


ACCURACY:  0.813029076186971



Iteration:   0%|          | 0/1013 [00:00<?, ?it/s][A
Iteration:   0%|          | 1/1013 [00:01<17:31,  1.04s/it][A
Iteration:   0%|          | 2/1013 [00:02<17:30,  1.04s/it][A
Iteration:   0%|          | 3/1013 [00:03<17:34,  1.04s/it][A
Iteration:   0%|          | 4/1013 [00:04<17:29,  1.04s/it][A
Iteration:   0%|          | 5/1013 [00:05<17:27,  1.04s/it][A
Iteration:   1%|          | 6/1013 [00:06<17:25,  1.04s/it][A
Iteration:   1%|          | 7/1013 [00:07<17:22,  1.04s/it][A
Iteration:   1%|          | 8/1013 [00:08<17:22,  1.04s/it][A
Iteration:   1%|          | 9/1013 [00:09<17:21,  1.04s/it][A
Iteration:   1%|          | 10/1013 [00:10<17:20,  1.04s/it][A
Iteration:   1%|          | 11/1013 [00:11<17:19,  1.04s/it][A
Iteration:   1%|          | 12/1013 [00:12<17:16,  1.04s/it][A
Iteration:   1%|▏         | 13/1013 [00:13<17:17,  1.04s/it][A
Iteration:   1%|▏         | 14/1013 [00:14<17:18,  1.04s/it][A
Iteration:   1%|▏         | 15/1013 [00:15<17:16,  1.04s/


====Evaluation====



Evaluating:   0%|          | 0/85 [00:00<?, ?it/s][A
Evaluating:   1%|          | 1/85 [00:00<00:49,  1.71it/s][A
Evaluating:   2%|▏         | 2/85 [00:01<00:49,  1.69it/s][A
Evaluating:   4%|▎         | 3/85 [00:01<00:48,  1.69it/s][A
Evaluating:   5%|▍         | 4/85 [00:02<00:47,  1.69it/s][A
Evaluating:   6%|▌         | 5/85 [00:02<00:47,  1.69it/s][A
Evaluating:   7%|▋         | 6/85 [00:03<00:46,  1.69it/s][A
Evaluating:   8%|▊         | 7/85 [00:04<00:46,  1.68it/s][A
Evaluating:   9%|▉         | 8/85 [00:04<00:45,  1.68it/s][A
Evaluating:  11%|█         | 9/85 [00:05<00:45,  1.68it/s][A
Evaluating:  12%|█▏        | 10/85 [00:05<00:44,  1.68it/s][A
Evaluating:  13%|█▎        | 11/85 [00:06<00:43,  1.68it/s][A
Evaluating:  14%|█▍        | 12/85 [00:07<00:43,  1.68it/s][A
Evaluating:  15%|█▌        | 13/85 [00:07<00:42,  1.68it/s][A
Evaluating:  16%|█▋        | 14/85 [00:08<00:42,  1.68it/s][A
Evaluating:  18%|█▊        | 15/85 [00:08<00:41,  1.68it/s][A
Evaluatin


ACCURACY:  0.8225984541774015



Iteration:   0%|          | 0/1013 [00:00<?, ?it/s][A
Iteration:   0%|          | 1/1013 [00:01<17:31,  1.04s/it][A
Iteration:   0%|          | 2/1013 [00:02<17:24,  1.03s/it][A
Iteration:   0%|          | 3/1013 [00:03<17:24,  1.03s/it][A
Iteration:   0%|          | 4/1013 [00:04<17:23,  1.03s/it][A
Iteration:   0%|          | 5/1013 [00:05<17:20,  1.03s/it][A
Iteration:   1%|          | 6/1013 [00:06<17:19,  1.03s/it][A
Iteration:   1%|          | 7/1013 [00:07<17:17,  1.03s/it][A
Iteration:   1%|          | 8/1013 [00:08<17:16,  1.03s/it][A
Iteration:   1%|          | 9/1013 [00:09<17:14,  1.03s/it][A
Iteration:   1%|          | 10/1013 [00:10<17:13,  1.03s/it][A
Iteration:   1%|          | 11/1013 [00:11<17:13,  1.03s/it][A
Iteration:   1%|          | 12/1013 [00:12<17:12,  1.03s/it][A
Iteration:   1%|▏         | 13/1013 [00:13<17:12,  1.03s/it][A
Iteration:   1%|▏         | 14/1013 [00:14<17:11,  1.03s/it][A
Iteration:   1%|▏         | 15/1013 [00:15<17:08,  1.03s/


====Evaluation====



Evaluating:   0%|          | 0/85 [00:00<?, ?it/s][A
Evaluating:   1%|          | 1/85 [00:00<00:49,  1.71it/s][A
Evaluating:   2%|▏         | 2/85 [00:01<00:48,  1.70it/s][A
Evaluating:   4%|▎         | 3/85 [00:01<00:48,  1.70it/s][A
Evaluating:   5%|▍         | 4/85 [00:02<00:47,  1.69it/s][A
Evaluating:   6%|▌         | 5/85 [00:02<00:47,  1.69it/s][A
Evaluating:   7%|▋         | 6/85 [00:03<00:46,  1.69it/s][A
Evaluating:   8%|▊         | 7/85 [00:04<00:46,  1.69it/s][A
Evaluating:   9%|▉         | 8/85 [00:04<00:45,  1.69it/s][A
Evaluating:  11%|█         | 9/85 [00:05<00:44,  1.69it/s][A
Evaluating:  12%|█▏        | 10/85 [00:05<00:44,  1.69it/s][A
Evaluating:  13%|█▎        | 11/85 [00:06<00:43,  1.69it/s][A
Evaluating:  14%|█▍        | 12/85 [00:07<00:43,  1.69it/s][A
Evaluating:  15%|█▌        | 13/85 [00:07<00:42,  1.69it/s][A
Evaluating:  16%|█▋        | 14/85 [00:08<00:41,  1.69it/s][A
Evaluating:  18%|█▊        | 15/85 [00:08<00:41,  1.69it/s][A
Evaluatin


ACCURACY:  0.8255428781744572



Iteration:   0%|          | 0/1013 [00:00<?, ?it/s][A
Iteration:   0%|          | 1/1013 [00:01<17:31,  1.04s/it][A
Iteration:   0%|          | 2/1013 [00:02<17:28,  1.04s/it][A
Iteration:   0%|          | 3/1013 [00:03<17:25,  1.04s/it][A
Iteration:   0%|          | 4/1013 [00:04<17:25,  1.04s/it][A
Iteration:   0%|          | 5/1013 [00:05<17:22,  1.03s/it][A
Iteration:   1%|          | 6/1013 [00:06<17:21,  1.03s/it][A
Iteration:   1%|          | 7/1013 [00:07<17:20,  1.03s/it][A
Iteration:   1%|          | 8/1013 [00:08<17:19,  1.03s/it][A
Iteration:   1%|          | 9/1013 [00:09<17:17,  1.03s/it][A
Iteration:   1%|          | 10/1013 [00:10<17:17,  1.03s/it][A
Iteration:   1%|          | 11/1013 [00:11<17:15,  1.03s/it][A
Iteration:   1%|          | 12/1013 [00:12<17:14,  1.03s/it][A
Iteration:   1%|▏         | 13/1013 [00:13<17:14,  1.03s/it][A
Iteration:   1%|▏         | 14/1013 [00:14<17:12,  1.03s/it][A
Iteration:   1%|▏         | 15/1013 [00:15<17:11,  1.03s/


====Evaluation====



Evaluating:   0%|          | 0/85 [00:00<?, ?it/s][A
Evaluating:   1%|          | 1/85 [00:00<00:49,  1.71it/s][A
Evaluating:   2%|▏         | 2/85 [00:01<00:48,  1.69it/s][A
Evaluating:   4%|▎         | 3/85 [00:01<00:48,  1.69it/s][A
Evaluating:   5%|▍         | 4/85 [00:02<00:47,  1.69it/s][A
Evaluating:   6%|▌         | 5/85 [00:02<00:47,  1.69it/s][A
Evaluating:   7%|▋         | 6/85 [00:03<00:46,  1.69it/s][A
Evaluating:   8%|▊         | 7/85 [00:04<00:46,  1.69it/s][A
Evaluating:   9%|▉         | 8/85 [00:04<00:45,  1.69it/s][A
Evaluating:  11%|█         | 9/85 [00:05<00:44,  1.69it/s][A
Evaluating:  12%|█▏        | 10/85 [00:05<00:44,  1.69it/s][A
Evaluating:  13%|█▎        | 11/85 [00:06<00:43,  1.69it/s][A
Evaluating:  14%|█▍        | 12/85 [00:07<00:43,  1.69it/s][A
Evaluating:  15%|█▌        | 13/85 [00:07<00:42,  1.69it/s][A
Evaluating:  16%|█▋        | 14/85 [00:08<00:42,  1.69it/s][A
Evaluating:  18%|█▊        | 15/85 [00:08<00:41,  1.69it/s][A
Evaluatin


ACCURACY:  0.8295914611704085





Model result on test data

In [None]:
result = evaluate(model, device, test_loader)
print("\nAccuracy: {}\nF1-score:{}".format(result['accuracy'], result['f1_score']))
# labels.iloc[result['pred']].values

Evaluating: 100%|██████████| 85/85 [00:50<00:00,  1.68it/s]


Accuracy: 0.8295914611704085
F1-score:0.7977865887682739





Save model

In [None]:
torch.save(model, 'model.pt')

Official F1 score metric without 'Other' class

In [None]:
def evaluate_f1_excluding_classes(model, device, test_loader, excluded_classes=[]):
    """
    Evaluates the model on F1 score excluding specific classes.

    Parameters:
    - model: The trained model.
    - device: Device (cpu or cuda).
    - test_loader: DataLoader for the test set.
    - excluded_classes: List of class labels to exclude from the evaluation.

    Returns:
    - result: A dictionary containing the accuracy and F1 score.
    """
    eval_loss = 0.0
    nb_eval_steps = 0
    preds = None
    out_label_ids = None

    model.eval()

    for batch in tqdm(test_loader, desc="Evaluating"):
        batch = tuple(t.to(device) for t in batch)
        with torch.no_grad():
            inputs = {'input_ids': batch[0],
                      'attention_mask': batch[1],
                      'labels': batch[2],
                      'e1_mask': batch[3],
                      'e2_mask': batch[4]}
            outputs = model(**inputs)
            tmp_eval_loss, logits = outputs[:2]

            eval_loss += tmp_eval_loss.mean().item()
        nb_eval_steps += 1

        if preds is None:
            preds = logits.detach().cpu().numpy()
            out_label_ids = inputs['labels'].detach().cpu().numpy()
        else:
            preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
            out_label_ids = np.append(out_label_ids, inputs['labels'].detach().cpu().numpy(), axis=0)

    eval_loss = eval_loss / nb_eval_steps
    preds = np.argmax(preds, axis=1)

    # Mask out the excluded classes from the predictions and labels
    mask = np.isin(out_label_ids, excluded_classes, invert=True)  # Mask out the excluded classes
    filtered_preds = preds[mask]
    filtered_labels = out_label_ids[mask]

    # Calculate F1 score excluding the specified classes
    f1 = f1_score(filtered_labels, filtered_preds, average='macro')

    result = {'f1_score': f1, 'pred': filtered_preds}
    return result


Improved f1 score (18 classes) for the approach

In [None]:
result1 = evaluate_f1_excluding_classes(model, device, test_loader, excluded_classes=[0])
print("\nF1-score (excluding class 0):", result1['f1_score'])

Evaluating: 100%|██████████| 85/85 [00:50<00:00,  1.67it/s]


F1-score (excluding class 0): 0.8054423006727647



