# Dataset declaration
Custom datasets loads the BRUSH samples in memory, each signal representing a sentence.

In [None]:
import os
from torch.utils.data import Dataset
import numpy as np
import cv2
import math
import pickle

class BrushDataset(Dataset):
    """The BRUSH dataset object is used to retrieve all samples from the BRUSH dataset
    This dataset retrives them as:
    Sample  = Offline image
    Label   = Online signal
    The loading of the dataset is special, as we retrieve only the label in memory and build
    the images at runtime.
    """
    brush_root: str
    display_stats: bool
    save_to_file: bool

    DRAW_COLOR_WHITE = 0
    DRAW_COLOR_BLACK = 255
    DRAW_COLOR_SIZE = 1

    signals: list[list]
    sentences: list[str]
    images: list

    size: int

    def __init__(self, brush_root, display_stats: bool = False, save_to_file: bool = True):
        self.brush_root = brush_root
        self.display_stats = display_stats
        self.save_to_file = save_to_file

        self.signals = []
        self.sentences = []
        self.images = []

        self.load_samples()

    #Override
    def __len__(self):
        return len(self.images)

    #Override
    def __getitem__(self, idx):
        if idx >= self.size:
            raise Exception(f"Invalid index: Dataset of size {self.size} has no item at index {idx}")
        image = self.images[idx]
        label = self.signals[idx]
        return image, label

    def load_samples(self):
        """This function loads the samples forn disk, creatin the offline image in the process"""
        try:
            writer_ids = os.listdir(self.brush_root)
            total_writers = len(writer_ids)
            if self.display_stats:
                print(f"Loading {total_writers} writers")
        except Exception as e:
            print(f"Impossible to read folder {self.brush_root}")
            raise e
            
        i = 0
        for writer_id in writer_ids:
            i+=1
            writer_path = os.path.join(self.brush_root, writer_id)
            #Each drawin is present in three examplaries: n, n_resample20 and n_resample25
            #base dataloader selects default (10ms)
            drawing_ids = [name for name in os.listdir(writer_path) if "_" not in name and ".npy" not in name]

            if self.display_stats:
                print(f"{i}/{total_writers}: Detected {len(drawing_ids)} drawings")

            for drawing_id in drawing_ids:
                signal_path = os.path.join(writer_path, drawing_id)
                image_path = os.path.join(writer_path, f"{drawing_id}.npy")

                sentence, signal, char_label = self.load_signal(signal_path)
                self.signals.append(signal)
                self.sentences.append(sentence)
                
                if os.path.isfile(image_path):
                    with open(image_path, "rb") as f:
                        image = np.load(f)
                else:
                    image = self.create_image(signal)
                    if self.save_to_file:
                        with open(image_path, "wb") as f:
                            np.save(f, image)

                self.images.append(self.create_image(signal))
                

        self.size = len(self.images)
        if self.display_stats:
            print(f"Loaded {self.size} data points")
    
    def load_signal(self, filepath: str) -> tuple[str, list, list]:
        """Load an online sinal from a filepath
        Args
        -----
            Filepath: The name of the file to load from
            
        Returns
        -----
            - str: Written sentence as string
            - list: Signal of x, y, penUp
            - list: List of one-hot vectors with same length as signal identifying charachter of point
        """
        with open(filepath, 'rb') as f:
            [sentence, signal, label] = pickle.load(f)
        
        return sentence, signal, label
    

    def create_image(self, signal: list[int, int, bool]):
        """Create the image associated with the given signal."""
        max_h =  int(math.ceil(max(signal[:, 0])))
        max_w = int(math.ceil(max(signal[:, 1])))

        canvas = np.ascontiguousarray(np.full((max_w, max_h), self.DRAW_COLOR_BLACK), dtype=np.uint8)

        #Draw lines from point (t-1) to current point (t) IFF the pen was not up. start with penup
        #as we start from point 0.
        draw_current_stroke = False
        for x, y, eos in signal:
            x,y = int(x), int(y)
            if draw_current_stroke:
                cv2.line(canvas, (last_x, last_y), (x, y), self.DRAW_COLOR_WHITE, self.DRAW_COLOR_SIZE) 
            last_x, last_y, draw_current_stroke = x, y, not eos
        
        return canvas

In [None]:
def del_all_imgs():
    
    brush_root = "../../data/handwriting/refined_BRUSH/BRUSH"
    writer_ids = os.listdir(brush_root)

    for writer_id in writer_ids:
        writer_path = os.path.join(brush_root, writer_id)
        #Each drawin is present in three examplaries: n, n_resample20 and n_resample25
        #base dataloader selects default (10ms)
        img_IDs = [name for name in os.listdir(writer_path) if ".npy" in name]

        for img_ID in img_IDs:
            full_path = os.path.join(writer_path, img_ID)
            print(full_path)
            os.remove(full_path) 

del_all_imgs()

# Test dataset

In [27]:
BRUSH_ROOT = "../../data/handwriting/refined_BRUSH/BRUSH"

import time
start_time = time.time()
dataset = BrushDataset(brush_root=BRUSH_ROOT, display_stats=True, save_to_file=False)
print("--- %s seconds ---" % (time.time() - start_time))

print(f"Size of dataset: {dataset.size}")

Loading 170 writers
1/170: Detected 163 drawings
2/170: Detected 163 drawings
3/170: Detected 162 drawings
4/170: Detected 162 drawings
5/170: Detected 164 drawings
6/170: Detected 160 drawings
7/170: Detected 164 drawings
8/170: Detected 164 drawings
9/170: Detected 162 drawings
10/170: Detected 159 drawings
11/170: Detected 159 drawings
12/170: Detected 155 drawings
13/170: Detected 165 drawings
14/170: Detected 165 drawings
15/170: Detected 164 drawings
16/170: Detected 164 drawings
17/170: Detected 163 drawings
18/170: Detected 163 drawings
19/170: Detected 165 drawings
20/170: Detected 166 drawings
21/170: Detected 164 drawings
22/170: Detected 166 drawings
23/170: Detected 163 drawings
24/170: Detected 162 drawings
25/170: Detected 163 drawings
26/170: Detected 165 drawings
27/170: Detected 165 drawings
28/170: Detected 165 drawings
29/170: Detected 165 drawings
30/170: Detected 165 drawings
31/170: Detected 165 drawings
32/170: Detected 161 drawings
33/170: Detected 165 drawings

In [None]:
BRUSH_ROOT = "../../data/handwriting/refined_BRUSH/BRUSH"

import time
start_time = time.time()
dataset = BrushDataset(brush_root=BRUSH_ROOT, display_stats=True, save_to_file=True)
print("--- %s seconds ---" % (time.time() - start_time))

print(f"Size of dataset: {dataset.size}")

In [None]:
BRUSH_ROOT = "../../data/handwriting/refined_BRUSH/BRUSH"

import time
start_time = time.time()
dataset = BrushDataset(brush_root=BRUSH_ROOT, display_stats=True, save_to_file=False)
print("--- %s seconds ---" % (time.time() - start_time))

print(f"Size of dataset: {dataset.size}")

Results of dataset creation
- Building images from signals:
- Building and saving images to file:
- Loading from gile:

In [None]:
import random
import matplotlib.pyplot as plt

n = random.randint(0, dataset.size)

image, signal = dataset[n]

print(f"Sentence: {dataset.sentences[n]}")

plt.imshow(image)
plt.show()