In [1]:
import sys
sys.path.append("..")

from torch.optim import AdamW
from transformers import T5Tokenizer
from torch.utils.data import DataLoader
from transformers import T5ForConditionalGeneration

from datasets import load_dataset
from transformers import get_linear_schedule_with_warmup

from scripts.response.training import train_model
from scripts.response.inference import inference_model
from scripts.response.preprocessing import ResponseDataset

from scripts.global_vars import DEVICE, MAX_LENGTH_RESPONSE, BATCH_SIZE

In [2]:
dataset = load_dataset("multi_woz_v22", trust_remote_code=True)

train_data = dataset['train']
val_data = dataset['validation']

In [3]:
tokenizer = T5Tokenizer.from_pretrained(
    legacy=True,
    pretrained_model_name_or_path="google/t5-efficient-mini"
)

train_response_dataset = ResponseDataset(
    data=dataset['train'],
    tokenizer=tokenizer,
    max_output_len=MAX_LENGTH_RESPONSE
)

valid_response_dataset = ResponseDataset(
    data=dataset['validation'],
    tokenizer=tokenizer,
    max_output_len=MAX_LENGTH_RESPONSE
)

train_loader_response = DataLoader(train_response_dataset, batch_size=BATCH_SIZE, shuffle=True)
valid_loader_response = DataLoader(valid_response_dataset, batch_size=BATCH_SIZE)

batch = next(iter(train_loader_response))
print("Action IDs shape:", batch['encoder_input_ids'].shape)
print("Response IDs shape:", batch['decoder_input_ids'].shape)

Processing dialogues: 100%|██████████| 8437/8437 [00:03<00:00, 2295.16it/s]
Processing dialogues: 100%|██████████| 1000/1000 [00:00<00:00, 2137.98it/s]

Action IDs shape: torch.Size([64, 128])
Response IDs shape: torch.Size([64, 128])





In [4]:
num_epochs = 5
num_training_steps = len(train_loader_response) * num_epochs
num_warmup_steps = num_training_steps // 10

response_model = T5ForConditionalGeneration.from_pretrained(
    "google/t5-efficient-mini"
).to(DEVICE)

optimizer = AdamW(
    response_model.parameters(),
    lr=1e-2,
    eps=1e-8
)

scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=num_warmup_steps,
    num_training_steps=num_training_steps
)

In [5]:
response_model = train_model(
    response_model,
    optimizer,
    scheduler,
    train_loader_response,
    valid_loader_response,
    num_epochs=num_epochs,
    device=DEVICE,
    save="../../models/multixoz_response_model.pth"
)


Epoch 1/5
--------------------------------------------------


Training:   0%|          | 0/888 [00:00<?, ?it/s]Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.
Training: 100%|██████████| 888/888 [01:55<00:00,  7.70it/s]
Validation: 100%|██████████| 116/116 [00:05<00:00, 20.89it/s]


Training   - Loss: 0.6345
Validation - Loss: 0.2137
LR: 8.89e-03

Epoch 2/5
--------------------------------------------------


Training: 100%|██████████| 888/888 [01:55<00:00,  7.71it/s]
Validation: 100%|██████████| 116/116 [00:05<00:00, 21.00it/s]


Training   - Loss: 0.2229
Validation - Loss: 0.1954
LR: 6.67e-03

Epoch 3/5
--------------------------------------------------


Training: 100%|██████████| 888/888 [01:54<00:00,  7.74it/s]
Validation: 100%|██████████| 116/116 [00:05<00:00, 20.74it/s]


Training   - Loss: 0.2012
Validation - Loss: 0.1848
LR: 4.44e-03

Epoch 4/5
--------------------------------------------------


Training: 100%|██████████| 888/888 [01:55<00:00,  7.71it/s]
Validation: 100%|██████████| 116/116 [00:05<00:00, 21.06it/s]


Training   - Loss: 0.1857
Validation - Loss: 0.1763
LR: 2.22e-03

Epoch 5/5
--------------------------------------------------


Training: 100%|██████████| 888/888 [01:54<00:00,  7.75it/s]
Validation: 100%|██████████| 116/116 [00:05<00:00, 21.01it/s]


Training   - Loss: 0.1708
Validation - Loss: 0.1693
LR: 0.00e+00


In [9]:
index = 105
inputs = valid_response_dataset.actions[index]

generated_output = inference_model(
    response_model,
    tokenizer,
    inputs,
    MAX_LENGTH_RESPONSE,
    DEVICE
)

print("User Action:", inputs)
print("Generated Response:", generated_output)
print("True Response:", valid_response_dataset.responses[index])

User Action: [USER]: I'm looking for a train leaving on Wednesday that's going to Cambridge. [ACTION]: Train-Request(departure=?)
Generated Response: Where will you be departing from?
True Response: Okay, where are you departing?
