<a href="https://colab.research.google.com/github/goelnikhils-lgtm/languagemodels/blob/main/TOD_BERT_Multi_Task_Fine_Tuning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
from transformers import AutoModel, AutoTokenizer
from torch.utils.data import Dataset, DataLoader

# --- 1. CONFIGURATION AND UTILITIES ---

# Replace with your actual pre-trained model path or a compatible BERT model
TOD_BERT_MODEL = 'bert-base-uncased'
NUM_SLOTS = 50  # Number of possible (slot, value) pairs for DST
NUM_ACTIONS = 10 # Number of possible dialogue actions for Policy
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def make_dummy_data(tokenizer):
    """Creates a dummy dataset simulating turn-level dialogue data."""
    # Context: Full dialogue history up to the current turn
    # Target 1 (DST): Binary vector for which slots are active
    # Target 2 (Policy): Integer index for the next action
    # Target 3 (Generation): Tokenized ground truth response

    dialogues = [
        ("Hello, I need a flight to London.", [1, 0, 0], 2, "Which city are you flying from?"),
        ("I'm leaving from New York.", [1, 1, 0], 3, "And what date are you looking for?"),
        ("I'll take any date.", [1, 1, 1], 4, "I'm searching for flights now."),
    ]

    data = []
    for context, dst_label, policy_label, response in dialogues:
        # Encoder Input (Context)
        enc_input = tokenizer(context, truncation=True, padding='max_length', max_length=128, return_tensors='pt')

        # Decoder Input (Response Tokens)
        dec_input = tokenizer(response, truncation=True, padding='max_length', max_length=64, return_tensors='pt')

        data.append({
            'input_ids': enc_input['input_ids'].squeeze(),
            'attention_mask': enc_input['attention_mask'].squeeze(),
            'dst_labels': torch.tensor(dst_label).float(), # Multi-label binary
            'policy_labels': torch.tensor(policy_label).long(), # Multi-class index
            'gen_labels': dec_input['input_ids'].squeeze(), # Sequence of tokens
        })
    return data

class DialogueDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

# --- 2. THE MULTI-HEAD MODEL ---

class TODBERTSystem(nn.Module):
    def __init__(self, num_slots, num_actions, vocab_size):
        super().__init__()
        # Pre-trained BERT Encoder (The TOD-BERT backbone)
        # Note: In a real TOD-BERT setup, you'd load the checkpoint after MLM/RCL.
        self.encoder = AutoModel.from_pretrained(TOD_BERT_MODEL)
        hidden_size = self.encoder.config.hidden_size

        # --- 1. Dialogue State Tracking (DST) Head ---
        # Predicts the probability of each slot being active (Multi-label classification)
        self.dst_head = nn.Sequential(
            nn.Linear(hidden_size, hidden_size // 2),
            nn.ReLU(),
            nn.Linear(hidden_size // 2, num_slots),
        )

        # --- 2. Dialogue Policy Head ---
        # Predicts the next discrete action (Multi-class classification)
        self.policy_head = nn.Linear(hidden_size, num_actions)

        # --- 3. Response Generation (Gen) Head ---
        # Simple token-level prediction for sequence generation (Language Modeling)
        # This acts as a decoder head, mapping the encoder output to the vocabulary.
        self.gen_head = nn.Linear(hidden_size, vocab_size)

    def forward(self, input_ids, attention_mask, gen_labels):
        # 1. Forward pass through the shared BERT encoder
        # We use the encoder's output for all tasks
        encoder_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)

        # Use the [CLS] token output for DST and Policy
        cls_output = encoder_outputs.last_hidden_state[:, 0, :]

        # Use the full sequence output for Generation
        sequence_output = encoder_outputs.last_hidden_state

        # 2. DST Prediction (Multi-label logits)
        dst_logits = self.dst_head(cls_output)

        # 3. Policy Prediction (Multi-class logits)
        policy_logits = self.policy_head(cls_output)

        # 4. Generation Prediction (Token logits)
        # This simplified head takes the encoded sequence output and predicts the next token
        gen_logits = self.gen_head(sequence_output)

        return dst_logits, policy_logits, gen_logits

# --- 3. TRAINING LOOP EXECUTION ---

def train_tod_bert_system():
    tokenizer = AutoTokenizer.from_pretrained(TOD_BERT_MODEL)
    dummy_data = make_dummy_data(tokenizer)
    dataset = DialogueDataset(dummy_data)
    dataloader = DataLoader(dataset, batch_size=4)

    # Initialize Model and Optimizer
    model = TODBERTSystem(NUM_SLOTS, NUM_ACTIONS, tokenizer.vocab_size).to(DEVICE)
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)

    # Define Loss Functions
    # BCE for DST (Multi-label)
    dst_loss_fn = nn.BCEWithLogitsLoss()
    # CrossEntropy for Policy (Multi-class)
    policy_loss_fn = nn.CrossEntropyLoss()
    # CrossEntropy for Generation (Token prediction)
    gen_loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)

    print(f"Starting fine-tuning on {DEVICE}...")

    # Training Loop (1 Epoch for demonstration)
    model.train()
    total_loss = 0

    for batch in dataloader:
        optimizer.zero_grad()

        # Move inputs and targets to device
        input_ids = batch['input_ids'].to(DEVICE)
        attention_mask = batch['attention_mask'].to(DEVICE)
        dst_labels = batch['dst_labels'].to(DEVICE)
        policy_labels = batch['policy_labels'].to(DEVICE)
        gen_labels = batch['gen_labels'].to(DEVICE)

        # Forward pass
        dst_logits, policy_logits, gen_logits = model(input_ids, attention_mask, gen_labels)

        # --- LOSS CALCULATION ---

        # 1. DST Loss (Multi-Label Classification)
        dst_loss = dst_loss_fn(dst_logits, dst_labels)

        # 2. Policy Loss (Multi-Class Classification)
        policy_loss = policy_loss_fn(policy_logits, policy_labels)

        # 3. Generation Loss (Sequence Prediction)
        # Flatten logits and labels for token-level CrossEntropy Loss
        # We predict the next token based on the sequence history
        gen_logits = gen_logits[:, :-1, :].contiguous()
        gen_labels = gen_labels[:, 1:].contiguous()

        gen_loss = gen_loss_fn(
            gen_logits.view(-1, gen_logits.size(-1)),
            gen_labels.view(-1)
        )

        # --- Total Loss (Multi-Task Learning) ---
        # The total loss is the weighted sum of individual losses.
        # Weights (1.0) are often used initially, but can be tuned.
        total_loss = (1.0 * dst_loss) + (1.0 * policy_loss) + (1.0 * gen_loss)

        # Backward pass and optimization
        total_loss.backward()
        optimizer.step()

        print(f"Batch Loss: {total_loss.item():.4f} | DST: {dst_loss.item():.4f} | Policy: {policy_loss.item():.4f} | Gen: {gen_loss.item():.4f}")

    print("Fine-tuning simulation complete.")

if __name__ == '__main__':
    # Set logging level to avoid excessive logs from transformers library
    from transformers import logging as hf_logging
    hf_logging.set_verbosity_error()
    train_tod_bert_system()