## Importing the Required Libraries

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import PIL

import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F

from PIL import Image
import io
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas

In [None]:
use_cuda = torch.cuda.is_available()

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
print(use_cuda)

## Defining Required Hyperparameters

In [None]:
class HParams():
    def __init__(self):
        self.data_location = '/content/sketchrnn_apple.npz'
        self.enc_hidden_size = 256
        self.dec_hidden_size = 512
        self.Nz = 128
        self.M = 20
        self.dropout = 0.9
        self.batch_size = 100
        self.eta_min = 0.01
        self.R = 0.99995
        self.KL_min = 0.2
        self.wKL = 0.5
        self.lr = 0.001
        self.lr_decay = 0.9999
        self.min_lr = 0.00001
        self.grad_clip = 1.
        self.temperature = 0.4
        self.max_seq_length = 200

In [None]:
hp = HParams()

## Preprocessing & Loading Dataset

In [None]:
def max_size(data):
    """larger sequence length in the data set"""
    sizes = [len(seq) for seq in data]
    return max(sizes)

def purify(strokes):
    """removes to small or too long sequences + removes large gaps"""
    data = []
    for seq in strokes:
        if seq.shape[0] <= hp.max_seq_length and seq.shape[0] > 10:
            seq = np.minimum(seq, 1000)
            seq = np.maximum(seq, -1000)
            seq = np.array(seq, dtype=np.float32)
            data.append(seq)
    return data

def calculate_normalizing_scale_factor(strokes):
    """Calculate the normalizing factor explained in appendix of sketch-rnn."""
    data = []
    for i in range(len(strokes)):
        for j in range(len(strokes[i])):
            data.append(strokes[i][j, 0])
            data.append(strokes[i][j, 1])
    data = np.array(data)
    return np.std(data)

def normalize(strokes):
    """Normalize entire dataset (delta_x, delta_y) by the scaling factor."""
    data = []
    scale_factor = calculate_normalizing_scale_factor(strokes)
    for seq in strokes:
        seq[:, 0:2] /= scale_factor
        data.append(seq)
    return data

In [None]:
dataset = np.load(hp.data_location, encoding='latin1',allow_pickle=True)
data = dataset['train']
data = purify(data)
data = normalize(data)
Nmax = max_size(data)

In [None]:
def make_batch(batch_size):
    batch_idx = np.random.choice(len(data), batch_size)
    batch_sequences = [data[idx] for idx in batch_idx]
    strokes = []
    lengths = []

    for seq in batch_sequences:
        len_seq = len(seq[:, 0])
        new_seq = np.zeros((Nmax, 5))
        new_seq[:len_seq, :2] = seq[:, :2]
        new_seq[:len_seq - 1, 2] = 1 - seq[:-1, 2]
        new_seq[:len_seq, 3] = seq[:, 2]
        new_seq[len_seq - 1:, 4] = 1
        new_seq[len_seq - 1, 2:4] = 0
        lengths.append(len(seq[:, 0]))
        strokes.append(new_seq)

    # Stack the sequences into a tensor
    batch = torch.from_numpy(np.stack(strokes, axis=1)).float()

    # Move tensor to GPU if CUDA is available
    if use_cuda:
        batch = batch.cuda()

    return batch, lengths

In [None]:
def lr_decay(optimizer):
    """Decay learning rate by a factor of lr_decay"""
    for param_group in optimizer.param_groups:
        if param_group['lr']>hp.min_lr:
            param_group['lr'] *= hp.lr_decay
    return optimizer

## Encoder & Decoder RNN

In [None]:
class EncoderRNN(nn.Module):
    def __init__(self):
        super(EncoderRNN, self).__init__()
        # bidirectional lstm:
        self.lstm = nn.LSTM(5, hp.enc_hidden_size, \
            dropout=hp.dropout, bidirectional=True)
        # create mu and sigma from lstm's last output:
        self.fc_mu = nn.Linear(2*hp.enc_hidden_size, hp.Nz)
        self.fc_sigma = nn.Linear(2*hp.enc_hidden_size, hp.Nz)
        # active dropout:
        self.train()

    def forward(self, inputs, batch_size, hidden_cell=None):
        if hidden_cell is None:
            # then must init with zeros
            if use_cuda:
                hidden = torch.zeros(2, batch_size, hp.enc_hidden_size).cuda()
                cell = torch.zeros(2, batch_size, hp.enc_hidden_size).cuda()
            else:
                hidden = torch.zeros(2, batch_size, hp.enc_hidden_size)
                cell = torch.zeros(2, batch_size, hp.enc_hidden_size)
            hidden_cell = (hidden, cell)
        _, (hidden,cell) = self.lstm(inputs.float(), hidden_cell)
        # hidden is (2, batch_size, hidden_size), we want (batch_size, 2*hidden_size)
        hidden_forward, hidden_backward = torch.split(hidden,1,0)
        hidden_cat = torch.cat([hidden_forward.squeeze(0), hidden_backward.squeeze(0)],1)
        # mu and sigma:
        mu = self.fc_mu(hidden_cat)
        sigma_hat = self.fc_sigma(hidden_cat)
        sigma = torch.exp(sigma_hat/2.)
        # N ~ N(0,1)
        z_size = mu.size()
        if use_cuda:
            N = torch.normal(torch.zeros(z_size),torch.ones(z_size)).cuda()
        else:
            N = torch.normal(torch.zeros(z_size),torch.ones(z_size))
        z = mu + sigma*N
        # mu and sigma_hat are needed for LKL loss
        return z, mu, sigma_hat

In [None]:
class DecoderRNN(nn.Module):
    def __init__(self):
        super(DecoderRNN, self).__init__()
        # to init hidden and cell from z:
        self.fc_hc = nn.Linear(hp.Nz, 2*hp.dec_hidden_size)
        # unidirectional lstm:
        self.lstm = nn.LSTM(hp.Nz+5, hp.dec_hidden_size, dropout=hp.dropout)
        # create proba distribution parameters from hiddens:
        self.fc_params = nn.Linear(hp.dec_hidden_size,6*hp.M+3)

    def forward(self, inputs, z, hidden_cell=None):
        if hidden_cell is None:
            # then we must init from z
            hidden,cell = torch.split(F.tanh(self.fc_hc(z)),hp.dec_hidden_size,1)
            hidden_cell = (hidden.unsqueeze(0).contiguous(), cell.unsqueeze(0).contiguous())
        outputs,(hidden,cell) = self.lstm(inputs, hidden_cell)
        # in training we feed the lstm with the whole input in one shot
        # and use all outputs contained in 'outputs', while in generate
        # mode we just feed with the last generated sample:
        if self.training:
            y = self.fc_params(outputs.view(-1, hp.dec_hidden_size))
        else:
            y = self.fc_params(hidden.view(-1, hp.dec_hidden_size))
        # separate pen and mixture params:
        params = torch.split(y,6,1)
        params_mixture = torch.stack(params[:-1]) # trajectory
        params_pen = params[-1] # pen up/down
        # identify mixture params:
        pi,mu_x,mu_y,sigma_x,sigma_y,rho_xy = torch.split(params_mixture,1,2)
        # preprocess params::
        if self.training:
            len_out = Nmax+1
        else:
            len_out = 1

        pi = F.softmax(pi.transpose(0,1).squeeze(), dim=-1).view(len_out,-1,hp.M)
        sigma_x = torch.exp(sigma_x.transpose(0,1).squeeze()).view(len_out,-1,hp.M)
        sigma_y = torch.exp(sigma_y.transpose(0,1).squeeze()).view(len_out,-1,hp.M)
        rho_xy = torch.tanh(rho_xy.transpose(0,1).squeeze()).view(len_out,-1,hp.M)
        mu_x = mu_x.transpose(0,1).squeeze().contiguous().view(len_out,-1,hp.M)
        mu_y = mu_y.transpose(0,1).squeeze().contiguous().view(len_out,-1,hp.M)
        q = F.softmax(params_pen, dim=-1).view(len_out,-1,3)

        return pi,mu_x,mu_y,sigma_x,sigma_y,rho_xy,q,hidden,cell


## Complete Model & Training

In [None]:
class Model():
    def __init__(self):
        self.device = torch.device("cuda" if use_cuda else "cpu")
        self.encoder = EncoderRNN().to(self.device)
        self.decoder = DecoderRNN().to(self.device)

        self.encoder_optimizer = optim.Adam(self.encoder.parameters(), hp.lr)
        self.decoder_optimizer = optim.Adam(self.decoder.parameters(), hp.lr)
        self.eta_step = hp.eta_min

    def make_target(self, batch, lengths):
        batch_size = batch.size(1)  # Get batch size
        eos = torch.tensor([0, 0, 0, 0, 1], device=self.device).repeat(batch_size, 1).unsqueeze(0)
        batch = torch.cat([batch, eos], dim=0)  # Append EOS token

         # Move tensors to the same device
        mask = torch.zeros(Nmax + 1, batch_size, device=self.device)  # Ensure mask is on the right device
        lengths_tensor = torch.tensor(lengths, device=self.device)  # Move lengths to the same device
        indices = torch.arange(Nmax + 1, device=self.device).unsqueeze(1)  # Move arange tensor to the right device
        max_len = min(Nmax + 1, max(lengths))  # Ensure max length doesn't exceed Nmax+1
        mask[:max_len, torch.arange(batch_size, device=self.device)] = (indices[:max_len] < lengths_tensor).float()


        # Extract components efficiently
        dx = batch[:, :, 0].unsqueeze(-1).expand(-1, -1, hp.M)
        dy = batch[:, :, 1].unsqueeze(-1).expand(-1, -1, hp.M)
        p = batch[:, :, 2:5]  # Direct slicing instead of stacking

        return mask, dx, dy, p

    def train(self, epoch):
      self.encoder.train()
      self.decoder.train()

    # Load batch data
      batch, lengths = make_batch(hp.batch_size)
      batch = batch.to(self.device)  # Move batch to correct device

    # Encode
      z, self.mu, self.sigma = self.encoder(batch, hp.batch_size)

    # Create start of sequence (sos)
      sos = torch.tensor([0, 0, 1, 0, 0], device=self.device).repeat(hp.batch_size, 1).unsqueeze(0)

    # Concatenate sos with batch
      batch_init = torch.cat([sos, batch], dim=0)

    # Expand `z` for concatenation
      z_stack = z.unsqueeze(0).expand(Nmax + 1, -1, -1)

    # Concatenate inputs
      inputs = torch.cat([batch_init, z_stack], dim=2)

    # Decode
      self.pi, self.mu_x, self.mu_y, self.sigma_x, self.sigma_y, \
        self.rho_xy, self.q, _, _ = self.decoder(inputs, z)

    # Prepare targets
      mask, dx, dy, p = self.make_target(batch, lengths)

    # Zero gradients
      self.encoder_optimizer.zero_grad()
      self.decoder_optimizer.zero_grad()

    # Update eta for KL loss
      self.eta_step = 1 - (1 - hp.eta_min) * hp.R

    # Compute losses
      LKL = self.kullback_leibler_loss()
      LR = self.reconstruction_loss(mask, dx, dy, p, epoch)
      loss = LR + LKL

    # Backpropagation
      loss.backward()

    # Gradient Clipping
      nn.utils.clip_grad_norm_(self.encoder.parameters(), hp.grad_clip)
      nn.utils.clip_grad_norm_(self.decoder.parameters(), hp.grad_clip)

    # Optimization Step
      self.encoder_optimizer.step()
      self.decoder_optimizer.step()

    # Logging and Learning Rate Decay
      if epoch % 1 == 0:
        print(f"Epoch {epoch} | Loss: {loss.item():.4f} | LR: {LR.item():.4f} | LKL: {LKL.item():.4f}")
        self.encoder_optimizer = lr_decay(self.encoder_optimizer)
        self.decoder_optimizer = lr_decay(self.decoder_optimizer)

    # Conditional Generation and Saving
      if epoch % 100 == 0:
        self.conditional_generation(epoch)


    def bivariate_normal_pdf(self, dx, dy):
        z_x = ((dx-self.mu_x)/self.sigma_x)**2
        z_y = ((dy-self.mu_y)/self.sigma_y)**2
        z_xy = (dx-self.mu_x)*(dy-self.mu_y)/(self.sigma_x*self.sigma_y)
        z = z_x + z_y -2*self.rho_xy*z_xy
        exp = torch.exp(-z/(2*(1-self.rho_xy**2)))
        norm = 2*np.pi*self.sigma_x*self.sigma_y*torch.sqrt(1-self.rho_xy**2)
        return exp/norm

    def reconstruction_loss(self, mask, dx, dy, p, epoch):
        pdf = self.bivariate_normal_pdf(dx, dy)
        LS = -torch.sum(mask*torch.log(1e-5+torch.sum(self.pi * pdf, 2)))\
            /float(Nmax*hp.batch_size)
        LP = -torch.sum(p*torch.log(self.q))/float(Nmax*hp.batch_size)
        return LS+LP

    def kullback_leibler_loss(self):
        LKL = -0.5*torch.sum(1+self.sigma-self.mu**2-torch.exp(self.sigma))\
            /float(hp.Nz*hp.batch_size)
        if use_cuda:
            KL_min = torch.Tensor([hp.KL_min]).cuda().detach()
        else:
            KL_min = torch.Tensor([hp.KL_min]).detach()
        return hp.wKL*self.eta_step * torch.max(LKL,KL_min)

    def save(self, epoch):
        sel = np.random.rand()
        torch.save(self.encoder.state_dict(), \
            'encoderRNN_sel_%3f_epoch_%d.pth' % (sel,epoch))
        torch.save(self.decoder.state_dict(), \
            'decoderRNN_sel_%3f_epoch_%d.pth' % (sel,epoch))

    def load(self, encoder_name, decoder_name):
        saved_encoder = torch.load(encoder_name)
        saved_decoder = torch.load(decoder_name)
        self.encoder.load_state_dict(saved_encoder)
        self.decoder.load_state_dict(saved_decoder)

    def conditional_generation(self, epoch):
        batch,lengths = make_batch(1)
        # should remove dropouts:
        self.encoder.train(False)
        self.decoder.train(False)
        # encode:
        z, _, _ = self.encoder(batch, 1)
        if use_cuda:
            sos = torch.Tensor([0,0,1,0,0]).view(1,1,-1).cuda()
        else:
            sos = torch.Tensor([0,0,1,0,0]).view(1,1,-1)
        s = sos
        seq_x = []
        seq_y = []
        seq_z = []
        hidden_cell = None
        for i in range(Nmax):
            input = torch.cat([s,z.unsqueeze(0)],2)
            # decode:
            self.pi, self.mu_x, self.mu_y, self.sigma_x, self.sigma_y, \
                self.rho_xy, self.q, hidden, cell = \
                    self.decoder(input, z, hidden_cell)
            hidden_cell = (hidden, cell)
            # sample from parameters:
            s, dx, dy, pen_down, eos = self.sample_next_state()
            #------
            seq_x.append(dx)
            seq_y.append(dy)
            seq_z.append(pen_down)
            if eos:
                print(i)
                break
        # visualize result:
        x_sample = np.cumsum(seq_x, 0)
        y_sample = np.cumsum(seq_y, 0)
        z_sample = np.array(seq_z)
        sequence = np.stack([x_sample,y_sample,z_sample]).T
        make_image(sequence, epoch)

    def sample_next_state(self):
      def adjust_temp(pi_pdf):
        pi_pdf = np.log(pi_pdf)/hp.temperature
        pi_pdf -= pi_pdf.max()
        pi_pdf = np.exp(pi_pdf)
        pi_pdf /= pi_pdf.sum()
        return pi_pdf

    # get mixture indice:
      pi = self.pi.data[0,0,:].cpu().numpy()  # Added .cpu() before .numpy()
      pi = adjust_temp(pi)
      pi_idx = np.random.choice(hp.M, p=pi)
    # get pen state:
      q = self.q.data[0,0,:].cpu().numpy()  # Added .cpu() before .numpy()
      q = adjust_temp(q)
      q_idx = np.random.choice(3, p=q)
    # get mixture params:
      mu_x = self.mu_x.data[0,0,pi_idx].cpu().numpy()  # Added .cpu().numpy()
      mu_y = self.mu_y.data[0,0,pi_idx].cpu().numpy()  # Added .cpu().numpy()
      sigma_x = self.sigma_x.data[0,0,pi_idx].cpu().numpy()  # Added .cpu().numpy()
      sigma_y = self.sigma_y.data[0,0,pi_idx].cpu().numpy()  # Added .cpu().numpy()
      rho_xy = self.rho_xy.data[0,0,pi_idx].cpu().numpy()  # Added .cpu().numpy()
      x,y = sample_bivariate_normal(mu_x,mu_y,sigma_x,sigma_y,rho_xy,greedy=False)
      next_state = torch.zeros(5)
      next_state[0] = x
      next_state[1] = y
      next_state[q_idx+2] = 1
      if use_cuda:
        return next_state.cuda().view(1,1,-1),x,y,q_idx==1,q_idx==2
      else:
        return next_state.view(1,1,-1),x,y,q_idx==1,q_idx==2

## Some Miscellaneous Functions

In [None]:
def sample_bivariate_normal(mu_x, mu_y, sigma_x, sigma_y, rho_xy, greedy=False):
    # inputs are now numpy values, not tensors
    if greedy:
        return mu_x, mu_y
    mean = [mu_x, mu_y]
    sigma_x *= np.sqrt(hp.temperature)
    sigma_y *= np.sqrt(hp.temperature)
    cov = [[sigma_x * sigma_x, rho_xy * sigma_x * sigma_y],
           [rho_xy * sigma_x * sigma_y, sigma_y * sigma_y]]
    x = np.random.multivariate_normal(mean, cov, 1)
    return x[0][0], x[0][1]

In [None]:
def make_image(sequence, epoch, name='_output_'):
    """plot drawing with separated strokes"""
    strokes = np.split(sequence, np.where(sequence[:,2]>0)[0]+1)
    fig = plt.figure()
    ax1 = fig.add_subplot(111)
    for s in strokes:
        plt.plot(s[:,0],-s[:,1])
    canvas = plt.get_current_fig_manager().canvas
    canvas.draw()

    # Updated method to get image data from canvas
    # In newer matplotlib versions, tostring_rgb() is replaced with buffer_rgba()
    # Then we need to convert RGBA to RGB
    width, height = canvas.get_width_height()

    # Try the newer method first
    try:
        buffer = canvas.buffer_rgba()
        image_array = np.asarray(buffer)
        # Convert RGBA to RGB
        pil_image = PIL.Image.fromarray(image_array[:, :, :3])
    except:
        # Fallback for older versions, try different methods
        try:
            buffer = canvas.tostring_rgb()
            pil_image = PIL.Image.frombytes('RGB', (width, height), buffer)
        except:
            # Last resort
            plt.savefig(f"{epoch}{name}_temp.png")
            pil_image = PIL.Image.open(f"{epoch}{name}_temp.png")

    # Save the image
    name = str(epoch) + name + '.jpg'
    pil_image.save(name, "JPEG")
    plt.close("all")

In [None]:
def save_model(model, save_dir='./saved_models', model_name='sketchrnn_apple'):
    """
    Save the trained model weights to disk.

    Args:
        model: The trained SketchRNN model
        save_dir: Directory to save the model files
        model_name: Base name for the saved model files

    Returns:
        tuple: Paths to the saved encoder and decoder files
    """
    import os

    # Create directory if it doesn't exist
    os.makedirs(save_dir, exist_ok=True)

    # Generate timestamp for the filename
    import time
    timestamp = time.strftime("%Y%m%d_%H%M%S")

    # Create filenames
    encoder_path = os.path.join(save_dir, f"{model_name}_encoder_{timestamp}.pth")
    decoder_path = os.path.join(save_dir, f"{model_name}_decoder_{timestamp}.pth")

    # Save model state dictionaries
    torch.save(model.encoder.state_dict(), encoder_path)
    torch.save(model.decoder.state_dict(), decoder_path)

    print(f"Model saved successfully:")
    print(f"- Encoder: {encoder_path}")
    print(f"- Decoder: {decoder_path}")

    # Also save a config file with hyperparameters
    config_path = os.path.join(save_dir, f"{model_name}_config_{timestamp}.txt")
    with open(config_path, 'w') as f:
        for param, value in vars(hp).items():
            f.write(f"{param}: {value}\n")

    print(f"- Config: {config_path}")

    return encoder_path, decoder_path

## Sketch Generation Function

In [None]:
def generate_drawing_animation(model, num_frames=None, save_path='drawing_animation.gif', temperature=0.4):
    """
    Generates a GIF animation showing the stroke-by-stroke drawing process from the trained model.

    Args:
        model: The trained SketchRNN model
        num_frames: Maximum number of frames to include (None = all)
        save_path: Path to save the output GIF
        temperature: Temperature for sampling (lower = more deterministic)
    """

    # Set model to evaluation mode
    model.encoder.train(False)
    model.decoder.train(False)

    # Get a random batch
    batch, lengths = make_batch(1)

    # Encode
    z, _, _ = model.encoder(batch, 1)

    # Setup for decoding
    if use_cuda:
        sos = torch.Tensor([0, 0, 1, 0, 0]).view(1, 1, -1).cuda()
    else:
        sos = torch.Tensor([0, 0, 1, 0, 0]).view(1, 1, -1)

    s = sos
    seq_x = []
    seq_y = []
    seq_z = []
    hidden_cell = None

    # Store the original temperature
    original_temp = hp.temperature
    # Set the temperature for generation
    hp.temperature = temperature

    # Generate the sequence
    for i in range(Nmax):
        input = torch.cat([s, z.unsqueeze(0)], 2)
        # decode:
        model.pi, model.mu_x, model.mu_y, model.sigma_x, model.sigma_y, \
            model.rho_xy, model.q, hidden, cell = \
                model.decoder(input, z, hidden_cell)
        hidden_cell = (hidden, cell)
        # sample from parameters:
        s, dx, dy, pen_down, eos = model.sample_next_state()

        seq_x.append(dx)
        seq_y.append(dy)
        seq_z.append(pen_down)

        if eos:
            print(f"Drawing completed in {i} steps")
            break

    # Cumulative sum for coordinates
    x_sample = np.cumsum(seq_x, 0)
    y_sample = np.cumsum(seq_y, 0)
    z_sample = np.array(seq_z)

    # Limit the number of frames if specified
    total_frames = len(x_sample)
    if num_frames is not None:
        if num_frames < total_frames:
            # Take evenly spaced frames
            indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
            x_sample = x_sample[indices]
            y_sample = y_sample[indices]
            z_sample = z_sample[indices]
            total_frames = num_frames

    # Create frames
    frames = []
    print(f"Creating animation with {total_frames} frames...")

    fig = plt.figure(figsize=(7, 7), dpi=100)
    ax = fig.add_subplot(111)

    # Find min and max for stable axes
    x_min, x_max = min(x_sample), max(x_sample)
    y_min, y_max = min(-y_sample), max(-y_sample)

    # Add some padding
    x_padding = (x_max - x_min) * 0.1
    y_padding = (y_max - y_min) * 0.1

    for i in range(total_frames):
        # Clear the axis for each frame
        ax.clear()

        # Get stroke end indices
        stroke_ends = np.where(z_sample[:i+1])[0] + 1

        # Split the sequence by strokes
        if len(stroke_ends) > 0:
            last_end = 0
            for end in stroke_ends:
                if end > i:
                    end = i + 1
                # Plot each stroke separately
                if end > last_end:
                    ax.plot(x_sample[last_end:end], -y_sample[last_end:end], 'k-', linewidth=2)
                last_end = end

            # If there's remaining points after the last complete stroke
            if last_end <= i:
                ax.plot(x_sample[last_end:i+1], -y_sample[last_end:i+1], 'k-', linewidth=2)
        else:
            # No complete strokes yet
            ax.plot(x_sample[:i+1], -y_sample[:i+1], 'k-', linewidth=2)

        # Set consistent axes limits for stable animation
        margin_x = (x_max - x_min) * 0.2
        margin_y = (y_max - y_min) * 0.2
        ax.set_xlim(x_min - margin_x, x_max + margin_x)
        ax.set_ylim(y_min - margin_y, y_max + margin_y)

        # Remove axes for cleaner visualization
        ax.axis('off')

        # Title showing progress
        ax.set_title(f"Frame {i+1}/{total_frames}")

        # Create canvas for saving
        canvas = FigureCanvas(fig)
        canvas.draw()

        # Convert canvas to image
        buf = io.BytesIO()
        fig.savefig(buf, format='png', bbox_inches='tight', pad_inches=0.1)
        buf.seek(0)
        img = Image.open(buf)
        frames.append(img.copy())
        buf.close()

    # Save as GIF
    duration = 1000 / 30  # 30 fps
    frames[0].save(
        save_path,
        format='GIF',
        append_images=frames[1:],
        save_all=True,
        duration=duration,
        loop=0
    )

    # Restore original temperature
    hp.temperature = original_temp

    plt.close(fig)
    print(f"Animation saved to {save_path}")

    return save_path

## Sketch Generation by class

In [None]:
def load_model_by_class(class_name):
    """
    Loads a SketchRNN model for the given class by loading the corresponding encoder and decoder files.

    Args:
        class_name (str): The class name. Expected values include: 'cat', 'dog', 'airplane', 'apple', 'book'.

    Returns:
        Model: An instance of Model with loaded encoder and decoder.
    """
    # Construct file paths for the encoder and decoder based on the class name
    encoder_file = f"sketchrnn_{class_name}_encoder.pth"
    decoder_file = f"sketchrnn_{class_name}_decoder.pth"

    # Instantiate your Model (which internally creates encoder and decoder)
    model = Model()

    # Load the saved weights into the model's encoder and decoder
    model.load(encoder_file, decoder_file)

    return model

def generate_sketch_by_class(class_name, num_frames=None, save_path=None, temperature=0.4):
    """
    Generates a sketch for the given class by loading the corresponding model files and then calling
    the pre-existing generate_drawing_animation function.

    Args:
        class_name (str): The class name. Expected: 'cat', 'dog', 'airplane', 'apple', 'book'.
        num_frames (int, optional): Maximum number of frames to include in the animation.
        save_path (str, optional): File path to save the output GIF. If not provided, a default filename will be used.
        temperature (float, optional): Temperature for sampling (lower values yield more deterministic output).

    Returns:
        str: The path where the generated animation GIF is saved.
    """
    # Load the appropriate model for the given class
    model = load_model_by_class(class_name)

    # Use a default save path if none is provided
    if save_path is None:
        save_path = f"{class_name}_sketch.gif"

    # Call your existing drawing animation function with the loaded model
    return generate_drawing_animation(model, num_frames=num_frames, save_path=save_path, temperature=temperature)

In [None]:
generate_sketch_by_class('book')