# Multi-label text classification using BERT

## Imports

In [1]:
%load_ext autoreload
%autoreload 2


import os
from datetime import datetime
from typing import List
import json
import glob
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from defi_textmine_2025.data.utils import TARGET_COL, LOGGING_DIR, INTERIM_DIR, MODELS_DIR, submission_path
import logging 

logging.basicConfig(
     level=logging.INFO, 
     format= '[%(asctime)s|%(levelname)s|%(module)s.py:%(lineno)s] %(message)s',
     datefmt='%H:%M:%S',
     filename=os.path.join(LOGGING_DIR, f'{datetime.now().strftime("%Y%m%dT%H%M%S")}.log')
 )
logging.getLogger().addHandler(logging.StreamHandler())
logging.info("## Imports")
import tqdm.notebook as tq
from tqdm import tqdm
# Create new `pandas` methods which use `tqdm` progress
# (can use tqdm_gui, optional kwargs, etc.)
tqdm.pandas()
from collections import defaultdict

from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.metrics import confusion_matrix, classification_report, f1_score
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
from transformers import AdamW, FlaubertTokenizer, FlaubertModel

from defi_textmine_2025.data.utils import load_test_raw_data
from defi_textmine_2025.data.problem_formulation import TextToMultiLabelDataGenerator
from defi_textmine_2025.bert_dataset_and_models import LinearHeadFlaubertBasedModel, loss_fn

# BASE_CHECKPOINT = "camembert/camembert-base"
BASE_CHECKPOINT = "flaubert/flaubert_base_cased"
EMBEDDING_SIZE = 768 # 768 # 1024
TASK_NAME = "multilabel_tagged_text"
NOTA_LABEL = '#NOTA'  # out-of-domain, none of the above

entity_classes = {'TERRORIST_OR_CRIMINAL', 'LASTNAME', 'LENGTH', 'NATURAL_CAUSES_DEATH', 'COLOR', 'STRIKE', 'DRUG_OPERATION', 'HEIGHT', 'INTERGOVERNMENTAL_ORGANISATION', 'TRAFFICKING', 'NON_MILITARY_GOVERNMENT_ORGANISATION', 'TIME_MIN', 'DEMONSTRATION', 'TIME_EXACT', 'FIRE', 'QUANTITY_MIN', 'MATERIEL', 'GATHERING', 'PLACE', 'CRIMINAL_ARREST', 'CBRN_EVENT', 'ECONOMICAL_CRISIS', 'ACCIDENT', 'LONGITUDE', 'BOMBING', 'MATERIAL_REFERENCE', 'WIDTH', 'FIRSTNAME', 'MILITARY_ORGANISATION', 'CIVILIAN', 'QUANTITY_MAX', 'CATEGORY', 'POLITICAL_VIOLENCE', 'EPIDEMIC', 'TIME_MAX', 'TIME_FUZZY', 'NATURAL_EVENT', 'SUICIDE', 'CIVIL_WAR_OUTBREAK', 'POLLUTION', 'ILLEGAL_CIVIL_DEMONSTRATION', 'NATIONALITY', 'GROUP_OF_INDIVIDUALS', 'QUANTITY_FUZZY', 'RIOT', 'WEIGHT', 'THEFT', 'MILITARY', 'NON_GOVERNMENTAL_ORGANISATION', 'LATITUDE', 'COUP_D_ETAT', 'ELECTION', 'HOOLIGANISM_TROUBLEMAKING', 'QUANTITY_EXACT', 'AGITATING_TROUBLE_MAKING'}
categories_to_check = ['END_DATE', 'GENDER_MALE', 'WEIGHS', 'DIED_IN', 'HAS_FAMILY_RELATIONSHIP', 'IS_DEAD_ON', 'IS_IN_CONTACT_WITH', 'HAS_CATEGORY', 'HAS_CONTROL_OVER', 'IS_BORN_IN', 'IS_OF_SIZE', 'HAS_LATITUDE', 'IS_PART_OF', 'IS_OF_NATIONALITY', 'IS_COOPERATING_WITH', 'DEATHS_NUMBER', 'HAS_FOR_HEIGHT', 'INITIATED', 'WAS_DISSOLVED_IN', 'HAS_COLOR', 'CREATED', 'IS_LOCATED_IN', 'WAS_CREATED_IN', 'IS_AT_ODDS_WITH', 'HAS_CONSEQUENCE', 'HAS_FOR_LENGTH', 'INJURED_NUMBER', 'START_DATE', 'STARTED_IN', 'GENDER_FEMALE', 'HAS_LONGITUDE', 'RESIDES_IN', 'HAS_FOR_WIDTH', 'IS_BORN_ON', 'HAS_QUANTITY', 'OPERATES_IN', 'IS_REGISTERED_AS']

mlb = MultiLabelBinarizer()
mlb.fit([categories_to_check])
logging.info(f"{mlb.classes_=}")

generated_data_dir_path = os.path.join(INTERIM_DIR, "multilabel_tagged_text_dataset")
assert os.path.exists(generated_data_dir_path)

preprocessed_data_dir = os.path.join(INTERIM_DIR, "one_hot_multilabel_tagged_text_dataset")
labeled_preprocessed_data_dir_path = os.path.join(preprocessed_data_dir,"train")
! mkdir -p {labeled_preprocessed_data_dir_path}

model_dir_path = os.path.join(MODELS_DIR, f"finetuned-{BASE_CHECKPOINT}")
! mkdir -p {model_dir_path}
model_dict_state_path = os.path.join(model_dir_path,"MLTC_model_state.bin")

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
device

## Imports
mlb.classes_=array(['CREATED', 'DEATHS_NUMBER', 'DIED_IN', 'END_DATE', 'GENDER_FEMALE',
       'GENDER_MALE', 'HAS_CATEGORY', 'HAS_COLOR', 'HAS_CONSEQUENCE',
       'HAS_CONTROL_OVER', 'HAS_FAMILY_RELATIONSHIP', 'HAS_FOR_HEIGHT',
       'HAS_FOR_LENGTH', 'HAS_FOR_WIDTH', 'HAS_LATITUDE', 'HAS_LONGITUDE',
       'HAS_QUANTITY', 'INITIATED', 'INJURED_NUMBER', 'IS_AT_ODDS_WITH',
       'IS_BORN_IN', 'IS_BORN_ON', 'IS_COOPERATING_WITH', 'IS_DEAD_ON',
       'IS_IN_CONTACT_WITH', 'IS_LOCATED_IN', 'IS_OF_NATIONALITY',
       'IS_OF_SIZE', 'IS_PART_OF', 'IS_REGISTERED_AS', 'OPERATES_IN',
       'RESIDES_IN', 'STARTED_IN', 'START_DATE', 'WAS_CREATED_IN',
       'WAS_DISSOLVED_IN', 'WEIGHS'], dtype=object)


device(type='cuda')

In [2]:
def load_csv(dir_or_file_path: str, index_col=None, sep=',') -> pd.DataFrame:
    if os.path.isdir(dir_or_file_path):
        all_files = glob.glob(os.path.join(dir_or_file_path , "*.csv"))  
    else:
        assert dir_or_file_path.endswith(".csv")
        all_files = [dir_or_file_path]
    assert len(all_files) > 0
    return pd.concat([pd.read_csv(filename, index_col=index_col, header=0, sep=sep) for filename in all_files], axis=0, ignore_index=True)

def process_data(data: pd.DataFrame) -> pd.DataFrame:
    return pd.concat([data, pd.DataFrame(mlb.transform(data[TARGET_COL]), columns=mlb.classes_, index=data.index)], axis=1) # .drop([TARGET_COL], axis=1)


def format_relations_str_to_list(labels_as_str: str) -> List[str]:
    return json.loads(
        labels_as_str.replace("{", "[").replace("}", "]").replace("'", '"')
    )  if not pd.isnull(labels_as_str) else []


def process_csv_to_csv(in_dir_or_file_path: str, out_dir_path: str) -> None:
    """Convert labels, i.e. list of relations category, into one-hot vectors

    Args:
        in_dir_or_file_path (str): str
        out_dir_path (str): str
    """
    if not os.path.exists(out_dir_path):
        os.makedirs(out_dir_path)
    if os.path.isdir(in_dir_or_file_path):
        all_files = glob.glob(os.path.join(in_dir_or_file_path , "*.csv"))  
    else:
        assert in_dir_or_file_path.endswith(".csv")
        all_files = [in_dir_or_file_path]
    for filename in (pb:=tqdm(all_files)):
        pb.set_description(filename)
        preprocessed_data_filename = os.path.join(out_dir_path, os.path.basename(filename))
        df = load_csv(filename).assign(**{TARGET_COL: lambda df: df[TARGET_COL].apply(format_relations_str_to_list)})
        process_data(df).to_csv(preprocessed_data_filename, sep="\t")

## Preprocess and save data

- load generated data
- convert to dataframe
- convert categories into one-hot labels
- save into a tsv file

In [3]:
logging.info("## Preprocess and save data...")
process_csv_to_csv(os.path.join(generated_data_dir_path, "train"), labeled_preprocessed_data_dir_path)

## Preprocess and save data...
data/defi-text-mine-2025/interim/multilabel_tagged_text_dataset/train/41884.csv: 100%|██████████| 800/800 [00:04<00:00, 164.39it/s]


## Load preprocessed data

In [4]:
logging.info("## Load preprocessed data...")

## Load preprocessed data...


In [5]:
labeled_df = load_csv(labeled_preprocessed_data_dir_path, index_col=0, sep='\t')

In [6]:
labeled_df.sample(5)

Unnamed: 0,text_index,e1,e2,text,relations,CREATED,DEATHS_NUMBER,DIED_IN,END_DATE,GENDER_FEMALE,...,IS_OF_SIZE,IS_PART_OF,IS_REGISTERED_AS,OPERATES_IN,RESIDES_IN,STARTED_IN,START_DATE,WAS_CREATED_IN,WAS_DISSOLVED_IN,WEIGHS
101348,11683,10,5,\nUne cinquantaine de <e2><GROUP_OF_INDIVIDUAL...,[],0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
11782,41712,9,3,Un <e2><DEMONSTRATION>rassemblement</e2> a été...,[],0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
60235,3643,1,9,Un incident a eu lieu le 13 février 2010 au Ma...,[],0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
92089,11867,0,14,Des centaines de manifestants vêtus de tee-shi...,[],0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
116872,3834,8,17,Une manifestation de l'association des étudian...,"['IS_LOCATED_IN', 'HAS_CONTROL_OVER']",0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


## Train-Validation split

- split such that each category exist at least in the train dataset

In [7]:
logging.info("## Train-Validation split...")

random_seed = 3508  # np.random.randint(10000)  # chosen such as each class has at least 3 validation examples
print(f"{random_seed=}")

train_df, valid_df = train_test_split(labeled_df, test_size = 0.3, shuffle=True, random_state=random_seed)
train_df.shape, valid_df.shape

train_category_sizes = train_df[mlb.classes_].sum(axis=0)
val_category_sizes = valid_df[mlb.classes_].sum(axis=0)
train_val_category_sizes_df = pd.DataFrame({"train": train_category_sizes, "valid": val_category_sizes}).sort_values("train", ascending=True)
train_val_category_sizes_df

## Train-Validation split...


random_seed=3508


Unnamed: 0,train,valid
HAS_LATITUDE,7,3
HAS_FOR_HEIGHT,8,4
HAS_LONGITUDE,9,3
HAS_FOR_WIDTH,9,5
WAS_CREATED_IN,10,5
WAS_DISSOLVED_IN,10,4
HAS_FOR_LENGTH,11,5
IS_BORN_ON,15,5
IS_REGISTERED_AS,25,9
DIED_IN,26,15


In [8]:
df_train_with_relation = train_df[train_df[mlb.classes_].sum(axis=1) >= 1]
df_train_with_relation.shape

(18453, 42)

In [9]:
df_train_without_relation = train_df[train_df[mlb.classes_].sum(axis=1) == 0]
df_train_without_relation.shape

(68350, 42)

### Chances of having a class in the training batch

In [10]:
BATCH_SIZE = 16
# train_val_category_sizes_df.assign(in_batch_proba = train_category_sizes.map(lambda category_size: 1 - ((train_df.shape[0] - category_size) / train_df.shape[0])**BATCH_SIZE))
train_val_category_sizes_df = train_val_category_sizes_df.assign(train_in_batch_proba = train_category_sizes.map(lambda categ_size: 1 - np.prod([(train_df.shape[0] - categ_size - i) / (train_df.shape[0] - i) for i in range(BATCH_SIZE)])))
train_val_category_sizes_df

Unnamed: 0,train,valid,train_in_batch_proba
HAS_LATITUDE,7,3,0.00129
HAS_FOR_HEIGHT,8,4,0.001474
HAS_LONGITUDE,9,3,0.001658
HAS_FOR_WIDTH,9,5,0.001658
WAS_CREATED_IN,10,5,0.001842
WAS_DISSOLVED_IN,10,4,0.001842
HAS_FOR_LENGTH,11,5,0.002026
IS_BORN_ON,15,5,0.002762
IS_REGISTERED_AS,25,9,0.004599
DIED_IN,26,15,0.004782


## Create the tokenized datasets for model input

In [11]:
# Hyperparameters
MAX_LEN = 150
# do_lowercase=False if using cased models, True if using uncased ones
tokenizer = FlaubertTokenizer.from_pretrained(BASE_CHECKPOINT, do_lowercase=False)
task_special_tokens = ["<e1>", "</e1>", "<e2>", "</e2>"] + [
    f"<{entity_class}>" for entity_class in entity_classes
]
# add special tokens to the tokenizer
num_added_tokens = tokenizer.add_tokens(task_special_tokens, special_tokens=True)
num_added_tokens



59

In [12]:
# Test the tokenizer
test_text = "La <e2><NON_MILITARY_GOVERNMENT_ORGANISATION>police</e2> tchèque a <e2><NON_MILITARY_GOVERNMENT_ORGANISATION>mis la main</e2> sur le couple responsable d'un trafic d'œuvres d'art. Il s'agit de <e1><TERRORIST_OR_CRIMINAL>Patel</e1> et Mirna Maroski. Une <e2><NON_MILITARY_GOVERNMENT_ORGANISATION>perquisition</e2> à leur domicile a permis de retrouver une centaine de tableaux d'artistes européens. Il y avait également des pots en céramique et en porcelaine d'origine chinoise, ainsi que plusieurs faux documents de voyage. Les époux Maroski ont été conduits au poste de <e2><NON_MILITARY_GOVERNMENT_ORGANISATION>police</e2> dans un véhicule blindé. Mirna Maroski s'est évanouie une fois arrivée au poste. Elle a été amenée en ambulance au CHU de Motol où elle a été soignée. Monsieur Sergueï Alekseï, le directeur de l'hôpital, a demandé à ses collaborateurs d'être vigilants et de ne pas se laisser corrompre par la criminelle."
# generate encodings
encodings = tokenizer.encode_plus(test_text, 
                                  add_special_tokens = True,
                                  max_length = MAX_LEN,
                                  truncation = True,
                                  padding = "max_length", 
                                  return_attention_mask = True, 
                                  return_tensors = "pt")
# we get a dictionary with three keys (see: https://huggingface.co/transformers/glossary.html) 
encodings

{'input_ids': tensor([[    0,    60, 68731, 68733,   873, 68732,  9307,    34, 68731, 68733,
           364,    17,   590, 68732,    37,    20,  1791,  1148,    24,    26,
          2951,    24,  2515,    24,  7069,    59,    53,   444,    15, 68729,
         68761,  9416,   615, 68730,    18,  7567,  3032,  1005,  1949,  3209,
            16,   154, 68731, 68733, 29390, 68732,    19,    81,  2254,    34,
           821,    15,  1307,    30,  5266,    15,  4678,    24,  1867,  2160,
            16,    59,    66,   110,   131,    23, 11514,    25, 12449,    18,
            25, 24094,    24,   677,  6571,    14,   120,    32,   182,  2989,
          1395,    15,  1154,    16,    64,  6641,  1005,  1949,  3209,    62,
            69, 14132,    36,  1056,    15, 68731, 68733,   873, 68732,    33,
            26,  1796, 40793,    16,  7567,  3032,  1005,  1949,  3209,    53,
            27, 20429,  1473,    30,   136,  1039,    36,  1056,    16,   145,
            34,    69, 20132,    25, 2

In [13]:
tokenizer.batch_decode(encodings['input_ids'])

["<s>La <e2> <NON_MILITARY_GOVERNMENT_ORGANISATION> police </e2> tchèque a <e2> <NON_MILITARY_GOVERNMENT_ORGANISATION> mis la main </e2> sur le couple responsable d' un trafic d' œuvres d' art. Il s' agit de <e1> <TERRORIST_OR_CRIMINAL> Patel </e1> et Mirna Maroski. Une <e2> <NON_MILITARY_GOVERNMENT_ORGANISATION> perquisition </e2> à leur domicile a permis de retrouver une centaine de tableaux d' artistes européens. Il y avait également des pots en céramique et en porcelaine d' origine chinoise, ainsi que plusieurs faux documents de voyage. Les époux Maroski ont été conduits au poste de <e2> <NON_MILITARY_GOVERNMENT_ORGANISATION> police </e2> dans un véhicule blindé. Mirna Maroski s' est évanouie une fois arrivée au poste. Elle a été amenée en ambulance au CHU de Motol où elle a été soignée. Monsieur Sergueï Alekseï, le directeur de l' hôpital, a </s>"]

In [14]:
len(tokenizer)

68788

In [15]:
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, df, tokenizer, max_len, target_list):
        self.tokenizer = tokenizer
        self.df = df
        # self.e1 = list(df['e1'])
        # self.e1 = list(df['e1'])
        # self.text_indexes = list(df['text_index'])
        self.title = list(df['text'])
        self.targets = self.df[target_list].values
        self.max_len = max_len

    def __len__(self):
        return len(self.title)

    def __getitem__(self, index):
        text = str(self.title[index])
        text = " ".join(text.split())
        inputs = self.tokenizer.encode_plus(
            text,
            None,
            add_special_tokens=True,
            max_length=self.max_len,
            padding='max_length',
            return_token_type_ids=True,
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt'
        )
        return {
            'input_ids': inputs['input_ids'].flatten(),
            'attention_mask': inputs['attention_mask'].flatten(),
            'token_type_ids': inputs["token_type_ids"].flatten(),
            'targets': torch.FloatTensor(self.targets[index]),
            'title': text,
            # 'text_index': self.text_index[index],
            # 'e1': self.e1[index],
            # 'e2': self.e2[index],
        }

In [16]:
# most_common_categories = df_train[mlb.classes_].sum().sort_values(ascending=False).index[:7]
# logging.info(most_common_categories)
# # target_list = mlb.classes_.tolist()
# target_list = most_common_categories
target_list = mlb.classes_
logging.info(f"{len(target_list)} categories = {target_list}")

37 categories = ['CREATED' 'DEATHS_NUMBER' 'DIED_IN' 'END_DATE' 'GENDER_FEMALE'
 'GENDER_MALE' 'HAS_CATEGORY' 'HAS_COLOR' 'HAS_CONSEQUENCE'
 'HAS_CONTROL_OVER' 'HAS_FAMILY_RELATIONSHIP' 'HAS_FOR_HEIGHT'
 'HAS_FOR_LENGTH' 'HAS_FOR_WIDTH' 'HAS_LATITUDE' 'HAS_LONGITUDE'
 'HAS_QUANTITY' 'INITIATED' 'INJURED_NUMBER' 'IS_AT_ODDS_WITH'
 'IS_BORN_IN' 'IS_BORN_ON' 'IS_COOPERATING_WITH' 'IS_DEAD_ON'
 'IS_IN_CONTACT_WITH' 'IS_LOCATED_IN' 'IS_OF_NATIONALITY' 'IS_OF_SIZE'
 'IS_PART_OF' 'IS_REGISTERED_AS' 'OPERATES_IN' 'RESIDES_IN' 'STARTED_IN'
 'START_DATE' 'WAS_CREATED_IN' 'WAS_DISSOLVED_IN' 'WEIGHS']


In [17]:
train_df.shape, valid_df.shape

((86803, 42), (37202, 42))

In [18]:
train_dataset = CustomDataset(train_df, tokenizer, MAX_LEN, target_list)
valid_dataset = CustomDataset(valid_df, tokenizer, MAX_LEN, target_list)

In [19]:
# testing the dataset
next(iter(train_dataset))

{'input_ids': tensor([    0,   154,  2040,   100,   438,    25,  3477,    53,    27,  7399,
            16,   156,    30,  3056,    28,  1303,  4032,   702,    19, 15460,
            14,    26,  9851,  1066,    15,  1012,  1665,    37,   738,   404,
            69, 11994,   278,    23,   707, 11380,    15, 12655,    18,    23,
         35068,    34,   441,   983,    16,   420, 23421,    14,    22,   438,
          2215,    19,    17,  2040,    51,    62,    42, 43727,    19,    83,
         14508,    22, 68731, 68749,  2272,  1414,    15,   836, 68732,   139,
            15,  2729,    22,  1232,    16,  2789,   403,    44,   566,    14,
            26,   983,    53,    27,   996,    33,    20,  1266,    18,    34,
          6436,    22, 15995,    18, 13433, 10014,   956,  8519,    19,   369,
            28,  9851,    16,   113,    27,    21,    26,    23,  3607,    14,
          1902,  2499,  2724, 13036,   114,  9411,    14,   913,    24,    30,
          1069,  7441,    29,   314, 17

## Create data loaders

In [20]:
TRAIN_BATCH_SIZE = 32
VALID_BATCH_SIZE = 32

# Data loaders
train_data_loader = torch.utils.data.DataLoader(train_dataset, 
    batch_size=TRAIN_BATCH_SIZE,
    shuffle=True,
    num_workers=0
)

val_data_loader = torch.utils.data.DataLoader(valid_dataset, 
    batch_size=VALID_BATCH_SIZE,
    shuffle=False,
    num_workers=0
)

## Compute class weights to handle imbalance into the loss function

In [21]:
# Source: https://www.tensorflow.org/tutorials/structured_data/imbalanced_data#calculate_class_weights
# Scaling by total/2 helps keep the loss to a similar magnitude.
n_examples = labeled_df.shape[0]
n_classes = len(target_list)
def get_cat_var_distribution(cat_var: pd.Series) -> pd.DataFrame:
    return pd.concat(
        [cat_var.value_counts(), cat_var.value_counts(normalize=True)], axis=1
    )
def compute_class_weights(lbl_df: pd.DataFrame, class_names: list) -> pd.Series:
    return lbl_df[target_list].sum(axis=0).map(lambda x: (1 / x) * (n_examples / n_classes)).rename("weight")

class_weights = compute_class_weights(labeled_df, target_list)
class_weights_tensor = torch.Tensor(class_weights.values).to(device, dtype = torch.float16)

pd.concat([labeled_df[target_list].sum(axis=0).rename("class_size"), class_weights], axis=1)

Unnamed: 0,class_size,weight
CREATED,126,26.599099
DEATHS_NUMBER,75,44.686486
DIED_IN,41,81.743573
END_DATE,874,3.834653
GENDER_FEMALE,414,8.095378
GENDER_MALE,908,3.691064
HAS_CATEGORY,894,3.748866
HAS_COLOR,91,36.829522
HAS_CONSEQUENCE,769,4.35824
HAS_CONTROL_OVER,4547,0.737076


## Prepare the model to trained

In [22]:
model = LinearHeadFlaubertBasedModel(
    tokenizer=tokenizer,
    embedding_model=FlaubertModel.from_pretrained(BASE_CHECKPOINT, output_loading_info=True)[0],  # flaubert, log =FlaubertModel.from_pretrained()...
    embedding_size=EMBEDDING_SIZE,
    hidden_dim=0,
    n_classes=len(target_list),
)

# Freezing BERT layers: (tested, weaker convergence)
# for param in model.embedding_model.parameters():
#     param.requires_grad = False

model.to(device)



LinearHeadFlaubertBasedModel(
  (embedding_model): FlaubertModel(
    (position_embeddings): Embedding(512, 768)
    (embeddings): Embedding(68788, 768)
    (layer_norm_emb): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (attentions): ModuleList(
      (0-11): 12 x MultiHeadAttention(
        (q_lin): Linear(in_features=768, out_features=768, bias=True)
        (k_lin): Linear(in_features=768, out_features=768, bias=True)
        (v_lin): Linear(in_features=768, out_features=768, bias=True)
        (out_lin): Linear(in_features=768, out_features=768, bias=True)
      )
    )
    (layer_norm1): ModuleList(
      (0-11): 12 x LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    )
    (ffns): ModuleList(
      (0-11): 12 x TransformerFFN(
        (lin1): Linear(in_features=768, out_features=3072, bias=True)
        (lin2): Linear(in_features=3072, out_features=768, bias=True)
        (act): GELUActivation()
      )
    )
    (layer_norm2): ModuleList(
      (0-11): 12 x L

In [23]:
# define the optimizer
optimizer = AdamW(model.parameters(), lr = 1e-5, weight_decay=0.1)
optimizer



AdamW (
Parameter Group 0
    betas: (0.9, 0.999)
    correct_bias: True
    eps: 1e-06
    lr: 1e-05
    weight_decay: 0.1
)

## Function to tain the model

In [24]:
# Training of the model for one epoch
def train_model(model, training_loader, optimizer, _class_weights_tensor):
    predictions = []
    prediction_probs = []
    target_values = []
    losses = []
    correct_predictions = 0
    num_samples = 0
    # set model to training mode (activate dropout, batch norm)
    model.train()
    # initialize the progress bar
    loop = tq.tqdm(enumerate(training_loader), total=len(training_loader), 
                      leave=True, colour='steelblue', desc="training")
    for batch_idx, data in loop:
        ids = data['input_ids'].to(device, dtype = torch.long)
        mask = data['attention_mask'].to(device, dtype = torch.long)
        token_type_ids = data['token_type_ids'].to(device, dtype = torch.long)
        targets = data['targets'].to(device, dtype = torch.float)

        # forward
        outputs = model(ids, mask, token_type_ids) # (batch,predict)=(32,37)
        loss = loss_fn(outputs, targets, _class_weights_tensor)
        losses.append(loss.item())
        # training accuracy, apply sigmoid, round (apply thresh 0.5)
        # outputs = torch.sigmoid(outputs).cpu().detach().numpy().round()
        # targets = targets.cpu().detach().numpy()
        # correct_predictions += np.sum(outputs==targets)
        # num_samples += targets.size   # total number of elements in the 2D array
        outputs = torch.sigmoid(outputs).cpu().detach()
        # thresholding at 0.5
        preds = outputs.round()
        targets = targets.cpu().detach()
        correct_predictions += np.sum(preds.numpy()==targets.numpy())
        num_samples += targets.numpy().size   # total number of elements in the 2D array
        
        # thresholding at 0.5
        preds = outputs.round()        
        predictions.extend(preds)
        prediction_probs.extend(outputs)
        target_values.extend(targets)

        # backward
        optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        # grad descent step
        optimizer.step()
        # Update progress bar
        loop.set_description(f"")
        loop.set_postfix(batch_loss=loss.cpu().detach().numpy())
        # break

    # returning: trained model, model accuracy, mean loss
    predictions = torch.stack(predictions)
    prediction_probs = torch.stack(prediction_probs)
    target_values = torch.stack(target_values)    

    return model, float(correct_predictions)/num_samples, f1_score(target_values, predictions, average="macro", zero_division=0), np.mean(losses)

# torch.cuda.empty_cache()
# train_model(model, train_data_loader, optimizer, class_weights_tensor)

## Function to evaluate the model

In [25]:
def eval_model(model, validation_loader, _class_weights_tensor):    
    predictions = []
    prediction_probs = []
    target_values = []
    losses = []
    correct_predictions = 0
    num_samples = 0
    # set model to eval mode (turn off dropout, fix batch norm)
    model.eval()

    with torch.no_grad():
        # for batch_idx, data in tqdm(enumerate(validation_loader, 0), "evaluating"):
        for data in tqdm(validation_loader, "evaluating"):
            ids = data['input_ids'].to(device, dtype = torch.long)
            mask = data['attention_mask'].to(device, dtype = torch.long)
            token_type_ids = data['token_type_ids'].to(device, dtype = torch.long)
            targets = data['targets'].to(device, dtype = torch.float)
            outputs = model(ids, mask, token_type_ids)

            loss = loss_fn(outputs, targets, _class_weights_tensor)
            losses.append(loss.item())

            # validation accuracy
            # add sigmoid, for the training sigmoid is in BCEWithLogitsLoss
            outputs = torch.sigmoid(outputs).cpu().detach()
            # thresholding at 0.5
            preds = outputs.round()
            targets = targets.cpu().detach()
            correct_predictions += np.sum(preds.numpy()==targets.numpy())
            num_samples += targets.numpy().size   # total number of elements in the 2D array

            predictions.extend(preds)
            prediction_probs.extend(outputs)
            target_values.extend(targets)
            # break
    
    predictions = torch.stack(predictions)
    prediction_probs = torch.stack(prediction_probs)
    target_values = torch.stack(target_values)    

    return float(correct_predictions)/num_samples, f1_score(target_values, predictions, average="macro", zero_division=0), np.mean(losses)


# eval_model(model, train_data_loader, class_weights_tensor)

## Function to apply the model

In [26]:
def get_predictions(model, data_loader):
    """
    Outputs:
      predictions - 
    """
    model = model.eval()
    
    titles = []
    predictions = []
    prediction_probs = []
    target_values = []

    with torch.no_grad():
      for data in tqdm(data_loader, "get_predictions"):
        title = data["title"]
        ids = data["input_ids"].to(device, dtype = torch.long)
        mask = data["attention_mask"].to(device, dtype = torch.long)
        token_type_ids = data['token_type_ids'].to(device, dtype = torch.long)
        targets = data["targets"].to(device, dtype = torch.float)
        
        outputs = model(ids, mask, token_type_ids)
        # add sigmoid, for the training sigmoid is in BCEWithLogitsLoss
        outputs = torch.sigmoid(outputs).detach().cpu()
        # thresholding at 0.5
        preds = outputs.round()
        targets = targets.detach().cpu()

        titles.extend(title)
        predictions.extend(preds)
        prediction_probs.extend(outputs)
        target_values.extend(targets)
        # break
    
    predictions = torch.stack(predictions)
    prediction_probs = torch.stack(prediction_probs)
    target_values = torch.stack(target_values)
    
    return titles, predictions, prediction_probs, target_values

# get_predictions(model, train_data_loader)

## Model Training

In [27]:
if os.path.exists(model_dict_state_path):  # to continue the training from a previous checkpoint
    logging.warning("The training will continue from a previous checkpoint...")
    logging.info("Loading the previous checkpoint...")
    model.load_state_dict(torch.load(model_dict_state_path))
    if not next(model.parameters()).is_cuda: # if model is not in GPU, then load it to GPU
        model = model.to(device)
elif not os.path.exists(os.path.dirname(model_dict_state_path)):
    os.makedirs(os.path.dirname(model_dict_state_path))

EPOCHS = 50
# THRESHOLD = 0.5 # threshold for the sigmoid
PATIENCE = 3
n_not_better_steps = 0
history = defaultdict(list)
best_f1_macro = 0


for epoch in range(1, EPOCHS+1):
    print(f'Epoch {epoch}/{EPOCHS}')
    model, train_acc, train_f1_macro, train_loss = train_model(model, train_data_loader, optimizer, class_weights_tensor)
    val_acc, val_f1_macro, val_loss = eval_model(model, val_data_loader, class_weights_tensor)

    print(f'train_loss={train_loss:.4f}, val_loss={val_loss:.4f} train_f1_macro={train_f1_macro:.4f}, val_f1_macro={val_f1_macro:.4f}')

    history['train_acc'].append(train_acc)
    history['train_f1_macro'].append(train_f1_macro)
    history['train_loss'].append(train_loss)
    history['val_acc'].append(val_acc)
    history['val_f1_macro'].append(val_f1_macro)
    history['val_loss'].append(val_loss)
    # save the best model
    if val_f1_macro > best_f1_macro:
        torch.save(model.state_dict(), model_dict_state_path)
        best_f1_macro = val_f1_macro
        n_not_better_steps = 0
    else: # check for early stopping
        n_not_better_steps += 1
        if n_not_better_steps >= PATIENCE:
            break
    # break

Epoch 1/50


training:   0%|          | 0/2713 [00:00<?, ?it/s]

evaluating: 100%|██████████| 1163/1163 [02:56<00:00,  6.58it/s]


train_loss=0.3462, val_loss=0.1411 train_f1_macro=0.0418, val_f1_macro=0.1695
Epoch 2/50


training:   0%|          | 0/2713 [00:00<?, ?it/s]

evaluating: 100%|██████████| 1163/1163 [02:56<00:00,  6.60it/s]


train_loss=0.1093, val_loss=0.0985 train_f1_macro=0.2783, val_f1_macro=0.3775
Epoch 3/50


training:   0%|          | 0/2713 [00:00<?, ?it/s]

In [None]:
plt.rcParams["figure.figsize"] = (10,7)
plt.plot(history['train_f1_macro'], label='train F1 macro')
plt.plot(history['val_f1_macro'], label='validation F1 macro')
plt.plot(history['train_loss'], label='train loss')
plt.plot(history['val_loss'], label='validation loss')
plt.title('Training history')
plt.ylabel('F1 macro / loss')
plt.xlabel('Epoch')
plt.legend()
plt.ylim([0, 1])
plt.grid()

## Evaluation of the model

In [None]:
# Loading pretrained model (best model)
# model = BERTClass(tokenizer)
model.load_state_dict(torch.load(model_dict_state_path))
model = model.to(device)

In [None]:
# Evaluate the model using the test data
# val_acc, val_f1_macro, val_loss = eval_model(val_data_loader, model)

In [None]:
# The accuracy looks OK, similar to the validation accuracy
# The model generalizes well !
# val_acc

In [None]:
titles, predictions, prediction_probs, target_values = get_predictions(model, val_data_loader)

In [None]:
# Generate Classification Metrics
#
# note that the total support is greater than the number of samples
# some samples have multiple lables

print(classification_report(target_values, predictions, target_names=target_list, zero_division=0))

In [None]:
# import seaborn as sns
# def show_confusion_matrix(confusion_matrix):
#     hmap = sns.heatmap(confusion_matrix, annot=True, fmt="d", cmap="Blues")
#     hmap.yaxis.set_ticklabels(hmap.yaxis.get_ticklabels(), rotation=0, ha='right')
#     hmap.xaxis.set_ticklabels(hmap.xaxis.get_ticklabels(), rotation=30, ha='right')
#     plt.ylabel('True category')
#     plt.xlabel('Predicted category');

In [None]:
# cm = confusion_matrix(target_values, predictions)
# df_cm = pd.DataFrame(cm, index=target_list, columns=target_list)
# show_confusion_matrix(df_cm)

## Prepare submission

In [None]:
df_test = load_csv(os.path.join(generated_data_dir_path, "test")) #.drop(TARGET_COL, axis=1)
df_test

In [None]:
# df_test.head().drop(TARGET_COL, axis=1).assign(**{cat: [0]*df_test.head().shape[0] for cat in target_list})

In [None]:
test_dataset = CustomDataset(df_test.drop(TARGET_COL, axis=1).assign(**{cat: [0]*df_test.shape[0] for cat in target_list}), tokenizer, MAX_LEN, target_list)

In [None]:
TEST_BATCH_SIZE = 32

test_data_loader = torch.utils.data.DataLoader(test_dataset, 
    batch_size=TEST_BATCH_SIZE,
    shuffle=False,
    num_workers=0
)

In [None]:
titles, predictions, prediction_probs, target_values = get_predictions(model, test_data_loader)

In [None]:

ml_labeled_test_df = pd.concat(
    [
        df_test.drop(TARGET_COL, axis=1),
        pd.Series(
            mlb.inverse_transform(
                pd.DataFrame(predictions.numpy(), columns=target_list, index=df_test.index)[mlb.classes_].values
            ),
            name=TARGET_COL,
            index=df_test.index
        )
    ],
    axis=1
)
ml_labeled_test_df

In [None]:
text_idx_to_relations = {
    text_index: [l[0] for l in group_df.drop(["text_index", "text"], axis=1)[group_df.relations.str.len()>0].apply(lambda row: [[row.iloc[0], r, row.iloc[1]] for r in row.iloc[-1]] if len(row.iloc[-1]) > 0 else [], axis=1).values.tolist()]
 for text_index, group_df in tqdm(ml_labeled_test_df.groupby("text_index"))
}

In [None]:
text_idx_to_relations[13]

In [None]:
test_index = load_test_raw_data().index
test_index

In [None]:
submission_df = pd.DataFrame({"id": list(text_idx_to_relations.keys()), TARGET_COL: list(text_idx_to_relations.values())}).set_index("id").loc[load_test_raw_data().index]
submission_df = submission_df.assign(relations= submission_df.relations.map(lambda x: str(x).replace("'", '"')))
submission_df

In [None]:
submission_df.to_csv(submission_path)

In [None]:
print(f"titles:{len(titles)} \npredictions:{predictions.shape} \nprediction_probs:{prediction_probs.shape} \ntarget_values:{target_values.shape}")