## Importing the Required Libraries

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import PIL
import os
import time
import io
from PIL import Image, ImageOps

import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.tensorboard import SummaryWriter
from sklearn.model_selection import train_test_split
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas

In [None]:
!mkdir -p sketch_data
!mkdir -p /content/cRNN_logs
!mkdir -p /content/cRNN_checkpoints

In [None]:
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


## Defining Required Hyperparameters

In [None]:
class HParams():
    def __init__(self):
        # Data and Classes
        self.classes = ['cat', 'apple', 'airplane','candle','alarm clock' ] # Ensure you have data
        self.data_location = '/content/sketch_data'  # <<<--- POINT TO YOUR DATA DIRECTORY
        self.log_dir = '/content/cRNN_logs' # <<<--- TensorBoard logs for this model
        self.checkpoint_dir = '/content/cRNN_checkpoints' # <<<--- Checkpoints for this model
        self.max_seq_length = 200 # Max sequence length to keep
        self.min_seq_length = 10  # Min sequence length to keep

        # Model Architecture
        self.dec_hidden_size = 512 # Decoder RNN hidden size
        self.M = 20                # Number of mixture components
        self.class_embedding_size = 64 # Size of the learned class embedding vector

        # Training Parameters
        self.batch_size = 100
        self.lr = 0.001
        self.scheduler_factor = 0.5 # Factor to reduce LR by
        self.scheduler_patience = 10 # Epochs to wait for improvement before reducing LR
        self.grad_clip = 1.0       # Gradient clipping threshold
        self.temperature = 0.4     # Sampling temperature
        self.dropout = 0.5         # Can experiment with dropout rate (0.9 might be too high without VAE)

        # Data Splitting
        self.validation_split = 0.15
        self.test_split = 0.15

        # Training Control
        self.num_epochs = 20 # <<<--- Adjust epochs
        self.epochs_til_checkpoint = 500
        self.epochs_til_validation = 20


In [None]:
hp = HParams()

## Preprocessing & Loading Data

In [None]:
def max_size(data):
    """Calculates the longest sequence length in the dataset."""
    sizes = [len(seq) for seq in data]
    return max(sizes) if sizes else 0

def purify(strokes):
    """Removes sequences that are too short or too long."""
    data = []
    for seq in strokes:
        if hp.min_seq_length <= seq.shape[0] <= hp.max_seq_length:
            seq = np.minimum(seq, 1000)
            seq = np.maximum(seq, -1000)
            seq = np.array(seq, dtype=np.float32)
            data.append(seq)
    return data

def calculate_normalizing_scale_factor(strokes):
    """Calculates the stddev of stroke displacements used for normalization."""
    all_displacements = []
    for seq in strokes:
        if len(seq) > 0:
             all_displacements.extend(seq[:, 0])
             all_displacements.extend(seq[:, 1])
    if not all_displacements: return 1.0
    return np.std(np.array(all_displacements))

def normalize(strokes):
    """Normalizes stroke displacements and returns data + scale factor."""
    all_sequences = [seq for seq, _ in strokes]
    scale_factor = calculate_normalizing_scale_factor(all_sequences)
    if scale_factor < 1e-6: scale_factor = 1.0

    normalized_data = []
    for seq, class_index in strokes:
        normalized_seq = seq.copy()
        normalized_seq[:, 0:2] /= scale_factor
        normalized_data.append((normalized_seq, class_index))
    return normalized_data, scale_factor

In [None]:
print("Loading and preprocessing data...")
all_data_raw = []
class_dict = {class_name: i for i, class_name in enumerate(hp.classes)}
num_classes = len(hp.classes)
os.makedirs(hp.data_location, exist_ok=True)

for class_name in hp.classes:
    file_path = os.path.join(hp.data_location, f'sketchrnn_{class_name}.npz')
    loaded_count = 0
    try:
        dataset = np.load(file_path, encoding='latin1', allow_pickle=True)
        class_index = class_dict[class_name]
        for split in ['train', 'valid', 'test']:
            if split in dataset:
                split_data = dataset[split]
                purified_data = purify(split_data)
                for seq in purified_data:
                    all_data_raw.append((seq, class_index))
                loaded_count += len(purified_data)
        print(f"Loaded {loaded_count} sequences for class '{class_name}'")
    except FileNotFoundError:
        print(f"!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
        print(f"Warning: File not found for class {class_name} at {file_path}")
        print(f"!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
    except Exception as e:
        print(f"Error loading or processing file {file_path}: {e}")

if not all_data_raw:
    raise ValueError("No data loaded. Please check `hp.data_location` and ensure .npz files exist.")


Loading and preprocessing data...
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
Loaded 75000 sequences for class 'airplane'
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!


In [None]:
all_data_normalized, data_scale_factor = normalize(all_data_raw)
print(f"Total sequences after purification: {len(all_data_normalized)}")
print(f"Data normalization scale factor: {data_scale_factor}")

Total sequences after purification: 75000
Data normalization scale factor: 53.127464294433594


In [None]:
train_data, temp_data = train_test_split(
    all_data_normalized,
    test_size=(hp.validation_split + hp.test_split),
    random_state=42,
    stratify=[item[1] for item in all_data_normalized]
)
validation_data, test_data = train_test_split(
    temp_data,
    test_size=hp.test_split / (hp.validation_split + hp.test_split),
    random_state=42,
    stratify=[item[1] for item in temp_data]
)
print(f"Training samples: {len(train_data)}")
print(f"Validation samples: {len(validation_data)}")
print(f"Test samples: {len(test_data)}")


Training samples: 52500
Validation samples: 11250
Test samples: 11250


In [None]:
Nmax = 0
if train_data: Nmax = max(Nmax, max_size([item[0] for item in train_data]))
if validation_data: Nmax = max(Nmax, max_size([item[0] for item in validation_data]))
if test_data: Nmax = max(Nmax, max_size([item[0] for item in test_data]))
print(f"Nmax (max sequence length after processing): {Nmax}")

Nmax (max sequence length after processing): 99


In [None]:
def make_batch(batch_size, dataset, Nmax, device):
    """Creates a batch of sequences with padding and class indices."""
    num_data = len(dataset)
    if num_data == 0: return None, None, None
    actual_batch_size = min(batch_size, num_data)
    batch_indices = np.random.choice(num_data, actual_batch_size, replace=False)
    batch_sequences_with_class = [dataset[idx] for idx in batch_indices]

    strokes = []
    lengths = []
    class_indices = []

    for seq, class_index in batch_sequences_with_class:
        len_seq = len(seq)
        new_seq = np.zeros((Nmax, 5), dtype=np.float32)
        new_seq[:len_seq, :2] = seq[:, :2]
        new_seq[:len_seq-1, 2] = 1 - seq[:-1, 2] # p1: pen down
        new_seq[:len_seq, 3] = seq[:, 2]       # p2: end of stroke
        new_seq[len_seq-1:, 4] = 1             # p3: end of drawing
        new_seq[len_seq-1, 2:4] = 0

        lengths.append(len_seq)
        strokes.append(new_seq)
        class_indices.append(class_index)

    batch_tensor = torch.from_numpy(np.stack(strokes, axis=1)).to(device)
    class_indices_tensor = torch.tensor(class_indices, dtype=torch.long, device=device)

    # Return lengths as a Python list
    return batch_tensor, lengths, class_indices_tensor

## Decoder RNN Model

In [None]:
class ConditionalDecoderRNN(nn.Module):
    def __init__(self, num_classes, hp):
        """
        Conditional RNN Decoder module.

        Args:
            num_classes (int): The total number of classes for embedding.
            hp (HParams): Hyperparameters object.
        """
        super(ConditionalDecoderRNN, self).__init__()
        self.hp = hp

        # Class embedding layer
        self.embedding = nn.Embedding(num_classes, hp.class_embedding_size)

        # Layer to initialize LSTM hidden/cell states from class embedding
        # Input: class_embedding_size
        # Output: 2 * dec_hidden_size (for hidden and cell)
        self.fc_init_hc = nn.Linear(hp.class_embedding_size, 2 * hp.dec_hidden_size)

        # Unidirectional LSTM
        # Input at each step: 5 (stroke features) + class_embedding_size
        self.lstm = nn.LSTM(
            input_size=5 + hp.class_embedding_size,
            hidden_size=hp.dec_hidden_size,
            num_layers=1, # Standard uses 1 layer here
            dropout=hp.dropout,
            batch_first=False # Input shape: (seq_len, batch, input_size)
        )

        # Output layer for Gaussian Mixture Model (GMM) parameters
        # Input: dec_hidden_size
        # Output: 6*M (pi, mu_x, mu_y, sigma_x, sigma_y, rho_xy) + 3 (pen states q)
        self.fc_params = nn.Linear(hp.dec_hidden_size, 6 * hp.M + 3)

        # Set module to training mode initially
        self.train()

    def forward(self, inputs, class_indices, hidden_cell=None):
        """
        Forward pass of the conditional decoder.

        Args:
            inputs (Tensor): Input sequence (seq_len, batch_size, 5 + class_embedding_size).
                             Note: Should already contain concatenated stroke and class embedding.
            class_indices (Tensor): Class indices for the batch (batch_size). Used for init state if hidden_cell is None.
            hidden_cell (tuple, optional): LSTM hidden and cell states from previous step. Defaults to None (init from class).

        Returns:
            pi, mu_x, mu_y, sigma_x, sigma_y, rho_xy, q: Output distribution parameters.
            hidden, cell: Last LSTM hidden and cell states.
        """
        # 1. Get Class Embeddings (only needed for init state if hidden_cell is None)
        #    The embedding should already be part of the `inputs` tensor for subsequent steps.
        if hidden_cell is None:
            embedded_init = self.embedding(class_indices) # Shape: (batch_size, class_embedding_size)
            # Use fc_init_hc to generate initial hidden and cell states from class embedding
            hidden_cell_init = F.tanh(self.fc_init_hc(embedded_init)) # Shape: (batch_size, 2 * dec_hidden_size)
            # Split into hidden and cell
            hidden, cell = torch.split(hidden_cell_init, self.hp.dec_hidden_size, dim=1)
            # Reshape for LSTM (num_layers, batch_size, dec_hidden_size)
            hidden_cell = (hidden.unsqueeze(0).contiguous(), cell.unsqueeze(0).contiguous())

        # 2. Pass sequence through LSTM
        # inputs shape: (seq_len, batch_size, 5 + class_embedding_size)
        outputs, (hidden, cell) = self.lstm(inputs, hidden_cell)

        # 3. Process outputs for GMM parameters
        # Reshape outputs to (seq_len * batch_size, dec_hidden_size)
        output_reshaped = outputs.view(-1, self.hp.dec_hidden_size)
        y = self.fc_params(output_reshaped)

        # 4. Separate and reshape GMM parameters (Identical logic to SketchRNN decoder)
        len_out = outputs.size(0)
        batch_s = outputs.size(1)
        params = torch.split(y, 6, dim=1)
        params_mixture = torch.stack(params[:-1], dim=0)
        params_pen = params[-1]
        pi_logits, mu_x_raw, mu_y_raw, sigma_x_log, sigma_y_log, rho_xy_tanh = torch.split(params_mixture, 1, dim=2)

        pi = F.softmax(pi_logits.squeeze(2).transpose(0, 1), dim=-1).view(len_out, batch_s, self.hp.M)
        sigma_x = torch.exp(sigma_x_log.squeeze(2).transpose(0, 1)).view(len_out, batch_s, self.hp.M)
        sigma_y = torch.exp(sigma_y_log.squeeze(2).transpose(0, 1)).view(len_out, batch_s, self.hp.M)
        rho_xy = torch.tanh(rho_xy_tanh.squeeze(2).transpose(0, 1)).view(len_out, batch_s, self.hp.M)
        mu_x = mu_x_raw.squeeze(2).transpose(0, 1).contiguous().view(len_out, batch_s, self.hp.M)
        mu_y = mu_y_raw.squeeze(2).transpose(0, 1).contiguous().view(len_out, batch_s, self.hp.M)
        q = F.softmax(params_pen, dim=-1).view(len_out, batch_s, 3)

        return pi, mu_x, mu_y, sigma_x, sigma_y, rho_xy, q, hidden, cell

## Complete Model

In [None]:
class Model():
    def __init__(self, num_classes, hp, device):
        self.hp = hp
        self.device = device
        self.num_classes = num_classes

        # Only the Decoder is needed
        self.decoder = ConditionalDecoderRNN(num_classes, hp).to(device)

        self.params = self.decoder.parameters() # Parameters are only from the decoder
        self.optimizer = optim.Adam(self.params, hp.lr)
        self.scheduler = ReduceLROnPlateau(self.optimizer, 'min', factor=hp.scheduler_factor, patience=hp.scheduler_patience, verbose=True)
        self.best_val_loss = float('inf')

    def make_target(self, batch, lengths, Nmax):
        """Prepares target tensors (mask, dx, dy, p) for loss calculation.
           Target length is Nmax (input is SOS + sequence[:Nmax-1]).
        """
        batch_size = batch.size(1)
        # Target sequence includes the actual strokes up to Nmax
        target_batch = batch[:Nmax] # Shape (Nmax, B, 5)

        mask = torch.zeros(Nmax, batch_size, device=self.device) # Mask for Nmax steps
        for i in range(batch_size):
            # Mask includes up to the last actual point in the sequence (max Nmax points)
            mask_len = min(lengths[i], Nmax)
            mask[:mask_len, i] = 1

        # Targets need to match the Nmax sequence length
        dx = target_batch[:, :, 0].unsqueeze(-1).repeat(1, 1, self.hp.M) # Shape: (Nmax, B, M)
        dy = target_batch[:, :, 1].unsqueeze(-1).repeat(1, 1, self.hp.M) # Shape: (Nmax, B, M)
        p = target_batch[:, :, 2:5] # Shape: (Nmax, B, 3) - Targets for p1, p2, p3

        return mask, dx, dy, p

    def train_epoch(self, epoch, dataset, Nmax):
        """Runs one training epoch."""
        self.decoder.train() # Set decoder to training mode
        total_loss = 0
        num_batches = len(dataset) // self.hp.batch_size
        if num_batches == 0:
            print("Warning: Not enough data for a single batch in training set.")
            return float('inf')

        for i in range(num_batches):
            # --- Get Batch ---
            # batch shape: (Nmax, B, 5), lengths: list[int], class_indices: (B,)
            batch, lengths, class_indices = make_batch(self.hp.batch_size, dataset, Nmax, self.device)
            if batch is None: continue
            current_batch_size = batch.size(1)

            # --- Prepare Decoder Input (Teacher Forcing) ---
            # Input is SOS followed by the first (Nmax-1) true strokes
            sos = torch.tensor([0, 0, 1, 0, 0], device=self.device, dtype=torch.float32).repeat(current_batch_size, 1).unsqueeze(0)
            # Shifted sequence: Use SOS and ground truth strokes as input
            decoder_input_stroke = torch.cat([sos, batch[:Nmax-1, :, :]], dim=0) # Shape (Nmax, B, 5)

            # Get class embeddings repeated for sequence length Nmax
            embedded = self.decoder.embedding(class_indices) # Shape (B, EmbSize)
            embedded_expanded = embedded.unsqueeze(0).expand(Nmax, -1, -1) # Shape (Nmax, B, EmbSize)

            # Concatenate stroke input with class embedding
            decoder_input = torch.cat([decoder_input_stroke, embedded_expanded], dim=2) # Shape (Nmax, B, 5 + EmbSize)

            # --- Forward Pass ---
            self.optimizer.zero_grad()
            # Initialize hidden state using class_indices, pass the full input sequence
            pi, mu_x, mu_y, sigma_x, sigma_y, rho_xy, q, _, _ = self.decoder(decoder_input, class_indices)

            # --- Calculate Loss ---
            # Targets are the original sequence (batch) up to Nmax steps
            mask, dx, dy, p = self.make_target(batch, lengths, Nmax)
            loss = self.reconstruction_loss(mask, dx, dy, p, pi, mu_x, mu_y, sigma_x, sigma_y, rho_xy, q)

            # --- Backward Pass & Optimize ---
            loss.backward()
            nn.utils.clip_grad_norm_(self.params, self.hp.grad_clip)
            self.optimizer.step()

            total_loss += loss.item()

        avg_loss = total_loss / num_batches
        print(f"Epoch {epoch} | Train Loss: {avg_loss:.4f}")
        # --- Log to TensorBoard ---
        writer.add_scalar('Loss/Train', avg_loss, epoch)
        writer.add_scalar('Params/Learning_Rate', self.optimizer.param_groups[0]['lr'], epoch)

        return avg_loss

    def evaluate(self, dataset, Nmax):
        """Evaluates the model on a dataset (e.g., validation)."""
        self.decoder.eval() # Set to evaluation mode
        total_loss = 0
        num_batches = len(dataset) // self.hp.batch_size
        if num_batches == 0:
             print("Warning: Not enough data for a single batch in validation set.")
             return float('inf')

        with torch.no_grad():
            for i in range(num_batches):
                batch, lengths, class_indices = make_batch(self.hp.batch_size, dataset, Nmax, self.device)
                if batch is None: continue
                current_batch_size = batch.size(1)

                # Prepare inputs (same teacher forcing as training for evaluation loss)
                sos = torch.tensor([0, 0, 1, 0, 0], device=self.device, dtype=torch.float32).repeat(current_batch_size, 1).unsqueeze(0)
                decoder_input_stroke = torch.cat([sos, batch[:Nmax-1, :, :]], dim=0)
                embedded = self.decoder.embedding(class_indices)
                embedded_expanded = embedded.unsqueeze(0).expand(Nmax, -1, -1)
                decoder_input = torch.cat([decoder_input_stroke, embedded_expanded], dim=2)

                # Forward pass
                pi, mu_x, mu_y, sigma_x, sigma_y, rho_xy, q, _, _ = self.decoder(decoder_input, class_indices)

                # Loss calculation
                mask, dx, dy, p = self.make_target(batch, lengths, Nmax)
                loss = self.reconstruction_loss(mask, dx, dy, p, pi, mu_x, mu_y, sigma_x, sigma_y, rho_xy, q)
                total_loss += loss.item()

        avg_loss = total_loss / num_batches
        print(f"--- Validation | Loss: {avg_loss:.4f} ---")
        # --- Log to TensorBoard ---
        writer.add_scalar('Loss/Validation', avg_loss, epoch) # Use current global epoch

        return avg_loss

    # --- Loss Function (Identical to SketchRNN reconstruction loss) ---
    def bivariate_normal_pdf(self, dx, dy, mu_x, mu_y, sigma_x, sigma_y, rho_xy):
        sigma_x = torch.clamp(sigma_x, min=1e-5)
        sigma_y = torch.clamp(sigma_y, min=1e-5)
        rho_xy = torch.clamp(rho_xy, min=-1.0 + 1e-5, max=1.0 - 1e-5)
        norm1 = dx - mu_x
        norm2 = dy - mu_y
        s1s2 = sigma_x * sigma_y
        z = (norm1 / sigma_x)**2 + (norm2 / sigma_y)**2 - 2 * rho_xy * norm1 * norm2 / s1s2
        rho_sq = rho_xy**2
        sqrt_term = torch.clamp(1.0 - rho_sq, min=1e-5)
        log_pdf = -z / (2 * sqrt_term) - torch.log(2 * np.pi * s1s2 * torch.sqrt(sqrt_term))
        return log_pdf

    def reconstruction_loss(self, mask, dx, dy, p, pi, mu_x, mu_y, sigma_x, sigma_y, rho_xy, q):
        log_pdf_vals = self.bivariate_normal_pdf(dx, dy, mu_x, mu_y, sigma_x, sigma_y, rho_xy)
        log_pi = torch.log(torch.clamp(pi, min=1e-5))
        log_likelihood_stroke = torch.logsumexp(log_pi + log_pdf_vals, dim=2)
        log_pen_likelihood = torch.sum(p * torch.log(torch.clamp(q, min=1e-5)), dim=2)
        masked_log_likelihood = mask * (log_likelihood_stroke + log_pen_likelihood)
        total_elements = torch.sum(mask)
        if total_elements == 0: return torch.tensor(0.0, device=self.device)
        loss_recon = -torch.sum(masked_log_likelihood) / total_elements
        return loss_recon

    # --- Checkpointing (Simplified for single decoder model) ---
    def save_checkpoint(self, epoch, is_best=False):
        os.makedirs(self.hp.checkpoint_dir, exist_ok=True)
        filename = os.path.join(self.hp.checkpoint_dir, f"cRNN_checkpoint_epoch_{epoch}.pth")
        save_content = {
            'epoch': epoch,
            'decoder_state_dict': self.decoder.state_dict(), # Only decoder state
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict(),
            'best_val_loss': self.best_val_loss,
            'hp': vars(self.hp)
        }
        torch.save(save_content, filename)
        print(f"Checkpoint saved to {filename}")
        if is_best:
            best_filename = os.path.join(self.hp.checkpoint_dir, "cRNN_best_model.pth")
            torch.save(save_content, best_filename)
            print(f"*** Best validation model saved to {best_filename} ***")

    def load_checkpoint(self, filename="cRNN_best_model.pth"):
        filepath = os.path.join(self.hp.checkpoint_dir, filename)
        if not os.path.exists(filepath):
             print(f"Checkpoint file not found: {filepath}. Training from scratch.")
             return 0
        try:
            checkpoint = torch.load(filepath, map_location=self.device)
            self.decoder.load_state_dict(checkpoint['decoder_state_dict']) # Load decoder state
            self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            if 'scheduler_state_dict' in checkpoint:
                 self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
            self.best_val_loss = checkpoint.get('best_val_loss', float('inf'))
            start_epoch = checkpoint['epoch'] + 1
            print(f"Loaded checkpoint '{filepath}' (trained up to epoch {checkpoint['epoch']})")
            print(f"Best validation loss was: {self.best_val_loss}")
            return start_epoch
        except Exception as e:
            print(f"Error loading checkpoint {filepath}: {e}. Training from scratch.")
            return 0

    # --- Conditional Generation & Sampling (Modified for Autoregressive Decoder) ---
    def conditional_generation(self, epoch, class_index, Nmax, data_scale_factor, temp=None):
        self.decoder.eval() # Set decoder to eval mode
        if temp is None: temp = self.hp.temperature

        original_temp = self.hp.temperature
        self.hp.temperature = temp # Set temp for sampling method

        # Get target class embedding
        class_indices = torch.tensor([class_index], dtype=torch.long, device=self.device)
        embedded = self.decoder.embedding(class_indices) # Shape: (1, EmbSize)

        # Initialize hidden state based on class embedding
        hidden_cell = None # Let the forward pass initialize it

        with torch.no_grad():
            # Start with SOS token
            s = torch.tensor([0, 0, 1, 0, 0], device=self.device, dtype=torch.float32).view(1, 1, 5) # Shape: (1, 1, 5)
            seq_x, seq_y, seq_z = [], [], []

            for i in range(Nmax):
                # Prepare input for this step: last stroke 's' + class embedding
                embedded_step = embedded.unsqueeze(0) # Shape: (1, 1, EmbSize)
                decoder_input_step = torch.cat([s, embedded_step], dim=2) # Shape: (1, 1, 5 + EmbSize)

                # Pass *single step* through decoder, using previous hidden state
                self.pi, self.mu_x, self.mu_y, self.sigma_x, self.sigma_y, \
                self.rho_xy, self.q, hidden, cell = \
                    self.decoder(decoder_input_step, class_indices, hidden_cell) # Pass indices for potential re-init if needed (though shouldn't happen here)
                hidden_cell = (hidden, cell) # Update hidden state for next step

                # Sample the *next* stroke 's' based on the output distributions
                # Ensure sample_next_state uses self.pi etc directly
                s_next, dx, dy, pen_down, end_drawing = self.sample_next_state()

                # Store denormalized stroke
                seq_x.append(dx * data_scale_factor)
                seq_y.append(dy * data_scale_factor)
                seq_z.append(pen_down) # Store pen state

                # Update 's' for the next iteration
                s = s_next

                if end_drawing:
                    print(f"Generated sequence length: {i+1}")
                    break

        self.hp.temperature = original_temp # Restore original temp

        if seq_x:
            x_sample = np.cumsum(seq_x, 0)
            y_sample = np.cumsum(seq_y, 0)
            z_sample = np.array(seq_z)
            sequence = np.stack([x_sample, y_sample, z_sample]).T
            class_name = hp.classes[class_index]
            img_name = f"cRNN_epoch_{epoch}_class_{class_name}_temp_{temp:.2f}"
            make_image(sequence, img_name)

    # sample_next_state: Needs access to self.pi, self.mu_x, etc. from the last decoder step
    # It can be identical to the SketchRNN version as it only depends on the *output parameters*
    def sample_next_state(self):
        def adjust_temp(pi_pdf, temp):
            pi_pdf = np.log(pi_pdf + 1e-8) / temp
            pi_pdf -= np.max(pi_pdf)
            pi_pdf = np.exp(pi_pdf)
            pi_pdf /= np.sum(pi_pdf)
            return pi_pdf

        # Parameters (pi, mu_x, etc.) should be stored as attributes from the last forward pass
        pi_cpu = self.pi.data[0, 0, :].cpu().numpy()
        q_cpu = self.q.data[0, 0, :].cpu().numpy()

        pi_adj = adjust_temp(pi_cpu, self.hp.temperature)
        pi_idx = np.random.choice(self.hp.M, p=pi_adj)

        q_adj = adjust_temp(q_cpu, self.hp.temperature)
        q_idx = np.random.choice(3, p=q_adj)

        mu_x = self.mu_x.data[0, 0, pi_idx].cpu().numpy()
        mu_y = self.mu_y.data[0, 0, pi_idx].cpu().numpy()
        sigma_x = self.sigma_x.data[0, 0, pi_idx].cpu().numpy()
        sigma_y = self.sigma_y.data[0, 0, pi_idx].cpu().numpy()
        rho_xy = self.rho_xy.data[0, 0, pi_idx].cpu().numpy()

        dx, dy = sample_bivariate_normal(mu_x, mu_y, sigma_x, sigma_y, rho_xy, self.hp.temperature, greedy=False)

        next_state = torch.zeros(5, device=self.device)
        next_state[0] = dx
        next_state[1] = dy
        next_state[q_idx + 2] = 1

        pen_down = (q_idx == 0)
        end_drawing = (q_idx == 2)

        return next_state.view(1, 1, 5), dx, dy, pen_down, end_drawing

## Some Miscellaneous Functions

In [None]:
def sample_bivariate_normal(mu_x, mu_y, sigma_x, sigma_y, rho_xy, temp, greedy=False):
    if greedy: return mu_x, mu_y
    sigma_x *= np.sqrt(temp); sigma_y *= np.sqrt(temp)
    sigma_x = max(sigma_x, 1e-4); sigma_y = max(sigma_y, 1e-4); rho_xy = np.clip(rho_xy, -1 + 1e-4, 1 - 1e-4)
    mean = [mu_x, mu_y]; cov = [[sigma_x * sigma_x, rho_xy * sigma_x * sigma_y], [rho_xy * sigma_x * sigma_y, sigma_y * sigma_y]]
    try: x = np.random.multivariate_normal(mean, cov, 1); return x[0][0], x[0][1]
    except np.linalg.LinAlgError as e: print(f"Warning: LinAlgError in multivariate_normal: {e}. Returning mean."); return mu_x, mu_y

In [None]:
def make_image(sequence, name='output_sketch'):
    if sequence is None or len(sequence) == 0 or sequence.shape[1] < 3: print("Warning: Cannot make image from empty or invalid sequence."); return
    pen_lift_indices = np.where(sequence[:, 2] < 0.5)[0]
    strokes = np.split(sequence, pen_lift_indices + 1)
    fig = plt.figure(figsize=(6, 6)); ax1 = fig.add_subplot(111); ax1.set_aspect('equal', adjustable='box')
    cumulative_pos = np.zeros(2)
    for s in strokes:
        if len(s) == 0: continue
        points = np.cumsum(s[:, :2], axis=0)
        plt.plot(points[:, 0] + cumulative_pos[0], -(points[:, 1] + cumulative_pos[1]), 'k-')
        if len(points) > 0: cumulative_pos += points[-1]
    plt.axis('off'); plt.gca().invert_yaxis()
    save_filename = name + '.png'
    try: plt.savefig(save_filename, bbox_inches='tight', pad_inches=0.1); print(f"Image saved to {save_filename}")
    except Exception as e: print(f"Error saving image {save_filename}: {e}")
    plt.close(fig)

## Training the Model

In [None]:
model = Model(num_classes=num_classes, hp=hp, device=device)



In [None]:
writer = SummaryWriter(hp.log_dir)

In [None]:
for epoch in range(hp.num_epochs):
    # --- Training Step ---
    train_loss = model.train_epoch(epoch, train_data, Nmax)

    # --- Validation Step ---
    if (epoch + 1) % hp.epochs_til_validation == 0 or epoch == hp.num_epochs - 1:
        val_loss = model.evaluate(validation_data, Nmax)
        model.scheduler.step(val_loss) # Adjust LR based on validation loss

        is_best = val_loss < model.best_val_loss
        if is_best:
            model.best_val_loss = val_loss
            print(f"*** New best validation loss: {val_loss:.4f} at epoch {epoch} ***")

        # --- Checkpointing ---
        if (epoch + 1) % hp.epochs_til_checkpoint == 0 or is_best:
             model.save_checkpoint(epoch, is_best=is_best)


Epoch 0 | Train Loss: 0.8402
Epoch 1 | Train Loss: 0.5056
Epoch 2 | Train Loss: 0.4167
Epoch 3 | Train Loss: 0.3660
Epoch 4 | Train Loss: 0.3296
Epoch 5 | Train Loss: 0.3018
Epoch 6 | Train Loss: 0.2789
Epoch 7 | Train Loss: 0.2597
Epoch 8 | Train Loss: 0.2442
Epoch 9 | Train Loss: 0.2325
Epoch 10 | Train Loss: 0.2235
Epoch 11 | Train Loss: 0.2115
Epoch 12 | Train Loss: 0.2026
Epoch 13 | Train Loss: 0.1964
Epoch 14 | Train Loss: 0.1876
Epoch 15 | Train Loss: 0.1802
Epoch 16 | Train Loss: 0.1731
Epoch 17 | Train Loss: 0.1659
Epoch 18 | Train Loss: 0.1629
Epoch 19 | Train Loss: 0.1563
--- Validation | Loss: 0.2124 ---
*** New best validation loss: 0.2124 at epoch 19 ***
Checkpoint saved to /content/cRNN_checkpoints/cRNN_checkpoint_epoch_19.pth
*** Best validation model saved to /content/cRNN_checkpoints/cRNN_best_model.pth ***


In [None]:
load_success = model.load_checkpoint("cRNN_best_model.pth")

Loaded checkpoint '/content/cRNN_checkpoints/cRNN_best_model.pth' (trained up to epoch 19)
Best validation loss was: 0.2123752439948084


## Sketch Generation Function

In [None]:
# --- Animation Function (Conditional RNN version - WITH VERTICAL FLIP POST-PLOT) ---
def generate_drawing_animation_cRNN(model, class_index, Nmax, data_scale_factor, num_frames=None, save_path='cRNN_drawing_animation_flipped.gif', temperature=0.4): # Changed default save_path
    """
    Generates animation for the Conditional RNN model.
    The final GIF output will be vertically flipped.
    """

    model.decoder.eval() # Set decoder to eval mode

    # Use model's hp, but override temperature for this generation
    hp = model.hp # Get hp from the model instance
    original_temp = hp.temperature
    hp.temperature = temperature # Set generation temp

    # Get class embedding and init hidden state
    class_indices = torch.tensor([class_index], dtype=torch.long, device=model.device)
    embedded = model.decoder.embedding(class_indices)
    hidden_cell = None # Let decoder initialize

    seq_x, seq_y, seq_z = [], [], []
    s = torch.tensor([0, 0, 1, 0, 0], device=model.device, dtype=torch.float32).view(1, 1, 5) # SOS token

    with torch.no_grad():
        for i in range(Nmax): # Generate up to Nmax steps
            embedded_step = embedded.unsqueeze(0) # Shape: (1, 1, EmbSize)
            decoder_input_step = torch.cat([s, embedded_step], dim=2) # Shape: (1, 1, 5 + EmbSize)

            # Single step forward
            model.pi, model.mu_x, model.mu_y, model.sigma_x, model.sigma_y, \
            model.rho_xy, model.q, hidden, cell = \
                model.decoder(decoder_input_step, class_indices, hidden_cell)
            hidden_cell = (hidden, cell) # Update for next step

            # Sample next state
            s_next, dx, dy, pen_down, end_drawing = model.sample_next_state()

            # Store denormalized stroke
            seq_x.append(dx * data_scale_factor)
            seq_y.append(dy * data_scale_factor)
            seq_z.append(pen_down) # Store pen state (True if down)

            s = s_next # Update s for the next loop

            if end_drawing:
                # print(f"Animation: Drawing completed in {i+1} steps (End of Drawing signal).")
                break
        # else:
             # print(f"Animation: Reached Nmax ({Nmax}) steps without End signal.")

    hp.temperature = original_temp # Restore original temp

    if not seq_x:
        print("Animation failed: No sequence generated.")
        return None

    # --- Frame Generation and GIF Saving ---
    x_sample = np.cumsum(seq_x, 0)
    y_sample = np.cumsum(seq_y, 0) # Keep Y positive for plotting consistency
    z_sample = np.array(seq_z) # Boolean array (True if pen down)

    total_generated_frames = len(x_sample)
    frame_indices = np.arange(total_generated_frames)
    if num_frames is not None and num_frames > 0 and num_frames < total_generated_frames:
        frame_indices = np.linspace(0, total_generated_frames - 1, num_frames, dtype=int)

    output_frames = len(frame_indices)
    frames = []
    print(f"Creating vertically flipped animation '{save_path}' with {output_frames} frames...")

    fig = plt.figure(figsize=(6, 6), dpi=100)
    ax = fig.add_subplot(111)
    ax.set_aspect('equal', adjustable='box')

    # Calculate plot limits (using positive Y)
    if len(x_sample) > 0:
        x_min, x_max = np.min(x_sample), np.max(x_sample)
        y_min, y_max = np.min(y_sample), np.max(y_sample)
        x_range = x_max - x_min
        y_range = y_max - y_min
        x_padding = x_range * 0.1 if x_range > 1e-4 else 1.0
        y_padding = y_range * 0.1 if y_range > 1e-4 else 1.0
        plot_xlim = (x_min - x_padding, x_max + x_padding)
        plot_ylim = (y_min - y_padding, y_max + y_padding)
    else:
        plot_xlim = (-1, 1); plot_ylim = (-1, 1)

    for frame_num, idx in enumerate(frame_indices):
        current_idx_in_sequence = idx
        ax.clear()
        current_x = x_sample[:current_idx_in_sequence+1]
        current_y = y_sample[:current_idx_in_sequence+1]
        current_z = z_sample[:current_idx_in_sequence+1]

        lift_indices = np.where(~current_z)[0]
        start_plot_idx = 0
        for lift_idx in lift_indices:
            segment_x = current_x[start_plot_idx : lift_idx+1]
            segment_y = current_y[start_plot_idx : lift_idx+1]
            if len(segment_x) > 1:
                 ax.plot(segment_x, segment_y, 'k-', linewidth=1.5) # Plot positive Y
            elif len(segment_x) == 1 :
                  ax.plot(segment_x[0], segment_y[0], 'k.', markersize=2)
            start_plot_idx = lift_idx + 1

        if start_plot_idx < len(current_x):
             segment_x = current_x[start_plot_idx:]
             segment_y = current_y[start_plot_idx:]
             if len(segment_x) > 1:
                  ax.plot(segment_x, segment_y, 'k-', linewidth=1.5) # Plot positive Y
             elif len(segment_x) == 1:
                  ax.plot(segment_x[0], segment_y[0], 'k.', markersize=2)

        # --- Set up plot appearance (Keep Y increasing upwards for plot consistency) ---
        ax.set_xlim(plot_xlim)
        ax.set_ylim(plot_ylim)
        ax.axis('off')
        ax.set_title(f"Class: {hp.classes[class_index]} | Temp: {temperature:.2f} | Frame {frame_num+1}/{output_frames}")
        # NO axis inversion here

        # --- Save frame to buffer ---
        canvas = FigureCanvas(fig)
        canvas.draw()
        buf = io.BytesIO()
        fig.savefig(buf, format='png', bbox_inches='tight', pad_inches=0.1)
        buf.seek(0)
        img = Image.open(buf).convert('RGB')

        # <<< --- VERTICAL FLIP ADDED HERE --- >>>
        img_flipped = ImageOps.flip(img)
        frames.append(img_flipped) # Append the flipped image to the list
        # <<< --- END OF FLIP --- >>>

        buf.close()

    plt.close(fig) # Close the figure after generating all frames

    if not frames:
         print("No frames generated for GIF.")
         return None

    # --- Save GIF ---
    duration_ms = max(20, 1000 // 30)
    try:
        frames[0].save(
            save_path,
            format='GIF',
            append_images=frames[1:],
            save_all=True,
            duration=duration_ms,
            loop=0
        )
        print(f"Animation saved to {save_path}")
        return save_path
    except Exception as e:
        print(f"Error saving GIF {save_path}: {e}")
        return None

In [None]:
load_success = model.load_checkpoint("cRNN_best_model.pth")
if load_success:
     print("Generating final animations using the best model...")
     for i, class_name in enumerate(hp.classes):
         print(f"--- Generating animation for: {class_name} ---")
         output_gif_path = f"cRNN_{class_name}_final_animation_FLIPPED.gif" # New filename
         generate_drawing_animation_cRNN( # Call the corrected function
             model=model,
             class_index=i,
             Nmax=Nmax,
             data_scale_factor=data_scale_factor,
             num_frames=200,
             save_path=output_gif_path,
             temperature=0.35
         )
else:
     print("Could not load model.")

Loaded checkpoint '/content/cRNN_checkpoints/cRNN_best_model.pth' (trained up to epoch 19)
Best validation loss was: 0.2123752439948084
Generating final animations using the best model...
--- Generating animation for: cat ---
Creating vertically flipped animation 'cRNN_cat_final_animation_FLIPPED.gif' with 52 frames...
Animation saved to cRNN_cat_final_animation_FLIPPED.gif
--- Generating animation for: apple ---
Creating vertically flipped animation 'cRNN_apple_final_animation_FLIPPED.gif' with 99 frames...
Animation saved to cRNN_apple_final_animation_FLIPPED.gif
--- Generating animation for: airplane ---
Creating vertically flipped animation 'cRNN_airplane_final_animation_FLIPPED.gif' with 99 frames...
Animation saved to cRNN_airplane_final_animation_FLIPPED.gif
--- Generating animation for: candle ---
Creating vertically flipped animation 'cRNN_candle_final_animation_FLIPPED.gif' with 29 frames...
Animation saved to cRNN_candle_final_animation_FLIPPED.gif
--- Generating animation f

## Bonus Task-1

In [None]:
def generate_drawing_animation_cRNN_bonus(model, class_indices, Nmax, data_scale_factor, num_frames=None, save_path='cRNN_parallel_animation_flipped.gif', temperature=0.4):
    """
    Generates animation with multiple sketches drawn simultaneously side-by-side.
    Maintains individual vertical flip for each subplot.
    """
    import numpy as np
    import matplotlib.pyplot as plt
    import PIL
    from PIL import Image, ImageOps
    import io
    from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas

    model.decoder.eval()
    hp = model.hp
    original_temp = hp.temperature
    hp.temperature = temperature

    # Pre-generate all sequences first
    all_sequences = []
    max_length = 0
    for class_index in class_indices:
        # Generate stroke sequence for this class
        class_indices_tensor = torch.tensor([class_index], dtype=torch.long, device=model.device)
        embedded = model.decoder.embedding(class_indices_tensor)
        hidden_cell = None
        s = torch.tensor([0, 0, 1, 0, 0], device=model.device, dtype=torch.float32).view(1, 1, 5)

        seq_x, seq_y, seq_z = [], [], []
        with torch.no_grad():
            for i in range(Nmax):
                embedded_step = embedded.unsqueeze(0)
                decoder_input_step = torch.cat([s, embedded_step], dim=2)

                model.pi, model.mu_x, model.mu_y, model.sigma_x, model.sigma_y, \
                model.rho_xy, model.q, hidden, cell = \
                    model.decoder(decoder_input_step, class_indices_tensor, hidden_cell)
                hidden_cell = (hidden, cell)

                s_next, dx, dy, pen_down, end_drawing = model.sample_next_state()

                seq_x.append(dx * data_scale_factor)
                seq_y.append(dy * data_scale_factor)
                seq_z.append(pen_down)
                s = s_next

                if end_drawing:
                    break

        x_sample = np.cumsum(seq_x, 0)
        y_sample = np.cumsum(seq_y, 0)
        z_sample = np.array(seq_z)

        all_sequences.append({
            'x': x_sample,
            'y': y_sample,
            'z': z_sample,
            'class_name': hp.classes[class_index],
            'length': len(x_sample)
        })

        if len(x_sample) > max_length:
            max_length = len(x_sample)

    if max_length == 0:
        print("No valid sequences generated for any class")
        return None

    # Determine frame indices based on longest sequence
    frame_indices = np.arange(max_length)
    if num_frames is not None and num_frames > 0 and num_frames < max_length:
        frame_indices = np.linspace(0, max_length-1, num_frames, dtype=int)

    # Setup subplot grid
    n_classes = len(class_indices)
    n_cols = int(np.ceil(np.sqrt(n_classes)))
    n_rows = int(np.ceil(n_classes / n_cols))

    fig = plt.figure(figsize=(6 * n_cols, 6 * n_rows), dpi=100)
    axes = [fig.add_subplot(n_rows, n_cols, i+1) for i in range(n_classes)]

    for ax in axes:
        ax.set_aspect('equal', adjustable='box')
        ax.axis('off')

    frames = []
    print(f"Creating parallel animation with {len(frame_indices)} frames...")

    for frame_idx in frame_indices:
        for ax in axes:
            ax.clear()

        for idx, (seq, ax) in enumerate(zip(all_sequences, axes)):
            current_idx = min(frame_idx, seq['length']-1)  # Handle finished sequences
            current_x = seq['x'][:current_idx+1]
            current_y = seq['y'][:current_idx+1]
            current_z = seq['z'][:current_idx+1]

            # Plotting logic for each subplot
            lift_indices = np.where(~current_z)[0]
            start_plot_idx = 0
            for lift_idx in lift_indices:
                segment_x = current_x[start_plot_idx : lift_idx+1]
                segment_y = current_y[start_plot_idx : lift_idx+1]
                if len(segment_x) > 1:
                    ax.plot(segment_x, segment_y, 'k-', linewidth=1.5)
                start_plot_idx = lift_idx + 1

            if start_plot_idx < len(current_x):
                segment_x = current_x[start_plot_idx:]
                segment_y = current_y[start_plot_idx:]
                if len(segment_x) > 1:
                    ax.plot(segment_x, segment_y, 'k-', linewidth=1.5)

            # Set individual plot limits
            x_pad = (np.max(seq['x']) - np.min(seq['x'])) * 0.1 if len(seq['x']) > 0 else 1.0
            y_pad = (np.max(seq['y']) - np.min(seq['y'])) * 0.1 if len(seq['y']) > 0 else 1.0
            ax.set_xlim(np.min(seq['x'])-x_pad if len(seq['x']) > 0 else -1,
                       np.max(seq['x'])+x_pad if len(seq['x']) > 0 else 1)
            ax.set_ylim(np.min(seq['y'])-y_pad if len(seq['y']) > 0 else -1,
                       np.max(seq['y'])+y_pad if len(seq['y']) > 0 else 1)

            ax.set_title(f"{seq['class_name']}\nStep: {current_idx+1}/{seq['length']}")

        # Save combined frame
        canvas = FigureCanvas(fig)
        canvas.draw()
        buf = io.BytesIO()
        fig.savefig(buf, format='png', bbox_inches='tight', pad_inches=0.1)
        buf.seek(0)
        img = ImageOps.flip(Image.open(buf).convert('RGB'))
        frames.append(img)
        buf.close()

    plt.close(fig)
    hp.temperature = original_temp

    if frames:
        duration = max(20, 1000 // 30)
        frames[0].save(
            save_path,
            format='GIF',
            append_images=frames[1:],
            save_all=True,
            duration=duration,
            loop=0
        )
        print(f"Parallel animation saved to {save_path}")
        return save_path
    print("No frames generated for GIF.")
    return None


In [None]:
generate_drawing_animation_cRNN_bonus(
    model=model,
    class_indices=[0, 1, 2, 3],  # 4 classes
    Nmax=Nmax,
    data_scale_factor=data_scale_factor,
    num_frames=300,
    save_path="four_classes_parallel.gif",
    temperature=0.4
)


Creating parallel animation with 129 frames...
Parallel animation saved to four_classes_parallel.gif


'four_classes_parallel.gif'

## Bonus Task-2

In [None]:
def _generate_single_object_sequence(model, class_index, Nmax, data_scale_factor, temperature):
    """Generates the stroke sequence (dx, dy, pen_down) for a single object."""
    model.decoder.eval()  # Set decoder to eval mode

    # Use model's hp, but override temperature for this generation
    hp = model.hp
    original_temp = hp.temperature
    hp.temperature = temperature  # Set generation temp

    # Get class embedding and init hidden state
    class_indices = torch.tensor([class_index], dtype=torch.long, device=model.device)
    embedded = model.decoder.embedding(class_indices)
    hidden_cell = None  # Let decoder initialize

    seq_x_rel, seq_y_rel, seq_z_pen_down = [], [], [] # Relative dx, dy
    # Start with SOS token
    s = torch.tensor([0, 0, 1, 0, 0], device=model.device, dtype=torch.float32).view(1, 1, 5)

    with torch.no_grad():
        for i in range(Nmax):  # Generate up to Nmax steps
            embedded_step = embedded.unsqueeze(0)  # Shape: (1, 1, EmbSize)
            decoder_input_step = torch.cat([s, embedded_step], dim=2) # Shape: (1, 1, 5 + EmbSize)

            # Single step forward
            model.pi, model.mu_x, model.mu_y, model.sigma_x, model.sigma_y, \
            model.rho_xy, model.q, hidden, cell = \
                model.decoder(decoder_input_step, class_indices, hidden_cell)
            hidden_cell = (hidden, cell)  # Update for next step

            # Sample next state (uses model attributes like self.pi set above)
            s_next, dx, dy, pen_down, end_drawing = model.sample_next_state()

            # Store denormalized *relative* stroke
            seq_x_rel.append(dx * data_scale_factor)
            seq_y_rel.append(dy * data_scale_factor)
            seq_z_pen_down.append(pen_down)  # Store pen state (True if down)

            s = s_next  # Update s for the next loop

            if end_drawing:
                # print(f"Object {hp.classes[class_index]} completed in {i+1} steps.")
                break
        # else:
            # print(f"Object {hp.classes[class_index]} reached Nmax ({Nmax}) steps.")

    hp.temperature = original_temp  # Restore original temp

    return np.array(seq_x_rel), np.array(seq_y_rel), np.array(seq_z_pen_down)


In [None]:
def generate_scene_animation(model, class_indices, Nmax, data_scale_factor,
                             num_frames=None, save_path='cRNN_scene_animation_flipped.gif',
                             temperature=0.4, spacing_factor=1.5, initial_offset=(0,0)):
    """
    Generates a scene animation with multiple objects drawn sequentially.
    Applies vertical flip to the final output.

    Args:
        model (Model): The loaded trained model instance.
        class_indices (list[int]): List of class indices for the objects to draw.
        Nmax (int): Max sequence length *per object*.
        data_scale_factor (float): Normalization scale factor.
        num_frames (int, optional): Total frames for the entire animation. Defaults to dynamic.
        save_path (str): Path to save the output GIF.
        temperature (float): Sampling temperature.
        spacing_factor (float): Multiplier for spacing based on previous object width.
        initial_offset (tuple): Starting (x, y) for the first object.
    """
    print(f"Starting scene generation for classes: {[model.hp.classes[i] for i in class_indices]}")
    model.decoder.eval()
    hp = model.hp

    all_abs_x = []      # Stores absolute X coordinates for the entire scene
    all_abs_y = []      # Stores absolute Y coordinates for the entire scene
    all_pen_down = []   # Stores pen state (True=down) for the entire scene
    object_start_indices = [0] # Index in the combined list where each object starts

    current_offset_x, current_offset_y = initial_offset

    for i, class_idx in enumerate(class_indices):
        print(f"  Generating object {i+1}/{len(class_indices)}: {hp.classes[class_idx]}")
        seq_x_rel, seq_y_rel, seq_z = _generate_single_object_sequence(
            model, class_idx, Nmax, data_scale_factor, temperature
        )

        if len(seq_x_rel) == 0:
            print(f"  Warning: No sequence generated for object {hp.classes[class_idx]}. Skipping.")
            continue

        # Calculate absolute coordinates for this object
        abs_x = np.cumsum(seq_x_rel) + current_offset_x
        abs_y = np.cumsum(seq_y_rel) + current_offset_y

        # Append to the master lists
        all_abs_x.extend(abs_x.tolist())
        all_abs_y.extend(abs_y.tolist())
        all_pen_down.extend(seq_z.tolist())
        object_start_indices.append(len(all_abs_x)) # Mark end of this object / start of next

        # --- Calculate offset for the next object (simple horizontal placement) ---
        if len(abs_x) > 0:
            min_x, max_x = np.min(abs_x), np.max(abs_x)
            width = max_x - min_x
            # Update offset to be to the right of the current object
            current_offset_x = max_x + width * (spacing_factor - 1.0)
            # You could add logic here to move to the next "row" if current_offset_x exceeds a limit
            # current_offset_y = ...
        print(f"    Object {hp.classes[class_idx]} added. Next offset: ({current_offset_x:.1f}, {current_offset_y:.1f})")


    # --- Convert combined lists to numpy arrays ---
    all_abs_x = np.array(all_abs_x)
    all_abs_y = np.array(all_abs_y)
    all_pen_down = np.array(all_pen_down) # Boolean array

    total_scene_strokes = len(all_abs_x)
    if total_scene_strokes == 0:
        print("Animation failed: No strokes generated for the entire scene.")
        return None

    # --- Frame Generation and GIF Saving ---
    frame_indices = np.arange(total_scene_strokes)
    if num_frames is not None and num_frames > 0 and num_frames < total_scene_strokes:
        frame_indices = np.linspace(0, total_scene_strokes - 1, num_frames, dtype=int)
    elif total_scene_strokes > 1000: # Limit default frames for very long sequences
         frame_indices = np.linspace(0, total_scene_strokes - 1, 1000, dtype=int)


    output_frames = len(frame_indices)
    frames = []
    print(f"Creating vertically flipped scene animation '{save_path}' with {output_frames} frames...")

    fig = plt.figure(figsize=(8, 6), dpi=100) # Potentially wider figure for scene
    ax = fig.add_subplot(111)
    ax.set_aspect('equal', adjustable='box')

    # Calculate plot limits for the entire scene
    if total_scene_strokes > 0:
        x_min, x_max = np.min(all_abs_x), np.max(all_abs_x)
        y_min, y_max = np.min(all_abs_y), np.max(all_abs_y)
        x_range = x_max - x_min
        y_range = y_max - y_min
        x_padding = x_range * 0.1 if x_range > 1e-4 else 5.0 # Add some padding
        y_padding = y_range * 0.1 if y_range > 1e-4 else 5.0
        plot_xlim = (x_min - x_padding, x_max + x_padding)
        plot_ylim = (y_min - y_padding, y_max + y_padding)
    else:
        plot_xlim = (-10, 10); plot_ylim = (-10, 10)

    # Find which object is being drawn at each frame index for title
    current_object_idx = 0

    for frame_num, master_idx in enumerate(frame_indices):
        # Determine which object is being drawn
        while current_object_idx + 1 < len(object_start_indices) and master_idx >= object_start_indices[current_object_idx+1]:
            current_object_idx += 1
        current_class_name = hp.classes[class_indices[current_object_idx]] if current_object_idx < len(class_indices) else "Finished"

        ax.clear()
        # Get coordinates and pen states up to the current master index
        current_x = all_abs_x[:master_idx+1]
        current_y = all_abs_y[:master_idx+1]
        current_z_down = all_pen_down[:master_idx+1] # Pen is down (True) or up (False)

        # Plot segments based on pen lifts (where current_z_down is False)
        lift_indices = np.where(~current_z_down)[0] # Indices where pen is UP
        start_plot_idx = 0
        for lift_idx in lift_indices:
            # Plot segment before the lift
            segment_x = current_x[start_plot_idx : lift_idx+1]
            segment_y = current_y[start_plot_idx : lift_idx+1]
            if len(segment_x) > 1:
                 ax.plot(segment_x, segment_y, 'k-', linewidth=1.5)
            # elif len(segment_x) == 1 : # Draw single points if needed
            #       ax.plot(segment_x[0], segment_y[0], 'k.', markersize=2)
            start_plot_idx = lift_idx + 1 # Start next segment after the lift

        # Plot the last segment (or the only segment if no lifts)
        if start_plot_idx < len(current_x):
             segment_x = current_x[start_plot_idx:]
             segment_y = current_y[start_plot_idx:]
             if len(segment_x) > 1:
                  ax.plot(segment_x, segment_y, 'k-', linewidth=1.5)
             # elif len(segment_x) == 1:
             #      ax.plot(segment_x[0], segment_y[0], 'k.', markersize=2)

        # --- Set up plot appearance (Keep Y increasing upwards for plot consistency) ---
        ax.set_xlim(plot_xlim)
        ax.set_ylim(plot_ylim)
        ax.axis('off')
        ax.set_title(f"Scene | Drawing: {current_class_name} | Temp: {temperature:.2f} | Frame {frame_num+1}/{output_frames}")

        # --- Save frame to buffer ---
        canvas = FigureCanvas(fig)
        canvas.draw()
        buf = io.BytesIO()
        fig.savefig(buf, format='png', bbox_inches='tight', pad_inches=0.1)
        buf.seek(0)
        img = Image.open(buf).convert('RGB')

        # <<< --- VERTICAL FLIP ADDED HERE --- >>>
        img_flipped = ImageOps.flip(img)
        frames.append(img_flipped) # Append the flipped image
        # <<< --- END OF FLIP --- >>>

        buf.close()

    plt.close(fig) # Close the figure after generating all frames

    if not frames:
         print("No frames generated for scene GIF.")
         return None

    # --- Save GIF ---
    duration_ms = max(20, 1000 // 30) # Aim for ~30 fps, min 20ms
    try:
        frames[0].save(
            save_path,
            format='GIF',
            append_images=frames[1:],
            save_all=True,
            duration=duration_ms,
            loop=0 # Loop indefinitely
        )
        print(f"Scene animation saved to {save_path}")
        return save_path
    except Exception as e:
        print(f"Error saving scene GIF {save_path}: {e}")
        return None

In [None]:
class_to_idx = {name: i for i, name in enumerate(hp.classes)}
scene_objects = ['cat', 'apple', 'airplane'] # Example scene

In [None]:
try:
  scene_indices = [class_to_idx[name] for name in scene_objects]
except KeyError as e:
  print(f"Error: Class '{e}' not found in model's known classes: {hp.classes}")
  print("Cannot generate scene.")
  scene_indices = [] # Prevent further execution

if scene_indices:
  # 4. Generate the Scene Animation
  output_gif_path = "my_cool_scene_animation_FLIPPED.gif"
  generate_scene_animation(
      model=model,
      class_indices=scene_indices,
      Nmax=hp.max_seq_length, # Use the hp value for max length per object
      data_scale_factor=data_scale_factor, # Use the correct scale factor!
      num_frames=400,        # Total frames for the whole animation (adjust as needed)
      save_path=output_gif_path,
      temperature=0.3,       # Adjust temperature for desired randomness
      spacing_factor=1.3,    # Adjust spacing between objects
      initial_offset=(0, 0)  # Start drawing at origin
  )
  print("\nScene generation complete.")

Starting scene generation for classes: ['cat', 'apple', 'airplane']
  Generating object 1/3: cat
    Object cat added. Next offset: (115.1, 0.0)
  Generating object 2/3: apple
    Object apple added. Next offset: (308.8, 0.0)
  Generating object 3/3: airplane
    Object airplane added. Next offset: (663.4, 0.0)
Creating vertically flipped scene animation 'my_cool_scene_animation_FLIPPED.gif' with 400 frames...
Scene animation saved to my_cool_scene_animation_FLIPPED.gif

Scene generation complete.
