# Sketchy

In [None]:
import matplotlib.pyplot as plt
from sketchy.sketchy_dataset import SketchyDataset
from torch.utils.data import DataLoader
from tqdm import tqdm

In [None]:
dataset_root = "<path/to/sketchy/root>"
split = "train"
img_size = 512
load_img = True
load_global_sketch = True
load_local_sketches = True
with_shoes = False
concat_locals = True
compose_global_sketch = True
img_transforms = None
global_sketch_transforms = None


sketchy_dataset = SketchyDataset(dataset_root=dataset_root, 
                              split=split, 
                              img_size=img_size, 
                              load_img=load_img, 
                              load_global_sketch=load_global_sketch,
                              load_local_sketch=load_local_sketches,
                              img_transforms=img_transforms,
                              global_sketch_transforms=global_sketch_transforms,
                              with_shoes=with_shoes,
                              concat_locals=concat_locals,
                              compose_global_sketch=compose_global_sketch,
                              )
print(f"Number of images in {split} split: {len(sketchy_dataset)}")

In [None]:
# create a dataloader with the proper collate function
dataloader = DataLoader(sketchy_dataset,
                        batch_size=8, 
                        shuffle=False, 
                        num_workers=0, 
                        collate_fn=sketchy_dataset.collate_fn)

## Visualize the item data

In [None]:
# get a sample from the dataset
item = sketchy_dataset[13]
print("####### ITEM KEYS ########")
for key in item.keys():
    print(f"{key}")
    
print("\n####### IMAGE ########")
# item['image'] is an image (by default PIL.Image)
plt.imshow(item['image'])
plt.axis('off')
plt.title("GT Image of item")
plt.show()


print("\n####### LOCAL DESCRIPTIONS ########")
# item['local_descriptions'] is a list of strings. Each string is a description of a single item in the image.
# NOTE: the local descriptions, local sketches, and masks are all aligned, meaning that the i-th local description corresponds to the i-th mask and i-th local sketch.
num_descriptions = len(item['local_descriptions'])
print(f"Number of local descriptions in item: {num_descriptions}")
for i, desc in enumerate(item['local_descriptions']):
    print(f"Local description {i}: {desc}")

print("\n####### GLOBAL SKETCH ########")
# item['global_sketch'] is an image
# visualize the global sketch
plt.imshow(item['global_sketch'])
plt.axis('off')
plt.title("Global Sketch")
plt.show()

print("\n####### LOCAL SKETCHES ########")
# item['local_sketches'] is a list of images. Each item in the item has a list. In each sublist, there is an image for each local sketch in the item.
num_local_sketches = len(item['local_sketches'])
assert num_local_sketches == num_descriptions, "Number of local sketches will always be equal to number of local descriptions"
print(f"Number of local sketches in item 0: {num_local_sketches}")
# visualize the local sketches
MAX_NUM_COLUMNS = 2
num_cols = min(num_local_sketches, MAX_NUM_COLUMNS)
num_rows = num_local_sketches // num_cols + (num_local_sketches % num_cols > 0)
fig, axs = plt.subplots(num_rows, num_cols, figsize=(5, 5))
if num_local_sketches > 1:
    # flatten the axs for easier indexing
    axs = axs.flatten()
    for i in range(len(item['local_sketches'])):
        axs[i].imshow(item['local_sketches'][i])
        axs[i].set_title(f"Local Sketch {i}")
        axs[i].axis('off')
else:
    axs.imshow(item['local_sketches'][0])
    axs.set_title(f"Local Sketch 0")
    axs.axis('off')

In [None]:
# iterate over the dataloader.
# NOTE: this changes how the data is structured due to the collate function. This is needed for batching the data.
for idx, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc="Iterating over batches"):
    continue # remove this to visualize the first element of a batch

    print("####### BATCH INFO ########")
    # every batch is a dictionary with the following keys:
    for key in batch.keys():
        print(f"{key}")
    
    print("\n####### IMAGE ########")
    # batch['image'] is a list of images, one for each item in the batch
    plt.imshow(batch['image'][0])
    plt.axis('off')
    plt.title("GT Image of item 0 in batch")
    plt.show()
    

    print("####### LOCAL DESCRIPTIONS ########")
    # batch['local_descriptions'] is a list of lists of strings. Each item in the batch has a list. In each sublist, there is a description for each item in the image.
    num_descriptions = len(batch['local_descriptions'][0])
    print(f"Number of local descriptions in item 0: {num_descriptions}")
    for i, desc in enumerate(batch['local_descriptions'][0]):
        print(f"Local description {i}: {desc}")

    
    print("####### GLOBAL SKETCH ########")
    # batch['global_sketch'] is a list of images, one for each item in the batch
    # visualize the global sketch
    plt.imshow(batch['global_sketch'][0])
    plt.axis('off')
    plt.title("Global Sketch of item 0 in batch")
    plt.show()

    
    print("####### LOCAL SKETCHES ########")
    # batch['local_sketches'] is a list of lists of images. Each item in the batch has a list. In each sublist, there is an image for each local sketch in the item.
    num_local_sketches = len(batch['local_sketches'][0])
    assert num_local_sketches == num_descriptions, "Number of local sketches will always be equal to number of local descriptions"
    print(f"Number of local sketches in item 0: {num_local_sketches}")
    # visualize the local sketches
    MAX_NUM_COLUMNS = 2
    num_cols = min(num_local_sketches, MAX_NUM_COLUMNS)
    num_rows = num_local_sketches // num_cols + (num_local_sketches % num_cols > 0)
    fig, axs = plt.subplots(num_rows, num_cols, figsize=(5, 5))
    # flatten the axs for easier indexing
    if num_local_sketches > 1:
        axs = axs.flatten()
        for i in range(len(batch['local_sketches'][0])):
            axs[i].imshow(batch['local_sketches'][0][i])
            axs[i].set_title(f"Local Sketch {i}")
            axs[i].axis('off')
    else:
        axs.imshow(batch['local_sketches'][0][0])
        axs.set_title(f"Local Sketch 0")
        axs.axis('off')
    break  # remove this to iterate through all batches