<h1><center></center></h1>
<div style="display: flex; justify-content: center; margin: 0 auto;" align="center">
  <img src="https://myth-ai.com/wp-content/uploads/2023/05/646f153be1e56.png" href="https://myth-ai.com/" width="100px" align="center">
  <h1>Technical Assignment</h1>
</div>

<div align="center">
  <h2>
  Sketch Generation via Diffusion Models using Sequential Strokes
  </h2>
</div>


<div align="center">
  <img src="https://github.com/googlecreativelab/quickdraw-dataset/blob/master/preview.jpg?raw=true">
  <figcaption>
    Collection of 50 million drawings across 345 categories, contributed by players of the game Quick, Draw!. Drawings were captured as timestamped vectors.
    <i>Source: <a href="https://quickdraw.withgoogle.com/data/">Quick, Draw! Dataset</a>.</i>
  </figcaption>
</div>

---

## Objective

In this project, you are expected to implement a **conditional generative diffusion model** that learns to generate hand-drawn sketches in a **stroke-by-stroke** sequential manner. Rather than generating the entire sketch at once, your model should mimic the **sequential nature of human drawing**, producing strokes one after another in a realistic and interpretable way.

You will use the [Quick, Draw!](https://quickdraw.withgoogle.com/data/) dataset released by Google, which provides timestamped vector representations of user-drawn sketches across 345 object categories.

---

## Brief Explanation

You will design and train a **separate conditional diffusion model** for each of the following three categories:

- `cat`
- `bus`
- `rabbit`

Each model must learn to generate sketches from that category using **sequential stroke data**. That means you will build **three separate models** in total—one per category.

Your implementation must be documented in a reproducible Jupyter notebook, including training steps, visualizations, and both qualitative and quantitative evaluations.

- Include comprehensive documentation of your approach and design decisions.
- Provide clear training procedures, model architecture explanations, and inference code.
- Ensure full reproducibility (running all cells should yield consistent results with fixed random seeds).

---

## Data Specification

The Quick, Draw! dataset includes over 50 million sketches in vector format, with each sketch consisting of multiple strokes, where each stroke is a sequence of coordinates (`x`, `y`) along with timing information.

You can download the raw `.ndjson` files from the this [section](#cell-id1). The following commands will download the required categories (`cat`, `bus`, `rabbit`) into the ./data directory.

**⚠️ Note:** If you're not using Google Colab or Kaggle, make sure you have `gsutil` installed. You can install it via pip:

```bash
pip install gsutil
```

**⚠️ Important:** The dataset files are in [NDJSON](https://github.com/ndjson/ndjson-spec) format. Make sure to install the ndjson Python module before attempting to parse the files.

```bash
pip install ndjson
```

### Train/Test Subsets for Target Categories

After downloading the dataset in the `./data` directory, extract the provided `subset.zip` file. This archive includes the predefined train/test splits for each of the three categories.

```
subset/
├── cat/
│  └── indices.json
├── bus/
│  └── indices.json
└── rabbit/
│  └── indices.json
```

Each `indices` file contains a JSON structure with two keys:

- `"train"`: list of indices for training
- `"test"`: list of indices for testing

**⚠️ Important:** Strictly adhere to these predefined splits for consistent evaluation.


---


## Evaluation

You must evaluate your model both **qualitatively** and **quantitatively**.

### Quantitative Evaluation

Use the following metrics to compare the real test set sketches with those generated by your model:

- **FID (Fréchet Inception Distance)**
- **KID (Kernel Inception Distance)**

These metrics should be computed **separately for each category** using the sketches indexed under the `"test"` key in each category’s `indices.json` file.

> **Final submission must include three FID and three KID scores—one pair per category.**

### Qualitative Evaluation

Provide visual demonstrations including:

- Sample generated sketches for each category.
- Your submission must include three animated GIFs (one per category) showing the stroke-by-stroke generation process, similar to `example.gif` file in the link.
- Comparison between real and generated sketches.


---


## Deliverables

Your submission should include the following:

- A well-structured **Jupyter Notebook** that:
  - Explains your approach and design decisions
  - Implements the conditional diffusion model
  - Includes training procedure and inference pipeline code
  - Presents both qualitative and quantitative results
  - Visual examples of generated sketches for each of the 3 categories
  - Animated GIFs demonstrating progressive sketch generation (similar to the provided example.gif)
  - Clearly computed FID/KID scores for each category
- Model performance analysis across categories
- Comparison of generated vs. real sketch characteristics
- Discussion of limitations and potential improvements


> 🔒 All visualizations must be based on sketches generated by your own model. Using samples from external sources will be considered **plagiarism** and will result in disqualification.

> 🔁 The notebook must be **fully reproducible**: running all cells from top to bottom should produce the same results (assuming fixed random seed).

---

## Acknowledgements

- [The Quick, Draw! Dataset](https://github.com/googlecreativelab/quickdraw-dataset)
- [Quick, Draw! Kaggle Competition](https://www.kaggle.com/c/quickdraw-doodle-recognition/overview)
- [Diffusion Models Overview (Lil’Log)](https://lilianweng.github.io/posts/2021-07-11-diffusion-models/)
- [Ha, D., & Eck, D. (2017). A neural representation of sketch drawings. arXiv preprint arXiv:1704.03477.](https://arxiv.org/pdf/1704.03477)
- Special thanks to M. Sung, KAIST

# Download the Quick, Draw! Dataset

<a name="cell-id1"></a>

In [None]:
# If you're not using Colab or Kaggle, uncomment the following line:
# !pip install gsutil

In [None]:
%pip install ndjson

In [None]:
%mkdir data
!gsutil -m cp 'gs://quickdraw_dataset/full/simplified/cat.ndjson' ../data
!gsutil -m cp 'gs://quickdraw_dataset/full/simplified/bus.ndjson' ../data
!gsutil -m cp 'gs://quickdraw_dataset/full/simplified/rabbit.ndjson' ../data

# Solution

- Briefly explain why you chose the method you did.
- Discuss the drawbacks and advantages of your chosen method.
- Evaluate and discuss the results for each metric.

## Introduction

This notebook involved the technical development of a diffusion models. There are different report which contains less technical detail and related works (general_report.pdf).

In this assignment, diffusion models designed to produce hand-drawn sketches in a sequential, stroke-by-stroke manner. To facilitate this process, the raw Quick, Draw! data was converted into a 5D vector representation (Δx, Δy, pen_state). There are couple of architecture with different conditions were experimeneted includes Diffusion Transformer and LSTM based Stroke History encoder.

<b>While the end-to-end training pipeline was successfully built, the model faced significant challenges. <u>Although the final visual outputs have not yet converged, this notebook provides a comprehensive record of the architectural design, the  training and evaluation process.</u></b>


In next sections Dataset Processing, Diffusion Models, Training Strategy and Evaluation methods will be discussed detailly.

## Dataset Processing

This project focuses on the Quick, Draw! dataset, which contains sketches from various categories. For this assignment, I will work with only three specific classes: cat, bus, and rabbit. Before delving into the diffusion model architecture and training strategy, I will use several helper functions to gain a deeper understanding of the data. The insights gathered here will be essential for the subsequent sections of this notebook.

In [None]:
# CONSTANTS

BATCH_SIZE = 4
CLASS_NAME = "cat" # cat, bus, rabbit are available
NDJSON_PATH = f"../data/{CLASS_NAME}.ndjson"
SET_INDICES_PATH = f"../subset/{CLASS_NAME}/indices.json"

Training and test sets are divided and shared, therefore these splits are directly used for model training.

In [None]:
# Read the data files
from src.utils.dataset_utils import read_ndjson_file, get_subset

dataset = read_ndjson_file(file_path=NDJSON_PATH)
train_set = get_subset(dataset=dataset, indices_json_path=SET_INDICES_PATH, subset_name="train")
test_set = get_subset(dataset=dataset, indices_json_path=SET_INDICES_PATH, subset_name="test")

print(f"Total Training Set Size: {len(train_set)}")
print(f"Total Testing Set Size: {len(test_set)}")

In [None]:
import random
from IPython.display import Image
from src.utils.dataset_utils import generate_gif

# Randomly select a drawing from the training set and generate a GIF
randomly_selected_drawing = random.choice(train_set)
randomly_selected_drawing_idx = train_set.index(randomly_selected_drawing)

saved_path = generate_gif(drawing=randomly_selected_drawing,
                          output_path="./",
                          output_name="random_drawing.gif")

Image(filename=saved_path, width=300, height=300)

The Quick, Draw! dataset provides sketches as a sequence of strokes, with each stroke being an array of (x, y) coordinates. This structure makes it possible to render the final sketch sequentially, just as a human would draw it.

However, feeding this raw coordinate data directly into a deep learning network is not ideal, as its absolute structure makes the learning process more difficult. [SketchRNN](https://arxiv.org/abs/1704.03477) paper introduced a method to convert this data into a more effective 5D vector format. In this format, each point is represented by five elements: the first two (Δx, Δy) are the offset from the previous point, while the last three are a one-hot vector representing the pen's state. The third element is 1 if the pen is down (stroke is being drawn), the fourth is 1 if the pen is up (end of a stroke), and the final element is 1 to signify the end of the entire drawing.

After this conversion, the dataset is properly structured to be fed into a deep learning network.


In [None]:
# Convert dataset to 5D format
from src.utils.dataset_utils import convert_drawing_to_5d_format

# (delta_x, delta_y, pen_down, pen_up, completed)
train_set_5d = convert_drawing_to_5d_format(train_set)
test_set_5d = convert_drawing_to_5d_format(test_set)

randomly_selected_drawing = random.choice(train_set_5d)
print(f"5D Format Shape: {randomly_selected_drawing.shape}")

Before diving into the training process, it's a good practice to verify that the data conversion works seamlessly. To do this, I have implemented a function that converts the 5D vector format back to the original coordinate format. This "round-trip" conversion test ensures that our preprocessing pipeline is reversible and correct before we proceed.

In [None]:
# Test 5D converting format
from src.utils.dataset_utils import convert_5d_to_raw_format

raw_randomly_selected_drawing = convert_5d_to_raw_format(train_set_5d[randomly_selected_drawing_idx])
saved_path = generate_gif(drawing=raw_randomly_selected_drawing,
                          output_path="./",
                          output_name="regenerated_drawing.gif")

Image(filename=saved_path, width=300, height=300)

Finally, normalizing the data is a crucial preprocessing step that helps stabilize and accelerate the training of deep learning models. For this project, I will adopt the strategy from the [SketchRNN](https://arxiv.org/abs/1704.03477) paper. This approach involves calculating the standard deviation of all delta values (Δx, Δy) across the entire training set and then normalizing the data by dividing each value by this single statistic.

It is essential that the standard deviation is computed only from the training data. This same value must then be used to normalize the validation and test sets, and later to denormalize the model's output during inference. This practice prevents data leakage and ensures consistent scaling across all phases of the project.

In [None]:
# Calculate the mean and std of the dataset for normalization
from src.utils.dataset_utils import calculate_normalization_params

mean, std  = calculate_normalization_params(train_set_5d)

print(f"Normalization Parameters for Training Dataset:")
print(f"Mean: {mean}, Std: {std}")

## Diffusion Model

In this section I will deep dive the Dataset Object and Diffusion Models architecture. As it is mentioned before there is no succesfully trained network in the report. Therefore I decided to share each strategy that I used while training the network.

In [None]:
# Constants
import torch

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

### Unconditional/Conditional DiT Approach
In this section, I experimented with the same core architecture in different configurations. The main goal of this approach is to generate a complete sketch in a single sampling process.

#### QuickDraw Dataset Object
This custom Dataset object is initialized with the <i>drawings dataset</i>, <i>maximum sequence length</i>, <i>standart deviation of the training dataset</i>, <i>random scale factor for augmentation dataset</i>, <i>stroke augmentation probability</i> and <i>limit value to avoid outliers</i>. The dataset code can be review in <i>src/modules/dataset.py</i> file.

The data processing pipeline is as follows:
1. <b>Preprocessing</b>; First, drawings longer than the max_seq_length are filtered out. Then, outlier delta values are clipped, and finally, the data is normalized using the pre-calculated standard deviation.
2. <b>__ getitem __</b>; When a sample is requested, it is randomly selected and augmented based on the provided parameters. A special SOS (Start of Sequence) token is prepended to the drawing, which is then padded* with zeros to the max_seq_length.

<b>The SOS token, [0, 0, 1, 0, 0], signals the start of a drawing.</b> This helps the model learn that the first real delta values represent absolute coordinates, a method inspired by the [SketchRNN](https://arxiv.org/abs/1704.03477) paper. The network first sees zero deltas with a "pen down" state, which sets the condition for the following stroke data.

I set 96 as a maximum sequence length. This value is copied from [SketchKnitter](https://openreview.net/pdf?id=4eJ43EN2g6l) paper which is designed for the same purpose.

*Padding is essential for batching during model training, as it ensures all tensors in a batch have a uniform dimension. This allows the data to be processed in parallel on the GPU.

In [None]:
# Create DataLoader for training and testing sets
from torch.utils.data import DataLoader
from src.modules.dataset import QuickDrawDataset


train_dataset = QuickDrawDataset(
    drawing_5d=train_set_5d,
    max_seq_length=96, # Copies from the reference implementation
    std=std,
    random_scale_factor=0.15, # Copies from the reference implementation
    augment_stroke_prob=0.10, # Copies from the reference implementation,
    limit=1000, # Copies from the reference implementation
)

test_dataset = QuickDrawDataset(
    drawing_5d=test_set_5d,
    max_seq_length=96, # Copies from the reference implementation
    std=std,
    random_scale_factor=0.15, # Copies from the reference implementation
    augment_stroke_prob=0.10, # Copies from the reference implementation,
    limit=1000, # Copies from the reference implementation
)

train_dataloader = DataLoader(
    dataset=train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True
)

test_dataloader = DataLoader(
    dataset=test_dataset,
    batch_size=BATCH_SIZE,
)

for test_batch in train_dataloader:
    print(f"Batch Shape: {test_batch['stroke'].shape}")
    print(f"Batch Sample: {test_batch['stroke'][0, :5]}")
    break

#### Diffusion Transformer Model
I chose a Diffusion Transformer (DiT) for this task because Transformer networks are exceptionally well-suited for sequential problems. Their self-attention mechanism is powerful for capturing the long-range dependencies between points and strokes, which is crucial for generating coherent sketches.

I experimented with two main variations of this architecture: an unconditional and a conditional approach. The model code can be review in <i>src/models/dit.py</i> file.

##### Unconditional DiT
In this approach, the model learns to generate a sketch holistically from a noised 2D vector. The architecture is designed with a multi-task learning objective:

1. The delta values are noised and fed into the Transformer backbone.

2. The model's final output layer is split into two separate "heads": one predicts the noise for the Δx, Δy values, and the other predicts the logits for the pen_state.

This design, inspired by the multi-headed outputs in models like [SketchKnitter](https://openreview.net/pdf?id=4eJ43EN2g6l), allows the network to learn both the continuous motion and discrete state predictions from a shared underlying representation.

![Unconditional Dit Architecture](assets/uncond_dit.png)

##### Condional DiT

This approach frames the task as an inpainting problem, where the pen states are treated as a known "condition." The model's goal is to generate the correct movements (Δx, Δy) that correspond to this known sequence of actions.

During training, only the Δx, Δy values are noised, while the pen_state vectors are kept clean.

The model is fed both the noised deltas and the clean pen states as input.

The model's objective is to predict the noise only for the delta values, using the clean pen states as a guiding context.

This gives the model the explicit ability to decide how a stroke should look based on whether the pen is on the paper or being lifted.

![Conditional Dit Architecture](assets/cond_dit.png)


In [None]:
# Create Diffusion Model for First Approach
from src.models.dit import DiffusionTransformer

diffusion_model = DiffusionTransformer(
    input_feats=5,
    output_feats=2,
    pen_state_feats=2,
    pen_condition=True # Set to True for pen state conditioning
)
diffusion_model.to(DEVICE)


### Stroke History based Diffusion Transformer Model
In this section, I tried to use cross-attention layer by conditioning stroke history. This is an autoregressive diffusion approach. In each sampling step single stroke is generated and appended to the stroke history. Stroke History is an another LSTM based network which extracts features from the histroy and condition the DiT by cross attention layers.


#### StrokeHistoryQuickDrawDataset Dataset Object
This custom Dataset object is initialized with the <i>drawings dataset</i>, <i>maximum sequence length</i>, <i>standart deviation of the training dataset</i>, <i>random scale factor for augmentation dataset</i>, <i>stroke augmentation probability</i> and <i>limit value to avoid outliers</i>. The dataset code can be review in <i>src/modules/dataset.py</i> file.

The data processing pipeline is as follows:
1. <b>Preprocessing</b>; First, drawings longer than the max_seq_length are filtered out. Then, outlier delta values are clipped, and finally, the data is normalized using the pre-calculated standard deviation.
2. <b>__ getitem __</b>; When a sample is requested, it is randomly selected and augmented based on the provided parameters. A single stroke is selected from the drawing. The selected stroke is used to fed the diffusion model and the previous strokes is used for Storhe History model. In this case, special SOS (Start of Sequence) token is only prepended to the stroke history. Also both stroke_history and stroke padded* with zeros to the max_seq_length and max_single_stroke_length.

I set 96 as a maximum sequence length. This value is copied from [SketchKnitter](https://openreview.net/pdf?id=4eJ43EN2g6l) paper which is designed for the same purpose. Maximum single stroke is calculated after preprocessing step.


In [None]:
# Create DataLoader for training and testing sets
from torch.utils.data import DataLoader
from src.modules.dataset import StrokeHistoryQuickDrawDataset


history_train_dataset = StrokeHistoryQuickDrawDataset(
    drawing_5d=train_set_5d,
    max_seq_length=96, # Copies from the reference implementation
    std=std,
    random_scale_factor=0.15, # Copies from the reference implementation
    augment_stroke_prob=0.10, # Copies from the reference implementation,
    limit=1000, # Copies from the reference implementation
)

history_test_dataset = StrokeHistoryQuickDrawDataset(
    drawing_5d=test_set_5d,
    max_seq_length=96, # Copies from the reference implementation
    std=std,
    random_scale_factor=0.15, # Copies from the reference implementation
    augment_stroke_prob=0.10, # Copies from the reference implementation,
    limit=1000, # Copies from the reference implementation
)

history_train_dataloader = DataLoader(
    dataset=history_train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True
)

history_test_dataloader = DataLoader(
    dataset=history_test_dataset,
    batch_size=BATCH_SIZE,
)

for test_batch in history_train_dataloader:
    print(f"Batch Shape: {test_batch['stroke'].shape}")
    print(f"Batch Sample: {test_batch['stroke'][0, :5]}")
    break

#### Stroke History Conditioned Diffusion Transformer
For the final experiment, I moved from a holistic generation model to a more advanced hierarchical and autoregressive architecture. This approach uses an additional network to encode the drawing history, providing a rich context to the Diffusion Transformer via cross-attention.

1. The stroke history is first fed into an LSTM-based StrokeHistoryEncoder. This network uses a bidirectional LSTM, which increases its power to extract complex relationships between the strokes by processing the sequence in both forward and reverse directions. A padding mask is provided to this encoder, which is used within its attention layers to ensure that the model does not form relationships between real stroke data and the meaningless padded areas.
2. The features extracted by the encoder, which represent a contextual summary of the drawing's past, are then fed as a condition to the DiffusionTransformer. The Transformer blocks in this model are equipped with cross-attention layers, a powerful and popular method for conditioning diffusion models on external information like text or, in this case, a drawing history.
3. The DiffusionTransformer's main task is to denoise the next target stroke. It takes the noisy stroke as input and, guided by the context from the history encoder, predicts the noise. Similar to the first unconditional approach, it uses two separate output heads to predict the noise for the delta values and the logits for the pen states.

The primary drawback of this autoregressive architecture is its slow sampling speed. Because the model generates sketches sequentially—producing only one stroke at a time based on all previous strokes—the overall inference time scales linearly with the number of strokes in the final drawing. This iterative process is computationally expensive and makes generation significantly more time-consuming compared to holistic approaches that produce the entire sketch in a single pass.

![Stroke History Dit Architecture](assets/stroke_history_dit.png)


In [None]:
# Create Diffusion Model for First Approach
from src.models.history_encoder import StrokeHistoryEncoder
from src.models.dit import DiffusersBlockCrossAttentionTransformer

history_encoder = StrokeHistoryEncoder()
diffusion_model_wcrossattn = DiffusersBlockCrossAttentionTransformer(
    input_feats=2,
    output_feats=2,
    pen_state_feats=3,
    pen_condition=True # Set to True for pen state conditioning
)

history_encoder.to(DEVICE)
diffusion_model_wcrossattn.to(DEVICE)

## Model Training

In this section I provide my training strategy. I prefer to share my functions in a cell.

In [None]:
# CONSTANTS
import torch

NUM_STEPS = 50000
NUM_TIMESTEPS = 100
LEARNING_RATE = 1e-4
MIN_LEARNING_RATE = 1e-6


To maintain consistency across different experiments, I encapsulated the core training logic within a single single_step function. This function is designed to be flexible, with parameters like <i>pen_condition</i> that can be toggled for different architectural approaches. Below is a description of this function and its key components.

<b>rescale_timesteps</b> <br>
This function is an implementation of a technique originally used in the [SketchKnitter](https://openreview.net/pdf?id=4eJ43EN2g6l) paper. It rescales the diffusion timesteps, which can improve sampling stability and efficiency. The paper proposes that using a smaller number of timesteps for sampling (inference) is more efficient. Following their findings, I set the <i>NUM_TIMESTEPS</i> hyperparameter to 100 for my experiments.

<b>single_step</b> <br>
This function handles a single forward pass of the training process, including the loss calculation. The total loss is composed of two distinct parts, creating a multi-task objective:

- MSE Loss: Used for the primary denoising task on the continuous delta values (Δx, Δy).
- Cross-Entropy Loss: Used for the pen state classification task (p1, p2, p3), when this feature is active.

These two losses are combined using a lambda coefficient. For this project, I set lambda to 0.01, which weighs the contribution of the pen state loss. This weighting strategy is also adapted from the [SketchKnitter](https://openreview.net/pdf?id=4eJ43EN2g6l) paper.

<b>train_loop</b> <br>
This function orchestrates the entire training process over a specified number of epochs. Within each epoch, it executes two main phases: a training step, where the model's weights are updated using the training data, and a validation step, where the model's performance is evaluated on unseen data to monitor for overfitting.

In [None]:
# Create Training Functions
import torch.nn.functional as F


def rescale_timesteps(t: torch.Tensor,
                      num_timestep: int = 100):
    
    return t.float() * (1000.0 / num_timestep)

def single_step(batch: dict,
                diffusion_model: torch.nn.Module,
                scheduler: torch.nn.Module,
                history_encoder: torch.nn.Module = None,
                pen_condition: bool = True):
    
    stroke = batch["stroke"]
    stroke = stroke.to(DEVICE).float()

    # Get real pen state
    real_pen_state = stroke[:, :, 2:]
    stroke = stroke[:, :, :2]

    if history_encoder is not None:
        stroke_history = batch["stroke_history"].to(DEVICE).float()
        stroke_history_mask = batch["stroke_history_mask"].to(DEVICE).float()

        stroke_history_features = history_encoder(
            stroke_history,
            stroke_history_mask
        )

    # Add noise to the stroke
    noise = torch.randn_like(stroke)
    timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (stroke.shape[0],), device=DEVICE).long()
    noisy_stroke = scheduler.add_noise(stroke, noise, timesteps)

    # Prepare model input
    if pen_condition:
        model_input = torch.cat([noisy_stroke, real_pen_state], dim=-1)
    else:
        model_input = noisy_stroke
    
    # Forward pass through the diffusion model
    predicted_noise, pen_state_out = diffusion_model(
            model_input,
            rescale_timesteps(timesteps),
            context=stroke_history_features if history_encoder is not None else None
        )
    
    # Calculate the loss
    loss_delta = F.mse_loss(
        predicted_noise,
        noise
    )

    if not pen_condition:
        B, C = real_pen_state.shape[:2]
        pen_state_out = pen_state_out.reshape(B * C, 3).type(torch.FloatTensor).to(DEVICE)
        real_pen_state = real_pen_state.reshape(B * C, 3).type(torch.FloatTensor).to(DEVICE)

        loss_pen = F.cross_entropy(
            pen_state_out,
            real_pen_state,
            )
    else:
        loss_pen = 0

    loss = loss_delta + 0.01 * loss_pen
    return loss, loss_delta, loss_pen


In [None]:
from tqdm.auto import tqdm

def train_loop(
      diffusion_model: torch.nn.Module, 
      optimizer: torch.optim.Optimizer, 
      scheduler: torch.optim.lr_scheduler.LambdaLR, 
      scheduler_lr: torch.optim.lr_scheduler.LambdaLR, 
      train_dataloader: torch.utils.data.DataLoader, 
      test_dataloader: torch.utils.data.DataLoader, 
      history_encoder: torch.nn.Module = None,
      pen_condition: bool = True):
  
  # Initialize the gradient scaler for mixed precision training
  scaler = torch.amp.GradScaler("cuda")

  total_step = 0
  total_train_loss = 0.0
  total_test_loss = 0.0
  train_loop = tqdm(range(NUM_STEPS), desc=f"Training")
  while total_step < 50001:
      
    # Training steps
    diffusion_model.train()
    for batch in train_dataloader:
        optimizer.zero_grad()
        with torch.amp.autocast("cuda"):
            loss, loss_delta, loss_pen = single_step(
                batch, diffusion_model, scheduler, pen_condition=pen_condition, history_encoder=history_encoder
            )

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

        total_train_loss += loss.item()
        train_loop.set_postfix(
            train_loss_delta=loss_delta.item(),
            train_loss=loss.item(),
            train_loss_pen=loss_pen.item() if loss_pen != 0 else 0
            )

        total_step += 1
        scheduler_lr.step()

    # Test steps
    diffusion_model.eval()
    with torch.no_grad():
        for batch in test_dataloader:
            test_loss, test_loss_delta, test_loss_pen = single_step(
                batch, diffusion_model, scheduler, pen_condition=pen_condition, history_encoder=history_encoder
                )


        total_test_loss += test_loss.item()
        train_loop.set_postfix(
            train_loss_delta=test_loss_delta.item(),
            train_loss=test_loss.item(),
            train_loss_pen=test_loss_pen.item() if test_loss_pen != 0 else 0
            )

    if total_step % 1000 == 0:
        avg_train_loss = total_train_loss / total_step
        avg_test_loss = total_test_loss / total_step
        print(f"Train Loss = {avg_train_loss:.4f} | Validation Loss = {avg_test_loss:.4f}")

    train_loop.update(1)
train_loop.close()

<b>set_training_parameters</b> <br>
This helper function centralizes the initialization of core training components. It sets up the DDIMScheduler for the diffusion process and the CosineAnnealingLR to manage the learning rate decay over the training run. Crucially, it configures the AdamW optimizer to handle both the unconditional and conditional training scenarios. 

If a history_encoder is provided, the function intelligently combines the parameters from both the diffusion model and the encoder, ensuring that the entire architecture is trained end-to-end. If no encoder is present, it configures the optimizer for the diffusion model alone.

In [None]:
# Set up schedulers and optimizers
from itertools import chain
from diffusers import DDIMScheduler
from torch.optim.lr_scheduler import CosineAnnealingLR


def set_training_parameters(diffusion_model, train_dataloader, history_encoder=None):
    scheduler = DDIMScheduler(
        num_train_timesteps=NUM_TIMESTEPS,
        beta_schedule='linear',
        prediction_type='epsilon'
    )

    if history_encoder is not None:
        optimizer = torch.optim.AdamW(chain(
            diffusion_model.parameters(),
            history_encoder.parameters()
        ), lr=LEARNING_RATE)
    else:
        optimizer = torch.optim.AdamW(diffusion_model.parameters(), lr=LEARNING_RATE)

    scheduler_lr = CosineAnnealingLR(optimizer, T_max=NUM_TIMESTEPS, eta_min=MIN_LEARNING_RATE)

    return scheduler, optimizer, scheduler_lr


### Training and Sampling for Unconditional or Condition Diffusion Transformer

In [None]:
scheduler, optimizer, scheduler_lr = set_training_parameters(
    diffusion_model=diffusion_model,
    train_dataloader=train_dataloader
)

# Unconditionally train the diffusion model
train_loop(
    diffusion_model=diffusion_model,
    optimizer=optimizer,
    scheduler=scheduler,
    scheduler_lr=scheduler_lr,
    train_dataloader=train_dataloader,
    test_dataloader=test_dataloader,
    pen_condition=True 
)

# Contion with pen state train the diffusion model
train_loop(
    diffusion_model=diffusion_model,
    optimizer=optimizer,
    scheduler=scheduler,
    scheduler_lr=scheduler_lr,
    train_dataloader=train_dataloader,
    test_dataloader=test_dataloader,
    pen_condition=True 
)

#### Sampling

This code block executes the full inference pipeline, starting from pure noise and resulting in a final, animated visualization of a generated sketch. The process begins with the dit_sampling function, which initializes a tensor with random Gaussian noise. This tensor is then iteratively refined over 100 denoising steps, where the trained diffusion model predicts the noise at each step, and the scheduler uses that prediction to produce a slightly cleaner version of the sketch data. Once the sampling loop is complete, the raw numerical output is post-processed to convert the pen state logits into clean one-hot vectors. Finally, this complete 5D vector sequence is denormalized and converted back into a raw stroke format, which is then rendered as an animated GIF to display the model's sequential drawing process.

In [None]:
from tqdm.auto import tqdm


def dit_sampling(diffusion_model: torch.nn.Module,
             scheduler: torch.nn.Module,
             timestep: int = 100,
             max_seq_length: int = 96 + 1,
             pen_condition_value: torch.Tensor = None):
    diffusion_model.eval()

    sample_shape = (1, max_seq_length, 2)
    sample = torch.randn(sample_shape, device=DEVICE)
    sample[:, 0, :] = torch.Tensor([0,0])

    if pen_condition_value is None:
        sample = sample

    else:
        sample = torch.cat([sample, pen_condition_value.unsqueeze(-1)], dim=-1)

    scheduler.set_timesteps(timestep)
    for t in tqdm(scheduler.timesteps):
        with torch.no_grad():
            predicted_noise, pen_values = diffusion_model(sample, t.reshape(1,).to(DEVICE))

        sample = scheduler.step(predicted_noise, t, sample).prev_sample

    return sample, pen_values

def postprocess_drawing_torch(pen_state_logits):
    probabilities = torch.softmax(pen_state_logits, dim=-1)
    winner_indices = torch.argmax(probabilities, dim=-1)
    clean_pen_states = F.one_hot(winner_indices, num_classes=3).float()
    return clean_pen_states

In [None]:
from IPython.display import Image
from src.utils.dataset_utils import denormalize, convert_5d_to_raw_format, generate_gif

predicted_delta_values, predicted_pen_values = dit_sampling(
    diffusion_model=diffusion_model,
    scheduler=scheduler,
    timestep=100,
    max_seq_length=96 + 1,
    pen_condition_value=None
)

if predicted_pen_values is not None:
  predicted_pen_values = postprocess_drawing_torch(predicted_pen_values)
  final_output = torch.cat([predicted_delta_values, predicted_pen_values], dim=-1)


denormalized = denormalize(final_output, std)
raw_drawing = convert_5d_to_raw_format(denormalized.detach().cpu().numpy()[0])
generate_gif(raw_drawing, "./test", fps=25)
Image(filename="./test/drawing.gif")

### Training and Sampling for Stroke History based Diffusion Transformer Model

In [None]:
scheduler, optimizer, scheduler_lr = set_training_parameters(
    diffusion_model=diffusion_model_wcrossattn,
    train_dataloader=history_train_dataloader,
    history_encoder=history_encoder
)

# Stroke History based diffusion model training
train_loop(
    diffusion_model=diffusion_model_wcrossattn,
    optimizer=optimizer,
    scheduler=scheduler,
    scheduler_lr=scheduler_lr,
    train_dataloader=history_train_dataloader,
    test_dataloader=history_test_dataloader,
    history_encoder=history_encoder,
    pen_condition=False 
)


#### Sampling

This code demonstrates the complete inference pipeline for Stroke History based Diffusion Transformer Model architecture. Rather than generating the entire sketch at once, this method constructs the drawing sequentially, stroke-by-stroke, to mimic a human-like process. The main for loop iterates for a set number of steps (STEP_COUNT), calling the stroke_history_dit_sampling function in each iteration to generate a single new stroke. The prev_generated_stroke_history variable manages the state of the drawing; it starts empty, is updated with each newly generated stroke, and is then passed as the condition for the next generation step. This process relies on the collaboration of two models: the stroke_history_encoder interprets the current history into a rich context vector, while the cross-attention diffusion_model uses this context to denoise a random signal into the next coherent stroke. Finally, after all strokes are generated, they are concatenated and post-processed to render an animated GIF that visualizes the entire sequential drawing process.

In [None]:
from tqdm.auto import tqdm


def stroke_history_dit_sampling(
    diffusion_model_wcrossattn: torch.nn.Module,
    stroke_history_encoder: torch.nn.Module,
    scheduler: torch.nn.Module,
    timestep: int = 100,
    max_seq_length: int = 96 + 1,
    max_single_stroke_length: int = 44,
    prev_generated_stroke_history = None
    ):
  
    diffusion_model.eval()
    stroke_history_encoder.eval()

    stroke_history_shape = (1, max_seq_length, 5)
    stroke_history = torch.zeros(stroke_history_shape, device=DEVICE)
    stroke_history[:, 0, :] = torch.Tensor([0, 0, 1, 0, 0])
    stroke_history_mask = torch.zeros(stroke_history_shape[:2], device=DEVICE)
    stroke_history_mask[:, 0] = 1

    if prev_generated_stroke_history is not None:
        stroke_history[:, 1:prev_generated_stroke_history.shape[1]+1, :] = prev_generated_stroke_history
        stroke_history_mask[:, 1:prev_generated_stroke_history.shape[1]+1] = 1

    stroke_history_features = stroke_history_encoder(
        stroke_history,
        stroke_history_mask
    )


    sample_stroke_shape = (1, max_single_stroke_length, 2)
    sample = torch.randn(sample_stroke_shape, device=DEVICE)


    scheduler.set_timesteps(timestep)
    for t in tqdm(scheduler.timesteps):
        with torch.no_grad():
            predicted_noise, pen_values = diffusion_model_wcrossattn(
                sample,
                t.reshape(1,).to(DEVICE),
                context=stroke_history_features
              )

        sample = scheduler.step(predicted_noise, t, sample).prev_sample

    return predicted_noise, pen_values

def postprocess_drawing_torch(pen_state_logits):
    probabilities = torch.softmax(pen_state_logits, dim=-1)
    winner_indices = torch.argmax(probabilities, dim=-1)
    clean_pen_states = F.one_hot(winner_indices, num_classes=3).float()
    return clean_pen_states

In [None]:
from IPython.display import Image
from src.utils.dataset_utils import denormalize, convert_5d_to_raw_format, generate_gif


STEP_COUNT = 2
prev_generated_stroke_history = None
for _ in range(STEP_COUNT):
    predicted_delta_values, predicted_pen_values = stroke_history_dit_sampling(
        diffusion_model_wcrossattn=diffusion_model_wcrossattn,
        stroke_history_encoder=history_encoder,
        scheduler=scheduler,
        timestep=100,
        max_seq_length=96 + 1,
        max_single_stroke_length= 44,
        prev_generated_stroke_history = prev_generated_stroke_history
    )

    predicted_pen_values = postprocess_drawing_torch(predicted_pen_values)
    final_output = torch.cat([predicted_delta_values, predicted_pen_values], dim=-1)

    if prev_generated_stroke_history is None:
        prev_generated_stroke_history = final_output
    else:
        prev_generated_stroke_history = torch.cat([prev_generated_stroke_history, final_output], dim=1)


denormalized = denormalize(prev_generated_stroke_history, std)
raw_drawing = convert_5d_to_raw_format(denormalized.detach().cpu().numpy()[0])
generate_gif(raw_drawing, "./test", fps=25)
Image(filename="./test/drawing.gif")

## Evaluation

For quantitative result forthe quality and diversity of the generated sketches, I implemented the function to calculate two standard metrics: <b>Fréchet Inception Distance (FID)</b> and <b>Kernel Inception Distance (KID)</b>. 

The <b>calculate_fid_kid</b> function is designed to measure the distributional similarity between a set of real sketches from the test data and a set of sketches produced by the trained model. The workflow involves first rendering both sets of sketches as images and saving them into separate directories. This function then uses the torch-fidelity library to compare these two image sets. Lower FID and KID scores indicate that the generated sketches are statistically more similar to real sketches.

<b><u>However, this function was not executed for the final evaluation. Since the model did not converge to a state where it produced recognizable sketches, running a comparative analysis like FID/KID would not produce meaningful results.</u></b>

In [None]:
from torch_fidelity import calculate_metrics

def calculate_fid_kid(real_images_dir: str, generated_images_dir: str):    
    metrics_dict = calculate_metrics(
        input1=real_images_dir,
        input2=generated_images_dir,
        cuda=(DEVICE == 'cuda'),
        isc=False,
        fid=True,
        kid=True,
        verbose=False
    )
    
    fid_score = metrics_dict['frechet_inception_distance']
    kid_mean, kid_std = metrics_dict['kernel_inception_distance_mean'], metrics_dict['kernel_inception_distance_std']

    return fid_score, (kid_mean, kid_std)

# Calculate FID and KID scores


# References

- [SketchRNN](https://arxiv.org/abs/1704.03477)
- [SketchKnitter](https://openreview.net/pdf?id=4eJ43EN2g6l)