In [1]:
import sys
sys.path.append("..")

import torch
from torch.optim import AdamW
from transformers import T5Tokenizer
from torch.utils.data import DataLoader
from transformers import T5ForConditionalGeneration

from datasets import load_dataset
from transformers import get_linear_schedule_with_warmup

from scripts.response.training import train_model
from scripts.response.inference import inference_model
# from scripts.response.preprocessing import ActionDataset

from scripts.global_vars import DEVICE, MAX_LENGTH_ACTION, BATCH_SIZE, MAX_TURNS

In [2]:
dataset = load_dataset("multi_woz_v22", trust_remote_code=True)

train_data = dataset['train']
val_data = dataset['validation']

In [3]:
from typing import List, Dict, Tuple


def preprocess_action_prediction(dialogue: Dict, max_turns: int = 5) -> List[Tuple[str, str]]:
    samples = []
    turns = dialogue["turns"]
    utterances = turns["utterance"]
    speakers = turns["speaker"]
    acts = turns["dialogue_acts"]

    context = []
    for i, (utt, spk) in enumerate(zip(utterances, speakers)):
        context.append(("USER" if spk == 0 else "SYS") + ": " + utt)
        
        # Keep only the last max_turns in context
        if len(context) > max_turns:
            context = context[-max_turns:]
        
        # if system turn → predict the action
        if spk == 1 and i < len(acts):
            act = acts[i]["dialog_act"]
            
            # flatten action into string label (e.g. "Restaurant-Inform(area=centre, pricerange=expensive)")
            label_parts = []
            for act_type, act_slots in zip(act["act_type"], act["act_slots"]):
                slots = [f"{s}={v}" for s, v in zip(act_slots.get("slot_name", []), act_slots.get("slot_value", []))]
                label_parts.append(f"{act_type}({', '.join(slots)})")
            action_label = " | ".join(label_parts)

            # join history up to now as input
            input_text = " ".join([c for c in context])
            samples.append((input_text, action_label))

    return samples

In [4]:
res = preprocess_action_prediction(train_data[0], max_turns=MAX_TURNS)

In [5]:
res[2]

("USER: Sounds good, could I get that phone number? Also, could you recommend me an expensive hotel? SYS: Bedouin's phone is 01223367660. As far as hotels go, I recommend the University Arms Hotel in the center of town.",
 'Hotel-Recommend(area=center of town, name=the University Arms Hotel) | Restaurant-Inform(name=Bedouin, phone=01223367660)')

In [6]:
import tqdm
from torch.utils.data import Dataset


class ActionDataset(Dataset):
    def __init__(self, data: List[Dict], tokenizer: T5Tokenizer, max_output_len: int = 64):
        self.tokenizer = tokenizer
        self.max_output_len = max_output_len
        
        self.inputs = []
        self.actions = []

        for dialogue in tqdm.tqdm(data, desc="Processing dialogues"):
            preprocessed_text = preprocess_action_prediction(dialogue)

            for (input, action) in preprocessed_text:
                self.inputs.append(input)
                self.actions.append(action)

    def __len__(self):
        return len(self.actions)

    def __getitem__(self, idx):
        input_enc = self.tokenizer(
            self.inputs[idx], 
            padding='max_length', 
            truncation=True, 
            max_length=self.max_output_len, 
            return_tensors="pt"
        )

        action_enc = self.tokenizer(
            self.actions[idx],
            padding='max_length',
            truncation=True,
            max_length=self.max_output_len,
            return_tensors="pt"
        )

        return {
            "encoder_input_ids": input_enc["input_ids"].squeeze(0),
            "encoder_attention_mask": input_enc["attention_mask"].squeeze(0),
            "decoder_input_ids": action_enc["input_ids"].squeeze(0),
            "decoder_attention_mask": action_enc["attention_mask"].squeeze(0)
        }

In [7]:
tokenizer = T5Tokenizer.from_pretrained(
    legacy=True,
    pretrained_model_name_or_path="google/t5-efficient-mini"
)

train_action_dataset = ActionDataset(
    data=dataset['train'],
    tokenizer=tokenizer,
    max_output_len=MAX_LENGTH_ACTION
)

valid_action_dataset = ActionDataset(
    data=dataset['validation'],
    tokenizer=tokenizer,
    max_output_len=MAX_LENGTH_ACTION
)

train_loader_action = DataLoader(train_action_dataset, batch_size=BATCH_SIZE, shuffle=True)
valid_loader_action = DataLoader(valid_action_dataset, batch_size=BATCH_SIZE)

batch = next(iter(train_loader_action))
print("Inputs IDs shape:", batch['encoder_input_ids'].shape)
print("Action IDs shape:", batch['decoder_input_ids'].shape)

Processing dialogues: 100%|██████████| 8437/8437 [00:03<00:00, 2258.98it/s]
Processing dialogues: 100%|██████████| 1000/1000 [00:00<00:00, 2131.37it/s]


Inputs IDs shape: torch.Size([64, 128])
Action IDs shape: torch.Size([64, 128])


In [8]:
num_epochs = 5
num_training_steps = len(train_loader_action) * num_epochs
num_warmup_steps = num_training_steps // 10

action_model = T5ForConditionalGeneration.from_pretrained(
    "google/t5-efficient-mini"
).to(DEVICE)

optimizer = AdamW(
    action_model.parameters(),
    lr=1e-3,
    eps=1e-8
)

scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=num_warmup_steps,
    num_training_steps=num_training_steps
)

In [9]:
action_model = train_model(
    action_model,
    optimizer,
    scheduler,
    train_loader_action,
    valid_loader_action,
    num_epochs=num_epochs,
    device=DEVICE,
    save="../../models/multixoz_action_model.pth"
)


Epoch 1/5
--------------------------------------------------


Training:   0%|          | 0/888 [00:00<?, ?it/s]Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.
Training: 100%|██████████| 888/888 [01:58<00:00,  7.49it/s]
Validation: 100%|██████████| 116/116 [00:06<00:00, 19.16it/s]


Training   - Loss: 1.1459
Validation - Loss: 0.0826
LR: 8.89e-04

Epoch 2/5
--------------------------------------------------


Training: 100%|██████████| 888/888 [01:58<00:00,  7.49it/s]
Validation: 100%|██████████| 116/116 [00:06<00:00, 19.16it/s]


Training   - Loss: 0.0905
Validation - Loss: 0.0618
LR: 6.67e-04

Epoch 3/5
--------------------------------------------------


Training: 100%|██████████| 888/888 [01:58<00:00,  7.48it/s]
Validation: 100%|██████████| 116/116 [00:06<00:00, 19.32it/s]


Training   - Loss: 0.0700
Validation - Loss: 0.0548
LR: 4.44e-04

Epoch 4/5
--------------------------------------------------


Training: 100%|██████████| 888/888 [01:58<00:00,  7.47it/s]
Validation: 100%|██████████| 116/116 [00:06<00:00, 19.28it/s]


Training   - Loss: 0.0606
Validation - Loss: 0.0509
LR: 2.22e-04

Epoch 5/5
--------------------------------------------------


Training: 100%|██████████| 888/888 [01:58<00:00,  7.48it/s]
Validation: 100%|██████████| 116/116 [00:06<00:00, 19.03it/s]


Training   - Loss: 0.0553
Validation - Loss: 0.0499
LR: 0.00e+00


In [10]:
index = 2
inputs = train_action_dataset.inputs[index]

generated_action = inference_model(
    action_model,
    tokenizer,
    inputs,
    MAX_LENGTH_ACTION,
    DEVICE
)

print("User Inputs:", inputs)
print("Generated Action:", generated_action)
print("True Action:", train_action_dataset.actions[index])

User Inputs: SYS: I have several options for you; do you prefer African, Asian, or British food? USER: Any sort of food would be fine, as long as it is a bit expensive. Could I get the phone number for your recommendation? SYS: There is an Afrian place named Bedouin in the centre. How does that sound? USER: Sounds good, could I get that phone number? Also, could you recommend me an expensive hotel? SYS: Bedouin's phone is 01223367660. As far as hotels go, I recommend the University Arms Hotel in the center of town.
Generated Action: Restaurant-Inform(food=Afrian, name=Afrian)
True Action: Hotel-Recommend(area=center of town, name=the University Arms Hotel) | Restaurant-Inform(name=Bedouin, phone=01223367660)
