In [1]:
import random
random.seed(123)

In [2]:
import numpy as np
import os
from datetime import datetime

import datasets
import evaluate

import torch
import torch.nn as nn

from transformers import Trainer, TrainerCallback, TrainingArguments, EarlyStoppingCallback
from transformers import AutoTokenizer, BertConfig, BertModel, BertPreTrainedModel

In [3]:
MAX_LENGTH = 512
NUM_TRAIN_EPOCHS = 10

timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')

RESULTS_DIRECTORY = './results/experiment_{}'.format(timestamp)

LOGGING_DIRECTORY = './logs/experiement_{}'.format(timestamp)

RESULTS_DIRECTORY, LOGGING_DIRECTORY

('./results/experiment_2024-06-20_19-47-07',
 './logs/experiement_2024-06-20_19-47-07')

In [4]:
PRODUCT_LABELS = [
    'Vehicle loan or lease', 'Student loan', 'Consumer Loan'
]

ISSUE_LABELS = [
    'Managing the loan or lease', 'Problems at the end of the loan or lease', 'Struggling to pay your loan', 'Getting a loan or lease', 'Dealing with your lender or servicer', 'Incorrect information on your report', 'Problems when you are unable to pay', 'Taking out the loan or lease'
]

PRODUCT_LABELS_TO_IDX = {
    i : idx
    for idx, i in enumerate(PRODUCT_LABELS)
}

IDX_TO_PRODUCT_LABELS = {
    i: j
    for j, i in PRODUCT_LABELS_TO_IDX.items()
}

ISSUE_LABELS_TO_IDX = {
    i : idx
    for idx, i in enumerate(ISSUE_LABELS)
}

IDX_TO_ISSUE_LABELS = {
    i: j
    for j, i in ISSUE_LABELS_TO_IDX.items()
}


MAX_SEQUENCE_LENGTH = 512
PRETRAINED_MODEL_NAME = "distilbert-base-uncased"
NUM_TRAIN_EPOCHS = 10

In [5]:
# tokenize function
def tokenize(example, tokenizer):
    return tokenizer(
        example['consumer_complaint'], padding=True, truncation=True
    )


# label encoding
def label_encoder(example):
    example['product_label'] = PRODUCT_LABELS_TO_IDX[example['product']]     
    example['issue_label'] = ISSUE_LABELS_TO_IDX[example['issue']]
    return example


In [6]:
# initializing Config and Tokenizer
BERT_CONFIG = BertConfig.from_pretrained(PRETRAINED_MODEL_NAME)
TOKENIZER = AutoTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)

You are using a model of type distilbert to instantiate a model of type bert. This is not supported for all configurations of models and can yield errors.


In [7]:
# loading data
data = datasets.load_from_disk('data/complaints_dataset_obj')
data

DatasetDict({
    train: Dataset({
        features: ['product', 'issue', 'consumer_complaint', '__index_level_0__'],
        num_rows: 6800
    })
    val: Dataset({
        features: ['product', 'issue', 'consumer_complaint', '__index_level_0__'],
        num_rows: 1200
    })
    test: Dataset({
        features: ['product', 'issue', 'consumer_complaint', '__index_level_0__'],
        num_rows: 2000
    })
})

In [8]:
data = data.map(lambda x: tokenize(x, tokenizer=TOKENIZER))
data = data.map(lambda x: label_encoder(x))

data

Loading cached processed dataset at /run/media/kuldeepsingh/Work/college_stuff/github/multi_task_text_classification/data/complaints_dataset_obj/train/cache-b9cfa2472282e060.arrow
Loading cached processed dataset at /run/media/kuldeepsingh/Work/college_stuff/github/multi_task_text_classification/data/complaints_dataset_obj/val/cache-4cfff64b2cdcac2a.arrow
Loading cached processed dataset at /run/media/kuldeepsingh/Work/college_stuff/github/multi_task_text_classification/data/complaints_dataset_obj/test/cache-47425c01f75b9f93.arrow
Loading cached processed dataset at /run/media/kuldeepsingh/Work/college_stuff/github/multi_task_text_classification/data/complaints_dataset_obj/train/cache-4f752077556c36c9.arrow
Loading cached processed dataset at /run/media/kuldeepsingh/Work/college_stuff/github/multi_task_text_classification/data/complaints_dataset_obj/val/cache-bee9ca709a15e723.arrow
Loading cached processed dataset at /run/media/kuldeepsingh/Work/college_stuff/github/multi_task_text_cla

DatasetDict({
    train: Dataset({
        features: ['product', 'issue', 'consumer_complaint', '__index_level_0__', 'input_ids', 'attention_mask', 'product_label', 'issue_label'],
        num_rows: 6800
    })
    val: Dataset({
        features: ['product', 'issue', 'consumer_complaint', '__index_level_0__', 'input_ids', 'attention_mask', 'product_label', 'issue_label'],
        num_rows: 1200
    })
    test: Dataset({
        features: ['product', 'issue', 'consumer_complaint', '__index_level_0__', 'input_ids', 'attention_mask', 'product_label', 'issue_label'],
        num_rows: 2000
    })
})

In [9]:
# Defining the metrics 
# we will be using F1 scores for both product and issues 
# the best model will be decided using the f1 score for the product label (can be done using a weighted sum also)
PRODUCT_METRIC = evaluate.load("f1")
ISSUE_METRIC = evaluate.load("f1")

def compute_metrics(eval_pred):
    all_logits, all_labels = eval_pred
    product_logits, issue_logits = all_logits 
    product_labels, issue_labels = all_labels

    product_predictions = np.argmax(product_logits, axis=-1)
    issue_predictions = np.argmax(issue_logits, axis=-1)
    
    product_computed_metrics = PRODUCT_METRIC.compute(predictions=product_predictions, references=product_labels, average='weighted')
    issue_computed_metrics = ISSUE_METRIC.compute(predictions=issue_predictions, references=issue_labels, average='weighted')
    
    return {
        'f1_product': product_computed_metrics['f1'],
        'f1_issue': issue_computed_metrics['f1'],
    }

In [10]:
# model definition
class MultiTaskSentencePrediction(BertPreTrainedModel):
    def __init__(self, config, num_product_labels, num_issue_labels):
        super().__init__(config)
        self.num_product_labels = num_product_labels
        self.num_issue_labels = num_issue_labels

        self.bert = BertModel(config)

        self.product_classifier = nn.Linear(config.hidden_size, num_product_labels)
        self.issue_classifier = nn.Linear(config.hidden_size, num_issue_labels)

        classifier_dropout = config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
        
        self.dropout = nn.Dropout(classifier_dropout)
        self.init_weights()

    def forward(
            self, input_ids, attention_mask=None, token_type_ids=None, 
            product_label=None, issue_label=None 
    ):
        outputs = self.bert(
            input_ids, attention_mask=attention_mask,
            token_type_ids=token_type_ids
        )
        pooled_outputs = self.dropout(outputs[1])
        
        logits_product = self.product_classifier(pooled_outputs)
        logits_issue = self.issue_classifier(pooled_outputs)

        loss = None
        if product_label != None and issue_label != None:
            loss_fct1 = nn.CrossEntropyLoss()
            loss_fct2 = nn.CrossEntropyLoss()

            loss = loss_fct1(
                logits_product.view(-1, self.num_product_labels),
                product_label.view(-1)
            ) + loss_fct2(
                logits_issue.view(-1, self.num_issue_labels),
                issue_label.view(-1)
            )
        
        return (loss, logits_product, logits_issue) if loss is not None else (logits_product, logits_issue)

In [11]:
def data_collator(batch, padding_token_id=TOKENIZER.pad_token_id):
    input_ids = [item["input_ids"][:MAX_SEQUENCE_LENGTH] for item in batch]
    attention_masks = [item["attention_mask"][:MAX_SEQUENCE_LENGTH] for item in batch]
    # token_type_ids = [item['token_type_ids'][:MAX_SEQUENCE_LENGTH] for item in batch]
    product_label = [item["product_label"] for item in batch]
    issue_label = [item["issue_label"] for item in batch]

    max_len = max(len(ids) for ids in input_ids)
    input_ids = torch.tensor([ids + [padding_token_id] * (max_len - len(ids)) for ids in input_ids])
    attention_masks = torch.tensor([masks + [padding_token_id] * (max_len - len(masks)) for masks in attention_masks])
    # token_type_ids = torch.tensor([ids + [padding_token_id] * (max_len - len(ids)) for ids in token_type_ids])
    product_label = torch.tensor([i for i in product_label])
    issue_label = torch.tensor([i for i in issue_label])
    
    return {
        "input_ids": input_ids, 
        "attention_mask": attention_masks, 
        # 'token_type_ids': token_type_ids,
        "product_label": product_label, 
        'issue_label': issue_label
    }

In [12]:
model = MultiTaskSentencePrediction.from_pretrained(
    PRETRAINED_MODEL_NAME,
    config=BERT_CONFIG,
    num_product_labels=len(PRODUCT_LABELS), num_issue_labels=len(ISSUE_LABELS)
)

Some weights of MultiTaskSentencePrediction were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['encoder.layer.3.attention.self.value.bias', 'encoder.layer.4.attention.self.query.bias', 'encoder.layer.10.output.LayerNorm.weight', 'encoder.layer.4.output.dense.bias', 'encoder.layer.3.attention.output.LayerNorm.weight', 'encoder.layer.7.attention.self.query.weight', 'issue_classifier.bias', 'encoder.layer.3.attention.output.dense.weight', 'encoder.layer.5.intermediate.dense.weight', 'encoder.layer.5.output.LayerNorm.weight', 'encoder.layer.3.output.dense.weight', 'encoder.layer.7.attention.self.key.weight', 'encoder.layer.9.output.dense.bias', 'encoder.layer.10.output.LayerNorm.bias', 'pooler.dense.bias', 'encoder.layer.7.attention.self.value.bias', 'encoder.layer.9.intermediate.dense.weight', 'encoder.layer.7.attention.self.value.weight', 'encoder.layer.1.output.dense.bias', 'encoder.layer.5.attention.output.LayerNorm.weight', 'encoder.l

In [13]:
output_directory = RESULTS_DIRECTORY
evaluation_strategy = 'epoch'
per_device_train_batch_size = 4
per_device_eval_batch_size = 4
gradint_accumulation_steps = 2
learning_rate = 2e-5
weight_decay = 0.01
max_grad_norm = 1
num_train_epochs = NUM_TRAIN_EPOCHS
lr_scheduler_type = 'linear'
warmup_ratio = 0.05
logging_dir = LOGGING_DIRECTORY
logging_strategy = 'epoch'
save_strategy = 'epoch'
save_total_limit = 1
label_names = ['product_label', 'issue_label']
load_best_model_at_end = True
metric_for_best_model = 'eval_f1_product'
greater_is_better = True
label_smoothing_factor = 0
report_to = 'tensorboard'
gradient_checkpointing = False

In [14]:
# Setup training arguments
training_args = TrainingArguments(
    output_dir=output_directory,
    evaluation_strategy=evaluation_strategy,
    learning_rate=learning_rate,
    per_device_train_batch_size=per_device_train_batch_size,
    per_device_eval_batch_size=per_device_eval_batch_size,
    num_train_epochs=num_train_epochs,
    weight_decay=weight_decay,
    logging_dir=logging_dir,
    label_names=label_names,
    max_grad_norm=max_grad_norm,
    lr_scheduler_type=lr_scheduler_type,
    warmup_ratio=warmup_ratio,
    logging_strategy=logging_strategy,
    save_strategy=save_strategy,
    save_total_limit=save_total_limit,
    load_best_model_at_end=load_best_model_at_end,
    metric_for_best_model=metric_for_best_model,
    greater_is_better=greater_is_better,
    label_smoothing_factor=label_smoothing_factor,
    report_to=report_to,
    gradient_checkpointing=gradient_checkpointing
)

early_stop_callback = EarlyStoppingCallback(3)

In [15]:
# Initialize the Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=data['train'],
    eval_dataset=data['test'],
    # tokenizer=tokenizer,
    compute_metrics=compute_metrics,
    data_collator=data_collator,
    callbacks=[early_stop_callback]
)

dataloader_config = DataLoaderConfiguration(dispatch_batches=None)


In [16]:
trainer.train()

  0%|          | 0/17000 [00:00<?, ?it/s]

{'loss': 2.5144, 'learning_rate': 1.894736842105263e-05, 'epoch': 1.0}


  0%|          | 0/500 [00:00<?, ?it/s]

{'eval_loss': 2.050814390182495, 'eval_f1_product': 0.7330244035915434, 'eval_f1_issue': 0.381744481523802, 'eval_runtime': 29.7564, 'eval_samples_per_second': 67.212, 'eval_steps_per_second': 16.803, 'epoch': 1.0}
{'loss': 1.8145, 'learning_rate': 1.6842105263157896e-05, 'epoch': 2.0}


  0%|          | 0/500 [00:00<?, ?it/s]

{'eval_loss': 1.725622534751892, 'eval_f1_product': 0.7874698075495403, 'eval_f1_issue': 0.4062835622465695, 'eval_runtime': 27.0552, 'eval_samples_per_second': 73.923, 'eval_steps_per_second': 18.481, 'epoch': 2.0}
{'loss': 1.6667, 'learning_rate': 1.4736842105263159e-05, 'epoch': 3.0}


  0%|          | 0/500 [00:00<?, ?it/s]

{'eval_loss': 1.8767348527908325, 'eval_f1_product': 0.8008055543134794, 'eval_f1_issue': 0.43789558345102136, 'eval_runtime': 27.5807, 'eval_samples_per_second': 72.515, 'eval_steps_per_second': 18.129, 'epoch': 3.0}
{'loss': 1.5752, 'learning_rate': 1.263157894736842e-05, 'epoch': 4.0}


  0%|          | 0/500 [00:00<?, ?it/s]

{'eval_loss': 1.7505781650543213, 'eval_f1_product': 0.8061989367601342, 'eval_f1_issue': 0.46253332276111997, 'eval_runtime': 27.511, 'eval_samples_per_second': 72.698, 'eval_steps_per_second': 18.175, 'epoch': 4.0}
{'loss': 1.4193, 'learning_rate': 1.0526315789473684e-05, 'epoch': 5.0}


  0%|          | 0/500 [00:00<?, ?it/s]

{'eval_loss': 1.7924469709396362, 'eval_f1_product': 0.8220689915797993, 'eval_f1_issue': 0.49385480692359296, 'eval_runtime': 27.8452, 'eval_samples_per_second': 71.826, 'eval_steps_per_second': 17.956, 'epoch': 5.0}
{'loss': 1.2805, 'learning_rate': 8.421052631578948e-06, 'epoch': 6.0}


  0%|          | 0/500 [00:00<?, ?it/s]

{'eval_loss': 1.7341166734695435, 'eval_f1_product': 0.8340649212035753, 'eval_f1_issue': 0.517574158893834, 'eval_runtime': 28.0977, 'eval_samples_per_second': 71.18, 'eval_steps_per_second': 17.795, 'epoch': 6.0}
{'loss': 1.1544, 'learning_rate': 6.31578947368421e-06, 'epoch': 7.0}


  0%|          | 0/500 [00:00<?, ?it/s]

{'eval_loss': 1.811495065689087, 'eval_f1_product': 0.8416251659607228, 'eval_f1_issue': 0.5500737803473196, 'eval_runtime': 27.3928, 'eval_samples_per_second': 73.012, 'eval_steps_per_second': 18.253, 'epoch': 7.0}
{'loss': 1.0247, 'learning_rate': 4.210526315789474e-06, 'epoch': 8.0}


  0%|          | 0/500 [00:00<?, ?it/s]

{'eval_loss': 1.9135061502456665, 'eval_f1_product': 0.8441983795419775, 'eval_f1_issue': 0.5612920205131465, 'eval_runtime': 27.6158, 'eval_samples_per_second': 72.422, 'eval_steps_per_second': 18.106, 'epoch': 8.0}
{'loss': 0.923, 'learning_rate': 2.105263157894737e-06, 'epoch': 9.0}


  0%|          | 0/500 [00:00<?, ?it/s]

{'eval_loss': 1.9395087957382202, 'eval_f1_product': 0.8449531285266052, 'eval_f1_issue': 0.5646385661481008, 'eval_runtime': 27.8557, 'eval_samples_per_second': 71.799, 'eval_steps_per_second': 17.95, 'epoch': 9.0}
{'loss': 0.8313, 'learning_rate': 0.0, 'epoch': 10.0}


  0%|          | 0/500 [00:00<?, ?it/s]

{'eval_loss': 1.9870702028274536, 'eval_f1_product': 0.8420717800675612, 'eval_f1_issue': 0.5642200057911417, 'eval_runtime': 27.6061, 'eval_samples_per_second': 72.448, 'eval_steps_per_second': 18.112, 'epoch': 10.0}
{'train_runtime': 3656.3222, 'train_samples_per_second': 18.598, 'train_steps_per_second': 4.649, 'train_loss': 1.420393834731158, 'epoch': 10.0}


TrainOutput(global_step=17000, training_loss=1.420393834731158, metrics={'train_runtime': 3656.3222, 'train_samples_per_second': 18.598, 'train_steps_per_second': 4.649, 'train_loss': 1.420393834731158, 'epoch': 10.0})

In [17]:
trainer.evaluate()

  0%|          | 0/500 [00:00<?, ?it/s]

{'eval_loss': 1.9155381917953491,
 'eval_f1_product': 0.8350117423827834,
 'eval_f1_issue': 0.5478804429185374,
 'eval_runtime': 25.7507,
 'eval_samples_per_second': 77.668,
 'eval_steps_per_second': 19.417,
 'epoch': 10.0}