In [2]:
import sys
sys.path.append("..")

import torch
from torch.optim import AdamW
from transformers import T5Tokenizer
from torch.utils.data import DataLoader
from transformers import T5ForConditionalGeneration

from datasets import load_dataset
from transformers import get_linear_schedule_with_warmup

from scripts.response.training import train_model
from scripts.response.inference import inference_model
from scripts.response.preprocessing import ResponseDataset

from scripts.global_vars import DEVICE, MAX_LENGTH, BATCH_SIZE

In [3]:
dataset = load_dataset("multi_woz_v22", trust_remote_code=True)

train_data = dataset['train']
val_data = dataset['validation']

In [5]:
from typing import List, Dict, Tuple


def preprocess_action_prediction(dialogue: Dict, max_turns: int = 5) -> List[Tuple[str, str]]:
    samples = []
    turns = dialogue["turns"]
    utterances = turns["utterance"]
    speakers = turns["speaker"]
    acts = turns["dialogue_acts"]

    context = []
    for i, (utt, spk) in enumerate(zip(utterances, speakers)):
        context.append(("USER" if spk == 0 else "SYS") + ": " + utt)
        
        # Keep only the last max_turns in context
        if len(context) > max_turns:
            context = context[-max_turns:]
        
        # if system turn → predict the action
        if spk == 1 and i < len(acts):
            act = acts[i]["dialog_act"]
            
            # flatten action into string label (e.g. "Restaurant-Inform(area=centre, pricerange=expensive)")
            label_parts = []
            for act_type, act_slots in zip(act["act_type"], act["act_slots"]):
                slots = [f"{s}={v}" for s, v in zip(act_slots.get("slot_name", []), act_slots.get("slot_value", []))]
                label_parts.append(f"{act_type}({', '.join(slots)})")
            action_label = " | ".join(label_parts)

            # join history up to now as input
            input_text = " ".join([c for c in context])
            samples.append((input_text, action_label))

    return samples

In [16]:
res = preprocess_action_prediction(train_data[0], max_turns=4)

In [24]:
res[2]

("USER: Any sort of food would be fine, as long as it is a bit expensive. Could I get the phone number for your recommendation? SYS: There is an Afrian place named Bedouin in the centre. How does that sound? USER: Sounds good, could I get that phone number? Also, could you recommend me an expensive hotel? SYS: Bedouin's phone is 01223367660. As far as hotels go, I recommend the University Arms Hotel in the center of town.",
 'Hotel-Recommend(area=center of town, name=the University Arms Hotel) | Restaurant-Inform(name=Bedouin, phone=01223367660)')