In [None]:
import sys
sys.path.append("..")

import torch
from torch.optim import AdamW
from transformers import T5Tokenizer
from torch.utils.data import DataLoader
from transformers import T5ForConditionalGeneration

from scripts.global_vars import (
    DEVICE, 
    BATCH_SIZE, 
    MODEL_NAME,
    USE_TRAINED_MODEL,
    MAX_LENGTH_ENCODER_RESPONSE, 
    MAX_LENGTH_DECODER_RESPONSE
)

from datasets import load_dataset
from transformers import get_linear_schedule_with_warmup

from scripts.utils import find_zero_percentage
from scripts.pytorch.training import train_model
from scripts.pytorch.inference import inference_model
from scripts.preprocessing.response import ResponseDataset


In [2]:
dataset = load_dataset("multi_woz_v22", trust_remote_code=True)

train_data = dataset['train']
val_data = dataset['validation']

In [3]:
tokenizer = T5Tokenizer.from_pretrained(
    legacy=True,
    pretrained_model_name_or_path=MODEL_NAME
)

train_response_dataset = ResponseDataset(
    data=dataset['train'],
    tokenizer=tokenizer,
    max_output_len=MAX_LENGTH_DECODER_RESPONSE,
    max_input_len=MAX_LENGTH_ENCODER_RESPONSE
)

valid_response_dataset = ResponseDataset(
    data=dataset['validation'],
    tokenizer=tokenizer,
    max_output_len=MAX_LENGTH_DECODER_RESPONSE,
    max_input_len=MAX_LENGTH_ENCODER_RESPONSE
)

train_loader_response = DataLoader(train_response_dataset, batch_size=BATCH_SIZE, shuffle=True)
valid_loader_response = DataLoader(valid_response_dataset, batch_size=BATCH_SIZE)

batch = next(iter(train_loader_response))
print("RESPONSE IDs shape:", batch['encoder_input_ids'].shape)
print("Response IDs shape:", batch['decoder_input_ids'].shape)

Processing dialogues: 100%|██████████| 8437/8437 [00:03<00:00, 2226.13it/s]
Processing dialogues: 100%|██████████| 1000/1000 [00:00<00:00, 2092.06it/s]


RESPONSE IDs shape: torch.Size([256, 64])
Response IDs shape: torch.Size([256, 32])


In [4]:
train_encoder_zero = find_zero_percentage(train_loader_response, "encoder_input_ids", MAX_LENGTH_ENCODER_RESPONSE)
train_decoder_zero = find_zero_percentage(train_loader_response, "decoder_input_ids", MAX_LENGTH_DECODER_RESPONSE)
valid_encoder_zero = find_zero_percentage(valid_loader_response, "encoder_input_ids", MAX_LENGTH_ENCODER_RESPONSE)
valid_decoder_zero = find_zero_percentage(valid_loader_response, "decoder_input_ids", MAX_LENGTH_DECODER_RESPONSE)

In [5]:
print(
    "Train Encoder Zero Percentage:", train_encoder_zero * 100,
    "\nTrain Decoder Zero Percentage:", train_decoder_zero * 100,
    "\nValid Encoder Zero Percentage:", valid_encoder_zero * 100,
    "\nValid Decoder Zero Percentage:", valid_decoder_zero * 100
)

Train Encoder Zero Percentage: 24.295009672641754 
Train Decoder Zero Percentage: 35.385385155677795 
Valid Encoder Zero Percentage: 23.39279055595398 
Valid Decoder Zero Percentage: 34.25803482532501


In [6]:
num_epochs = 10
num_training_steps = len(train_loader_response) * num_epochs
num_warmup_steps = num_training_steps // 10

response_model = T5ForConditionalGeneration.from_pretrained(
    pretrained_model_name_or_path=MODEL_NAME
).to(DEVICE)

optimizer = AdamW(
    response_model.parameters(),
    lr=5e-3,
    eps=1e-8
)

scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=num_warmup_steps,
    num_training_steps=num_training_steps
)

In [7]:
if not USE_TRAINED_MODEL:
    response_model = train_model(
        response_model,
        optimizer,
        scheduler,
        train_loader_response,
        valid_loader_response,
        num_epochs=num_epochs,
        device=DEVICE,
        save="../../models/multixoz_response_model.pth"
    )

else:
    response_model.load_state_dict(torch.load("../../models/multixoz_response_model.pth", weights_only=True))


Epoch 1/10
--------------------------------------------------


Training:   0%|          | 0/222 [00:00<?, ?it/s]Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.
Training: 100%|██████████| 222/222 [01:12<00:00,  3.06it/s]
Validation: 100%|██████████| 29/29 [00:03<00:00,  7.26it/s]


Training   - Loss: 1.4804
Validation - Loss: 0.8059
LR: 5.00e-03

Epoch 2/10
--------------------------------------------------


Training: 100%|██████████| 222/222 [01:12<00:00,  3.06it/s]
Validation: 100%|██████████| 29/29 [00:04<00:00,  7.25it/s]


Training   - Loss: 0.7896
Validation - Loss: 0.7197
LR: 4.44e-03

Epoch 3/10
--------------------------------------------------


Training: 100%|██████████| 222/222 [01:12<00:00,  3.06it/s]
Validation: 100%|██████████| 29/29 [00:04<00:00,  7.19it/s]


Training   - Loss: 0.7088
Validation - Loss: 0.6877
LR: 3.89e-03

Epoch 4/10
--------------------------------------------------


Training: 100%|██████████| 222/222 [01:12<00:00,  3.06it/s]
Validation: 100%|██████████| 29/29 [00:03<00:00,  7.25it/s]


Training   - Loss: 0.6664
Validation - Loss: 0.6728
LR: 3.33e-03

Epoch 5/10
--------------------------------------------------


Training: 100%|██████████| 222/222 [01:12<00:00,  3.06it/s]
Validation: 100%|██████████| 29/29 [00:04<00:00,  7.19it/s]


Training   - Loss: 0.6349
Validation - Loss: 0.6667
LR: 2.78e-03

Epoch 6/10
--------------------------------------------------


Training: 100%|██████████| 222/222 [01:12<00:00,  3.06it/s]
Validation: 100%|██████████| 29/29 [00:03<00:00,  7.26it/s]


Training   - Loss: 0.6082
Validation - Loss: 0.6571
LR: 2.22e-03

Epoch 7/10
--------------------------------------------------


Training: 100%|██████████| 222/222 [01:12<00:00,  3.06it/s]
Validation: 100%|██████████| 29/29 [00:04<00:00,  7.18it/s]


Training   - Loss: 0.5847
Validation - Loss: 0.6524
LR: 1.67e-03

Epoch 8/10
--------------------------------------------------


Training: 100%|██████████| 222/222 [01:12<00:00,  3.06it/s]
Validation: 100%|██████████| 29/29 [00:04<00:00,  7.12it/s]


Training   - Loss: 0.5636
Validation - Loss: 0.6517
LR: 1.11e-03

Epoch 9/10
--------------------------------------------------


Training: 100%|██████████| 222/222 [01:13<00:00,  3.04it/s]
Validation: 100%|██████████| 29/29 [00:04<00:00,  7.09it/s]


Training   - Loss: 0.5438
Validation - Loss: 0.6521
LR: 5.56e-04

Epoch 10/10
--------------------------------------------------


Training: 100%|██████████| 222/222 [01:12<00:00,  3.06it/s]
Validation: 100%|██████████| 29/29 [00:04<00:00,  7.19it/s]


Training   - Loss: 0.5269
Validation - Loss: 0.6535
LR: 0.00e+00


In [8]:
generated_outputs_train = inference_model(
    response_model, 
    tokenizer, 
    train_response_dataset.actions, 
    MAX_LENGTH_ENCODER_RESPONSE, 
    MAX_LENGTH_DECODER_RESPONSE, 
    DEVICE,
    batch_size=1024
) 

generated_outputs_valid = inference_model(
    response_model, 
    tokenizer, 
    valid_response_dataset.actions, 
    MAX_LENGTH_ENCODER_RESPONSE, 
    MAX_LENGTH_DECODER_RESPONSE, 
    DEVICE,
    batch_size=1024
) 

Inference: 100%|██████████| 56/56 [02:00<00:00,  2.15s/it]
Inference: 100%|██████████| 8/8 [00:16<00:00,  2.05s/it]


In [20]:
from typing import List, Dict, Tuple
from bert_score import score as bert_score
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction


def get_evaluation_score(predictions: List[str], references: List[str]) -> float:
    smoothing = SmoothingFunction().method1
    
    predictions = [pred.strip().lower() for pred in predictions]
    references = [ref.strip().lower() for ref in references]
    
    bleu_scores = [
        sentence_bleu(
            [ref.split()], 
            pred.split(), 
            smoothing_function=smoothing
        ) 
        for pred, ref in zip(predictions, references)
    ]

    P, R, F1 = bert_score(
        predictions, 
        references, 
        lang='en', 
        rescale_with_baseline=True, 
        verbose=True, 
        batch_size=256
    )
    
    avg_bert_f1 = F1.mean().item()
    avg_bleu = sum(bleu_scores) / len(bleu_scores)
    
    return {
        "bleu_score": avg_bleu,
        "bert_f1_score": avg_bert_f1
    }

In [None]:
bleu_score_train = get_evaluation_score(generated_outputs_train, train_response_dataset.actions)
bleu_score_valid = get_evaluation_score(generated_outputs_valid, valid_response_dataset.actions)

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


calculating scores...
computing bert embedding.


  0%|          | 0/353 [00:00<?, ?it/s]

In [None]:
print("BLEU Score Train:", bleu_score_train["bleu_score"])
print("BERT F1 Score Train:", bleu_score_train["bert_f1_score"])
print("BLEU Score Valid:", bleu_score_valid["bleu_score"])
print("BERT F1 Score Valid:", bleu_score_valid["bert_f1_score"])

BLEU Score Train: 0.021270445050982297
BERT F1 Score Train: 0
BLEU Score Valid: 0.020149294690607063
BERT F1 Score Valid: 0


In [19]:
index = 700
inputs = valid_response_dataset.actions[index]

generated_output = inference_model(
    response_model,
    tokenizer,
    inputs,
    MAX_LENGTH_ENCODER_RESPONSE,
    MAX_LENGTH_DECODER_RESPONSE,
    DEVICE
)

print("User Action:", inputs)
print("Generated Response:", generated_output)
print("True Response:", valid_response_dataset.responses[index])

Inference: 100%|██████████| 1/1 [00:00<00:00,  9.15it/s]

User Action: [USER]: I would like to leave after 9:45. Please let me know the car type and contact number. [ACTION]: Taxi-Inform(phone=07597996556, type=red volkswagen)
Generated Response: Okay, I've booked you a red volkswagen and the contact number is 07597996556.
True Response: A red volkswagen will pick you up. The contact number is 07597996556.



