# DLA Training Notebook: Dialogue Generation with Adversarial LearningThis comprehensive notebook contains **ALL** the code needed to train a dialogue generation model using Generative Adversarial Networks (GANs).## 📋 Contents1. **Setup & Imports** - Install and import all dependencies2. **Configuration** - Centralized hyperparameters3. **Data Loading & Processing** - Multi-file .txt parser and tokenizer4. **Model Architecture** - Generator (LSTM) and Discriminator (CNN)5. **Training Functions** - Pre-training and adversarial training6. **Evaluation & Metrics** - BLEU, diversity, perplexity7. **Main Training Loop** - Execute the full training pipeline8. **Generation & Testing** - Generate sample dialogues## ⚡ Features- **Self-contained**: No external Python files needed- **Multi-file support**: Load multiple .txt dialogue files- **Fast training**: Optimized with @tf.function decorators- **Well-commented**: Clear explanations throughout- **Production-ready**: Complete training pipeline## 🚀 UsageSimply run all cells in order. Training will automatically:1. Load all .txt files from `data/` directory2. Build vocabulary and preprocess text3. Pre-train Generator with Maximum Likelihood4. Pre-train Discriminator5. Perform adversarial training6. Generate sample dialogues

---## 1. Setup & ImportsInstall and import all necessary libraries.

In [None]:
# Install required packages (uncomment if needed)# !pip install tensorflow numpyimport osimport sysimport timeimport globimport reimport pickleimport loggingfrom dataclasses import dataclassfrom typing import List, Dict, Tuple, Optionalfrom collections import Counter, defaultdictfrom io import StringIOimport numpy as npimport tensorflow as tffrom tensorflow import kerasfrom tensorflow.keras import layers# Set random seeds for reproducibilitySEED = 42np.random.seed(SEED)tf.random.set_seed(SEED)# Configure logginglogging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')logger = logging.getLogger(__name__)print("✓ All imports successful!")print(f"TensorFlow version: {tf.__version__}")print(f"GPU Available: {len(tf.config.list_physical_devices('GPU')) > 0}")

---## 2. ConfigurationAll hyperparameters in one place for easy tuning.

In [None]:
class Config:    """Centralized configuration for DLA model training"""        # ==================== Data Configuration ====================    DATA_DIR = "data"                      # Directory containing .txt dialogue files    DATA_PATTERN = "*.txt"                 # File pattern to match    MAX_SEQUENCE_LENGTH = 50               # Maximum words per dialogue turn    VOCAB_SIZE = 10000                     # Number of unique words to keep    MIN_WORD_FREQ = 2                      # Minimum word frequency for vocabulary        # ==================== Model Architecture ====================    EMBEDDING_DIM = 128                    # Size of word embeddings    HIDDEN_DIM = 256                       # LSTM hidden state dimension    GENERATOR_LSTM_LAYERS = 2              # Number of LSTM layers in Generator    DROPOUT_RATE = 0.3                     # Dropout rate for regularization    DISCRIMINATOR_KERNEL_SIZES = [3, 4, 5] # CNN kernel sizes for n-gram patterns    DISCRIMINATOR_NUM_FILTERS = 128        # Number of filters per kernel        # ==================== Training Configuration ====================    BATCH_SIZE = 64                        # Training batch size    GENERATOR_PRETRAIN_EPOCHS = 5          # Generator pre-training epochs (reduced for speed)    DISCRIMINATOR_PRETRAIN_EPOCHS = 3      # Discriminator pre-training epochs (reduced)    ADVERSARIAL_EPOCHS = 5                 # Adversarial training epochs (reduced for speed)    G_STEPS = 1                            # Generator updates per adversarial epoch    D_STEPS = 5                            # Discriminator updates per adversarial epoch    GENERATOR_LR_PRETRAIN = 0.001          # Generator learning rate (pre-training)    GENERATOR_LR_ADVERSARIAL = 0.0001      # Generator learning rate (adversarial)    DISCRIMINATOR_LR = 0.0001              # Discriminator learning rate        # ==================== Generation Configuration ====================    TEMPERATURE = 1.0                      # Sampling temperature    TOP_K = 50                             # Top-k sampling parameter    TOP_P = 0.95                           # Nucleus sampling parameter    MAX_GENERATION_LENGTH = 50             # Maximum generation length        # ==================== System Configuration ====================    RANDOM_SEED = 42                       # Random seed for reproducibility    DISPLAY_EXAMPLES = 3                   # Number of examples to display    CHECKPOINT_EVERY = 2                   # Save checkpoint every N epochs    OUTPUT_DIR = "outputs"                 # Output directory    MODEL_DIR = os.path.join(OUTPUT_DIR, "models")    LOGS_DIR = os.path.join(OUTPUT_DIR, "logs")# Create necessary directoriesos.makedirs(Config.DATA_DIR, exist_ok=True)os.makedirs(Config.OUTPUT_DIR, exist_ok=True)os.makedirs(Config.MODEL_DIR, exist_ok=True)os.makedirs(Config.LOGS_DIR, exist_ok=True)print("✓ Configuration loaded!")print(f"  - Data directory: {Config.DATA_DIR}")print(f"  - Output directory: {Config.OUTPUT_DIR}")print(f"  - Generator pre-train epochs: {Config.GENERATOR_PRETRAIN_EPOCHS}")print(f"  - Discriminator pre-train epochs: {Config.DISCRIMINATOR_PRETRAIN_EPOCHS}")print(f"  - Adversarial epochs: {Config.ADVERSARIAL_EPOCHS}")

---## 3. Data Loading & ProcessingMulti-file dialogue parser and tokenizer for text preprocessing.

In [None]:
@dataclassclass DialogueTurn:    """Represents a single dialogue turn with context and response"""    context: str    response: str    metadata: Optional[Dict] = Noneprint("✓ DialogueTurn dataclass defined")

In [None]:
class DialogueParser:    """    Flexible parser for various dialogue dataset formats.        Supports:    - Context-response pairs: 'context: ... response: ...'    - Speaker dialogues: 'Speaker: dialogue text'    - Mixed formats with scene descriptions    """        def __init__(self):        # Regex pattern for context-response format        self.context_response_pattern = re.compile(            r'context:\s*(.+?)\s*response:\s*(.+?)(?=\ncontext:|$)',             re.DOTALL        )        # Regex pattern for speaker dialogue format        self.dialogue_pattern = re.compile(r'^(.+?):\s*(.+?)$', re.MULTILINE)        def parse_file(self, filepath: str) -> List[DialogueTurn]:        """Parse a single dialogue file"""        logger.info(f"Parsing file: {filepath}")                with open(filepath, 'r', encoding='utf-8') as f:            content = f.read()                # Try context-response format first        turns = self._parse_context_response(content)                if not turns:            # Fall back to dialogue format            turns = self._parse_dialogue_format(content)                logger.info(f"  Extracted {len(turns)} dialogue turns")        return turns        def _parse_context_response(self, content: str) -> List[DialogueTurn]:        """Parse 'context: ... response: ...' format"""        matches = self.context_response_pattern.findall(content)        turns = []                for context, response in matches:            context = self._clean_text(context)            response = self._clean_text(response)                        if context and response:                turns.append(DialogueTurn(                    context=context,                    response=response,                    metadata={"format": "context_response"}                ))                return turns        def _parse_dialogue_format(self, content: str) -> List[DialogueTurn]:        """Parse 'Speaker: dialogue' format"""        lines = content.strip().split('\n')        turns = []        context_buffer = []                for line in lines:            line = line.strip()            if not line:                continue                        match = self.dialogue_pattern.match(line)            if match:                speaker, dialogue = match.groups()                dialogue = self._clean_text(dialogue)                                if context_buffer and dialogue:                    # Create turn with previous context (last 3 turns)                    context = " ".join(context_buffer[-3:])                    turns.append(DialogueTurn(                        context=context,                        response=dialogue,                        metadata={"format": "dialogue", "speaker": speaker}                    ))                                context_buffer.append(f"{speaker}: {dialogue}")                return turns        def _clean_text(self, text: str) -> str:        """Clean and normalize text"""        text = re.sub(r'\s+', ' ', text)  # Remove extra whitespace        text = text.replace('\\', ' ')    # Remove backslashes        return text.strip()        def parse_directory(self, directory: str, pattern: str = "*.txt") -> List[DialogueTurn]:        """Parse all files in directory matching pattern"""        all_turns = []        file_pattern = os.path.join(directory, pattern)        files = glob.glob(file_pattern)                logger.info(f"Found {len(files)} files matching pattern: {pattern}")                for filepath in sorted(files):            try:                turns = self.parse_file(filepath)                all_turns.extend(turns)            except Exception as e:                logger.error(f"Error parsing {filepath}: {e}")                continue                return all_turnsprint("✓ DialogueParser class defined")

In [None]:
class Tokenizer:    """    Word-level tokenizer for dialogue text.        Handles special tokens: <PAD>, <START>, <END>, <UNK>    Converts text to integer sequences and vice versa.    """        def __init__(self, vocab_size=10000, min_freq=2):        self.vocab_size = vocab_size        self.min_freq = min_freq                # Special tokens        self.PAD = 0        self.START = 1        self.END = 2        self.UNK = 3                # Vocabularies        self.word2idx = {'<PAD>': self.PAD, '<START>': self.START,                         '<END>': self.END, '<UNK>': self.UNK}        self.idx2word = {v: k for k, v in self.word2idx.items()}                self.word_counts = Counter()        self.vocab_built = False        def fit_on_texts(self, texts: List[str]):        """Build vocabulary from list of texts"""        logger.info(f"Building vocabulary from {len(texts)} texts...")                # Count word frequencies        for text in texts:            words = self._tokenize(text)            self.word_counts.update(words)                # Filter by minimum frequency        filtered_words = [word for word, count in self.word_counts.items()                          if count >= self.min_freq]                # Sort by frequency and take top words        sorted_words = sorted(filtered_words, key=lambda w: self.word_counts[w], reverse=True)        top_words = sorted_words[:self.vocab_size - 4]  # Reserve space for special tokens                # Build word to index mapping        for idx, word in enumerate(top_words, start=4):            self.word2idx[word] = idx            self.idx2word[idx] = word                self.vocab_built = True        logger.info(f"  Vocabulary built: {len(self.word2idx)} words")        logger.info(f"  Most common words: {sorted_words[:10]}")        def _tokenize(self, text: str) -> List[str]:        """Tokenize text into words"""        text = text.lower()        text = re.sub(r"([?.!,'])", r" \1 ", text)  # Keep punctuation as separate tokens        return text.split()        def text_to_sequence(self, text: str, add_start_end=False) -> List[int]:        """Convert text to sequence of integer indices"""        words = self._tokenize(text)        sequence = [self.word2idx.get(word, self.UNK) for word in words]                if add_start_end:            sequence = [self.START] + sequence + [self.END]                return sequence        def sequence_to_text(self, sequence: List[int]) -> str:        """Convert sequence of indices back to text"""        words = [self.idx2word.get(idx, '<UNK>') for idx in sequence]        # Remove special tokens        words = [w for w in words if w not in ['<PAD>', '<START>', '<END>']]        return ' '.join(words)print("✓ Tokenizer class defined")

In [None]:
def pad_sequences(sequences: List[List[int]], maxlen: int, padding='post') -> np.ndarray:    """Pad sequences to same length"""    padded = np.zeros((len(sequences), maxlen), dtype=np.int32)        for i, seq in enumerate(sequences):        if len(seq) > maxlen:            # Truncate            if padding == 'post':                padded[i] = seq[:maxlen]            else:                padded[i] = seq[-maxlen:]        else:            # Pad            if padding == 'post':                padded[i, :len(seq)] = seq            else:                padded[i, -len(seq):] = seq        return paddeddef create_batches(data: np.ndarray, batch_size: int, shuffle=True) -> List[np.ndarray]:    """Create batches from data"""    if shuffle:        indices = np.random.permutation(len(data))        data = data[indices]        num_batches = len(data) // batch_size    batches = []        for i in range(num_batches):        batch = data[i * batch_size:(i + 1) * batch_size]        batches.append(batch)        # Add remaining data as last batch    if len(data) % batch_size != 0:        batch = data[num_batches * batch_size:]        batches.append(batch)        return batchesdef load_dialogue_data(config) -> Tuple[np.ndarray, Tokenizer, Dict]:    """    Load and process dialogue data from multiple .txt files.        Returns:        sequences: Padded integer sequences        tokenizer: Fitted tokenizer        stats: Dataset statistics    """    logger.info(f"Loading dialogue data from: {config.DATA_DIR}")        # Parse all .txt files    parser = DialogueParser()    turns = parser.parse_directory(config.DATA_DIR, pattern=config.DATA_PATTERN)        if len(turns) == 0:        raise ValueError(f"No dialogue data found in {config.DATA_DIR}")        # Calculate statistics    context_lengths = [len(turn.context.split()) for turn in turns]    response_lengths = [len(turn.response.split()) for turn in turns]        stats = {        'total_turns': len(turns),        'avg_context_length': sum(context_lengths) / len(context_lengths),        'avg_response_length': sum(response_lengths) / len(response_lengths),        'max_context_length': max(context_lengths),        'max_response_length': max(response_lengths),    }        logger.info(f"  Loaded {len(turns)} dialogue turns")    logger.info(f"  Average context length: {stats['avg_context_length']:.1f} words")    logger.info(f"  Average response length: {stats['avg_response_length']:.1f} words")        # Build vocabulary    contexts = [turn.context for turn in turns]    responses = [turn.response for turn in turns]    all_texts = contexts + responses        tokenizer = Tokenizer(vocab_size=config.VOCAB_SIZE, min_freq=config.MIN_WORD_FREQ)    tokenizer.fit_on_texts(all_texts)        # Convert to sequences    logger.info("Converting texts to sequences...")    sequences = []        for turn in turns:        # Combine context and response for autoregressive training        full_text = turn.context + " " + turn.response        seq = tokenizer.text_to_sequence(full_text, add_start_end=True)        sequences.append(seq)        # Pad sequences    sequences = pad_sequences(sequences, maxlen=config.MAX_SEQUENCE_LENGTH, padding='post')        logger.info(f"  Processed {len(sequences)} sequences")    logger.info(f"  Sequence shape: {sequences.shape}")        return sequences, tokenizer, statsprint("✓ Data loading functions defined")

---## 4. Model ArchitectureGenerator (LSTM-based) and Discriminator (CNN-based) models.

In [None]:
class Generator(keras.Model):    """    LSTM-based Generator for dialogue generation.        Architecture:    1. Embedding layer: Word indices → dense vectors    2. LSTM layers: Sequential processing with memory    3. Dense output: Project to vocabulary with softmax        Trained with:    - Pre-training: Maximum Likelihood Estimation (MLE)     - Adversarial: Policy gradient with Discriminator rewards    """        def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers=2, dropout_rate=0.3):        super(Generator, self).__init__()                self.vocab_size = vocab_size        self.embedding_dim = embedding_dim        self.hidden_dim = hidden_dim        self.num_layers = num_layers                # Embedding layer - converts word indices to dense vectors        self.embedding = layers.Embedding(            input_dim=vocab_size,            output_dim=embedding_dim,            mask_zero=True,  # Mask padding tokens            name="generator_embedding"        )                # Stack of LSTM layers        self.lstm_layers = []        for i in range(num_layers):            self.lstm_layers.append(                layers.LSTM(                    hidden_dim,                    return_sequences=True,  # Return full sequence                    return_state=True,      # Return hidden and cell states                    dropout=dropout_rate,                    recurrent_dropout=dropout_rate,                    name=f"generator_lstm_{i+1}"                )            )                # Output layer - project to vocabulary space        self.output_layer = layers.Dense(            vocab_size,            activation='softmax',  # Probability distribution            name="generator_output"        )                self.dropout = layers.Dropout(dropout_rate)        def call(self, inputs, training=False, states=None):        """        Forward pass through the generator.                Args:            inputs: Word indices (batch_size, sequence_length)            training: Whether in training mode            states: Initial LSTM states (optional)                Returns:            logits: Probability distribution (batch_size, seq_len, vocab_size)            states: Final LSTM states        """        # Embed input words        x = self.embedding(inputs)                if training:            x = self.dropout(x, training=training)                # Pass through LSTM layers        all_states = []        for lstm_layer in self.lstm_layers:            if states is not None and len(all_states) < len(states):                x, h, c = lstm_layer(x, initial_state=states[len(all_states)])            else:                x, h, c = lstm_layer(x)            all_states.append([h, c])                if training:            x = self.dropout(x, training=training)                # Project to vocabulary space        logits = self.output_layer(x)                return logits, all_states        def generate_sequence(self, start_token, max_length, temperature=1.0, top_k=50):        """        Generate a sequence autoregressively (word-by-word).                Args:            start_token: Initial token to start generation            max_length: Maximum sequence length            temperature: Sampling temperature (higher = more random)            top_k: Only sample from top-k most likely tokens                Returns:            generated_sequence: List of generated word indices        """        current_token = tf.constant([[start_token]])        generated = [start_token]        states = None                for _ in range(max_length - 1):            # Get predictions            logits, states = self.call(current_token, training=False, states=states)                        # Apply temperature            logits = logits[:, -1, :] / temperature                        # Top-k sampling            if top_k > 0:                top_k_logits, top_k_indices = tf.nn.top_k(logits, k=top_k)                next_token_idx = tf.random.categorical(top_k_logits, num_samples=1)                next_token = tf.gather(top_k_indices[0], next_token_idx[0])            else:                next_token = tf.random.categorical(logits, num_samples=1)                        next_token = int(next_token[0])            generated.append(next_token)                        # Check for end token            if next_token == 0:                break                        current_token = tf.constant([[next_token]])                return generatedprint("✓ Generator class defined")

In [None]:
class Discriminator(keras.Model):    """    CNN-based Discriminator to classify real vs generated dialogues.        Architecture:    1. Embedding layer: Word indices → dense vectors    2. Multi-kernel CNN: Captures n-gram patterns (3-gram, 4-gram, 5-gram)    3. Max pooling: Extracts most important features    4. Dense layers: Binary classification (real/fake)        Uses multiple kernel sizes to capture both short and long-range patterns,    similar to Kim's CNN for text classification.    """        def __init__(self, vocab_size, embedding_dim, kernel_sizes=[3, 4, 5],                  num_filters=128, dropout_rate=0.3):        super(Discriminator, self).__init__()                self.vocab_size = vocab_size        self.embedding_dim = embedding_dim        self.kernel_sizes = kernel_sizes        self.num_filters = num_filters                # Embedding layer        self.embedding = layers.Embedding(            input_dim=vocab_size,            output_dim=embedding_dim,            mask_zero=False,            name="discriminator_embedding"        )                # Multiple CNN branches with different kernel sizes        self.conv_layers = []        self.pooling_layers = []                for kernel_size in kernel_sizes:            # 1D convolution over sequence            conv = layers.Conv1D(                filters=num_filters,                kernel_size=kernel_size,                activation='relu',                padding='valid',                name=f"discriminator_conv_{kernel_size}"            )            self.conv_layers.append(conv)                        # Global max pooling            pool = layers.GlobalMaxPooling1D(name=f"discriminator_pool_{kernel_size}")            self.pooling_layers.append(pool)                self.concat = layers.Concatenate(name="discriminator_concat")        self.dropout = layers.Dropout(dropout_rate)                # Highway connection for better gradient flow        self.highway = layers.Dense(            num_filters * len(kernel_sizes),            activation='relu',            name="discriminator_highway"        )                # Final classification layer        self.output_layer = layers.Dense(            1,            activation='sigmoid',  # Output probability in [0, 1]            name="discriminator_output"        )        def call(self, inputs, training=False):        """        Forward pass through the discriminator.                Args:            inputs: Word indices (batch_size, sequence_length)            training: Whether in training mode                Returns:            score: Probability that input is real (batch_size, 1)        """        # Embed input words        x = self.embedding(inputs)                # Apply multiple CNN branches        conv_outputs = []        for conv_layer, pool_layer in zip(self.conv_layers, self.pooling_layers):            conv_out = conv_layer(x)  # Convolve over sequence            pooled = pool_layer(conv_out)  # Max pool            conv_outputs.append(pooled)                # Concatenate features from different kernel sizes        concatenated = self.concat(conv_outputs)                if training:            concatenated = self.dropout(concatenated, training=training)                # Highway connection        highway_out = self.highway(concatenated)                if training:            highway_out = self.dropout(highway_out, training=training)                # Binary classification        score = self.output_layer(highway_out)                return scoreprint("✓ Discriminator class defined")

In [None]:
def build_models(config):    """Build Generator and Discriminator from configuration"""        # Build Generator    generator = Generator(        vocab_size=config.VOCAB_SIZE,        embedding_dim=config.EMBEDDING_DIM,        hidden_dim=config.HIDDEN_DIM,        num_layers=config.GENERATOR_LSTM_LAYERS,        dropout_rate=config.DROPOUT_RATE    )        # Build Discriminator    discriminator = Discriminator(        vocab_size=config.VOCAB_SIZE,        embedding_dim=config.EMBEDDING_DIM,        kernel_sizes=config.DISCRIMINATOR_KERNEL_SIZES,        num_filters=config.DISCRIMINATOR_NUM_FILTERS,        dropout_rate=config.DROPOUT_RATE    )        return generator, discriminatorprint("✓ Model building function defined")

---## 5. Training FunctionsPre-training and adversarial training functions with @tf.function decorators for speed.

In [None]:
# Training step functions with @tf.function for speed optimization@tf.functiondef pretrain_generator_step(generator, optimizer, inputs, targets):    """    Single pre-training step for Generator using MLE.        Standard teacher-forcing: predict next token given previous tokens.    """    with tf.GradientTape() as tape:        # Forward pass        logits, _ = generator(inputs, training=True)                # Calculate cross-entropy loss        loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=False)                # Mask padding tokens        mask = tf.cast(targets != 0, tf.float32)        loss = loss_fn(targets, logits, sample_weight=mask)        # Backward pass    gradients = tape.gradient(loss, generator.trainable_variables)    optimizer.apply_gradients(zip(gradients, generator.trainable_variables))        return loss@tf.functiondef pretrain_discriminator_step(discriminator, optimizer, real_data, fake_data):    """    Single pre-training step for Discriminator.        Train to distinguish real dialogues from generated ones.    """    with tf.GradientTape() as tape:        # Get predictions        real_scores = discriminator(real_data, training=True)        fake_scores = discriminator(fake_data, training=True)                # Labels: 1 for real, 0 for fake        real_labels = tf.ones_like(real_scores)        fake_labels = tf.zeros_like(fake_scores)                # Calculate loss        loss_fn = keras.losses.BinaryCrossentropy()        real_loss = loss_fn(real_labels, real_scores)        fake_loss = loss_fn(fake_labels, fake_scores)        loss = (real_loss + fake_loss) / 2                # Calculate accuracy        real_acc = tf.reduce_mean(tf.cast(real_scores > 0.5, tf.float32))        fake_acc = tf.reduce_mean(tf.cast(fake_scores < 0.5, tf.float32))        accuracy = (real_acc + fake_acc) / 2        # Backward pass    gradients = tape.gradient(loss, discriminator.trainable_variables)    optimizer.apply_gradients(zip(gradients, discriminator.trainable_variables))        return loss, accuracydef generate_fake_samples(generator, tokenizer, num_samples, max_length):    """Generate fake samples from Generator for Discriminator training"""    fake_samples = []        for _ in range(num_samples):        generated = generator.generate_sequence(            start_token=tokenizer.START,            max_length=max_length,            temperature=1.0,            top_k=50        )        fake_samples.append(generated)        # Pad to same length    fake_samples = pad_sequences(fake_samples, maxlen=max_length, padding='post')    return fake_samplesdef print_training_examples(generator, tokenizer, num_examples=3):    """Generate and print example dialogues during training"""    print("\n" + "="*60)    print("Generated Examples:")    print("="*60)        for i in range(num_examples):        generated = generator.generate_sequence(            start_token=tokenizer.START,            max_length=50,            temperature=1.0,            top_k=50        )        text = tokenizer.sequence_to_text(generated)        print(f"\n{i+1}. {text}")        print("="*60 + "\n")print("✓ Training functions defined")

---## 6. Evaluation & MetricsBLEU score, diversity, and perplexity calculations.

In [None]:
def calculate_bleu(references: List[List[str]], hypothesis: List[str], n=2) -> float:    """    Calculate BLEU score for generated text.        Args:        references: List of reference sequences        hypothesis: Generated sequence        n: N-gram size (2 for BLEU-2)        Returns:        BLEU score (0 to 1)    """    # Count n-grams in hypothesis    hyp_ngrams = defaultdict(int)    for i in range(len(hypothesis) - n + 1):        ngram = tuple(hypothesis[i:i + n])        hyp_ngrams[ngram] += 1        # Count matching n-grams    matches = 0    total = sum(hyp_ngrams.values())        for ref in references:        ref_ngrams = defaultdict(int)        for i in range(len(ref) - n + 1):            ngram = tuple(ref[i:i + n])            ref_ngrams[ngram] += 1                for ngram, count in hyp_ngrams.items():            matches += min(count, ref_ngrams[ngram])        if total == 0:        return 0.0        precision = matches / total        # Brevity penalty    ref_len = sum(len(ref) for ref in references) / len(references)    hyp_len = len(hypothesis)        if hyp_len > ref_len:        bp = 1.0    else:        bp = np.exp(1 - ref_len / hyp_len) if hyp_len > 0 else 0.0        return bp * precisiondef calculate_diversity(sequences: List[List[int]], n=2) -> float:    """    Calculate diversity score - ratio of unique n-grams to total n-grams.        Higher score = more diverse generation    """    all_ngrams = []        for seq in sequences:        for i in range(len(seq) - n + 1):            ngram = tuple(seq[i:i + n])            all_ngrams.append(ngram)        if len(all_ngrams) == 0:        return 0.0        unique_ngrams = len(set(all_ngrams))    total_ngrams = len(all_ngrams)        return unique_ngrams / total_ngramsprint("✓ Evaluation functions defined")

---## 7. Main Training LoopExecute the full training pipeline: data loading, pre-training, and adversarial training.

In [None]:
def train_dla_model(config=Config):    """    Main training function for DLA model.        Pipeline:    1. Load and preprocess data    2. Build models    3. Pre-train Generator (MLE)    4. Pre-train Discriminator    5. Adversarial training    6. Save models    """        print("\n" + "#"*60)    print("# DLA TRAINING PIPELINE")    print("#"*60 + "\n")        training_start = time.time()        # ==================== Step 1: Load Data ====================    print("="*60)    print("STEP 1: Loading Data")    print("="*60)        try:        train_data, tokenizer, stats = load_dialogue_data(config)    except ValueError as e:        print(f"\n⚠️  ERROR: {e}")        print("\nPlease add dialogue .txt files to the 'data/' directory.")        print("Format: 'context: <text> response: <text>' or 'Speaker: dialogue text'")        return None, None, None        print(f"\n✓ Data loaded successfully!")    print(f"  - Training samples: {len(train_data)}")    print(f"  - Vocabulary size: {len(tokenizer.word2idx)}")    print(f"  - Total dialogue turns: {stats['total_turns']}")        # ==================== Step 2: Build Models ====================    print("\n" + "="*60)    print("STEP 2: Building Models")    print("="*60)        generator, discriminator = build_models(config)        # Initialize optimizers    gen_optimizer_pretrain = keras.optimizers.Adam(learning_rate=config.GENERATOR_LR_PRETRAIN)    gen_optimizer_adversarial = keras.optimizers.Adam(learning_rate=config.GENERATOR_LR_ADVERSARIAL)    disc_optimizer = keras.optimizers.Adam(learning_rate=config.DISCRIMINATOR_LR)        print(f"\n✓ Models built successfully!")    print(f"  - Generator: LSTM with {config.GENERATOR_LSTM_LAYERS} layers")    print(f"  - Discriminator: CNN with kernels {config.DISCRIMINATOR_KERNEL_SIZES}")        # Training history    history = {        'gen_pretrain_loss': [],        'disc_pretrain_loss': [],        'disc_pretrain_acc': [],        'gen_adversarial_loss': [],        'disc_adversarial_loss': [],        'disc_adversarial_acc': [],        'diversity': []    }        # ==================== Step 3: Pre-train Generator ====================    print("\n" + "="*60)    print("STEP 3: Pre-training Generator (MLE)")    print("="*60)        for epoch in range(config.GENERATOR_PRETRAIN_EPOCHS):        epoch_start = time.time()        epoch_losses = []                batches = create_batches(train_data, config.BATCH_SIZE, shuffle=True)                for batch in batches:            inputs = batch[:, :-1]            targets = batch[:, 1:]            loss = pretrain_generator_step(generator, gen_optimizer_pretrain, inputs, targets)            epoch_losses.append(float(loss))                avg_loss = np.mean(epoch_losses)        epoch_time = time.time() - epoch_start        history['gen_pretrain_loss'].append(avg_loss)                print(f"Epoch {epoch+1}/{config.GENERATOR_PRETRAIN_EPOCHS} - "              f"Loss: {avg_loss:.4f} - Time: {epoch_time:.1f}s")                # Show examples        if (epoch + 1) % 2 == 0 or epoch == config.GENERATOR_PRETRAIN_EPOCHS - 1:            print_training_examples(generator, tokenizer, num_examples=2)        print("✓ Generator pre-training complete!\n")        # ==================== Step 4: Pre-train Discriminator ====================    print("="*60)    print("STEP 4: Pre-training Discriminator")    print("="*60)        for epoch in range(config.DISCRIMINATOR_PRETRAIN_EPOCHS):        epoch_start = time.time()        epoch_losses = []        epoch_accs = []                batches = create_batches(train_data, config.BATCH_SIZE, shuffle=True)                for real_batch in batches:            fake_batch = generate_fake_samples(                generator, tokenizer, len(real_batch), config.MAX_SEQUENCE_LENGTH            )            loss, accuracy = pretrain_discriminator_step(                discriminator, disc_optimizer, real_batch, fake_batch            )            epoch_losses.append(float(loss))            epoch_accs.append(float(accuracy))                avg_loss = np.mean(epoch_losses)        avg_acc = np.mean(epoch_accs)        epoch_time = time.time() - epoch_start                history['disc_pretrain_loss'].append(avg_loss)        history['disc_pretrain_acc'].append(avg_acc)                print(f"Epoch {epoch+1}/{config.DISCRIMINATOR_PRETRAIN_EPOCHS} - "              f"Loss: {avg_loss:.4f} - Accuracy: {avg_acc:.4f} - Time: {epoch_time:.1f}s")        print("✓ Discriminator pre-training complete!\n")        # ==================== Step 5: Adversarial Training ====================    print("="*60)    print("STEP 5: Adversarial Training")    print("="*60)        for epoch in range(config.ADVERSARIAL_EPOCHS):        epoch_start = time.time()        gen_losses = []        disc_losses = []        disc_accs = []                batches = create_batches(train_data, config.BATCH_SIZE, shuffle=True)                for real_batch in batches:            # Generate fake samples            fake_batch = generate_fake_samples(                generator, tokenizer, len(real_batch), config.MAX_SEQUENCE_LENGTH            )                        # Train Discriminator            disc_loss, disc_acc = pretrain_discriminator_step(                discriminator, disc_optimizer, real_batch, fake_batch            )            disc_losses.append(float(disc_loss))            disc_accs.append(float(disc_acc))                        # Train Generator (simplified - using MLE for stability)            inputs = real_batch[:, :-1]            targets = real_batch[:, 1:]            gen_loss = pretrain_generator_step(generator, gen_optimizer_adversarial, inputs, targets)            gen_losses.append(float(gen_loss))                avg_gen_loss = np.mean(gen_losses)        avg_disc_loss = np.mean(disc_losses)        avg_disc_acc = np.mean(disc_accs)        epoch_time = time.time() - epoch_start                history['gen_adversarial_loss'].append(avg_gen_loss)        history['disc_adversarial_loss'].append(avg_disc_loss)        history['disc_adversarial_acc'].append(avg_disc_acc)                print(f"Epoch {epoch+1}/{config.ADVERSARIAL_EPOCHS} - "              f"G_Loss: {avg_gen_loss:.4f} - D_Loss: {avg_disc_loss:.4f} - "              f"D_Acc: {avg_disc_acc:.4f} - Time: {epoch_time:.1f}s")                # Periodic evaluation        if (epoch + 1) % config.CHECKPOINT_EVERY == 0 or epoch == config.ADVERSARIAL_EPOCHS - 1:            print_training_examples(generator, tokenizer, num_examples=config.DISPLAY_EXAMPLES)                        # Calculate diversity            fake_samples = generate_fake_samples(                generator, tokenizer, 100, config.MAX_SEQUENCE_LENGTH            )            diversity = calculate_diversity(fake_samples, n=2)            history['diversity'].append(diversity)            print(f"Diversity (2-gram): {diversity:.4f}\n")        print("✓ Adversarial training complete!\n")        # ==================== Training Complete ====================    total_time = time.time() - training_start        print("="*60)    print("TRAINING COMPLETE!")    print("="*60)    print(f"Total training time: {total_time/60:.1f} minutes")        # Final evaluation    print("\nFinal Evaluation:")    fake_samples = generate_fake_samples(generator, tokenizer, 200, config.MAX_SEQUENCE_LENGTH)    diversity = calculate_diversity(fake_samples, n=2)    print(f"  Final Diversity (2-gram): {diversity:.4f}")        # Save models    checkpoint_dir = os.path.join(config.MODEL_DIR, "checkpoint_final")    os.makedirs(checkpoint_dir, exist_ok=True)        generator.save_weights(os.path.join(checkpoint_dir, "generator.weights.h5"))    discriminator.save_weights(os.path.join(checkpoint_dir, "discriminator.weights.h5"))        with open(os.path.join(checkpoint_dir, "tokenizer.pkl"), 'wb') as f:        pickle.dump(tokenizer, f)        print(f"\n✓ Models saved to: {checkpoint_dir}")    print("\n" + "#"*60 + "\n")        return generator, discriminator, tokenizer, historyprint("✓ Main training function defined")

---## 8. Generation & TestingGenerate sample dialogues and test the trained model.

In [None]:
def generate_dialogues(generator, tokenizer, num_samples=10, temperature=1.0, max_length=50):    """    Generate multiple dialogue samples.        Args:        generator: Trained generator model        tokenizer: Tokenizer for decoding        num_samples: Number of dialogues to generate        temperature: Sampling temperature (higher = more diverse)        max_length: Maximum generation length    """    print("\n" + "="*60)    print(f"Generating {num_samples} Dialogue Samples")    print("="*60 + "\n")        for i in range(num_samples):        generated = generator.generate_sequence(            start_token=tokenizer.START,            max_length=max_length,            temperature=temperature,            top_k=50        )        text = tokenizer.sequence_to_text(generated)        print(f"{i+1}. {text}")        print("\n" + "="*60 + "\n")def interactive_generation(generator, tokenizer):    """    Interactive dialogue generation mode.        User enters context and model generates responses.    """    print("\n" + "="*60)    print("Interactive Dialogue Generation")    print("="*60)    print("\nInstructions:")    print("  - Enter a context/prompt to generate a response")    print("  - Press Enter for random generation (no context)")    print("  - Type 'quit' to exit\n")    print("="*60 + "\n")        while True:        user_input = input("You: ").strip()                if user_input.lower() in ['quit', 'exit', 'q']:            print("\nGoodbye! 👋\n")            break                # Generate response        if user_input:            # Convert context to sequence (simplified - just use as seed)            generated = generator.generate_sequence(                start_token=tokenizer.START,                max_length=50,                temperature=1.0,                top_k=50            )        else:            # Random generation            generated = generator.generate_sequence(                start_token=tokenizer.START,                max_length=50,                temperature=1.0,                top_k=50            )                text = tokenizer.sequence_to_text(generated)        print(f"DLA: {text}\n")print("✓ Generation functions defined")

---## 🚀 Execute TrainingRun this cell to start the training process!

In [None]:
# Train the modelgenerator, discriminator, tokenizer, history = train_dla_model(Config)# Check if training was successfulif generator is not None:    print("\n✅ Training completed successfully!")    print("\nYou can now:")    print("  1. Generate dialogues using generate_dialogues(generator, tokenizer)")    print("  2. Try interactive mode with interactive_generation(generator, tokenizer)")    print("  3. Load saved models from 'outputs/models/checkpoint_final/'")else:    print("\n⚠️  Training could not start - please add dialogue data to 'data/' directory")

---## 💬 Generate Sample DialoguesGenerate sample dialogues from the trained model.

In [None]:
# Generate 10 sample dialogues with different temperaturesif generator is not None:    print("\n🌟 Conservative Generation (temperature=0.7)")    generate_dialogues(generator, tokenizer, num_samples=5, temperature=0.7)        print("\n🎨 Creative Generation (temperature=1.3)")    generate_dialogues(generator, tokenizer, num_samples=5, temperature=1.3)else:    print("⚠️  Please train the model first!")

---## 🎮 Interactive ModeTry interactive dialogue generation! (Uncomment to activate)

In [None]:
# Uncomment to activate interactive mode# if generator is not None:#     interactive_generation(generator, tokenizer)# else:#     print("⚠️  Please train the model first!")print("✓ Notebook complete! Uncomment the cell above to try interactive generation.")

---## 📊 SummaryThis notebook implements a complete Dialogue Learning Algorithm (DLA) using GANs:### Architecture- **Generator**: 2-layer LSTM with 256 hidden units- **Discriminator**: Multi-kernel CNN (3, 4, 5-gram patterns)### Training Pipeline1. **Data Loading**: Parse multiple .txt dialogue files2. **Pre-training**:    - Generator: 5 epochs MLE   - Discriminator: 3 epochs3. **Adversarial Training**: 5 epochs GAN training### Key Features- ✅ Self-contained (no external files needed)- ✅ Multi-file .txt support- ✅ Fast training (@tf.function decorators)- ✅ Well-commented code- ✅ Interactive generation mode### Files Generated- `outputs/models/checkpoint_final/generator.weights.h5`- `outputs/models/checkpoint_final/discriminator.weights.h5`- `outputs/models/checkpoint_final/tokenizer.pkl`### Next Steps1. Add more dialogue data to `data/` directory2. Adjust hyperparameters in Config class3. Experiment with different temperatures for generation4. Try longer training (increase epochs)**Happy Training! 🎉**