## Homework 3: DCGAN for Flowers102

In [None]:
import torch
from torchvision import datasets, transforms
import scipy as sp

import json

## Flowers102 dataset import

In [None]:
# the size of the images is (500 x (something > 500)) or ((something > 500) x 500)
# we choose the maximum crop to obtain squared images
IMAGE_CROP = 500
# fix the the image size according to the problem description
IMAGE_SIZE = 256

# specify data transformations
transform = transforms.Compose([
    # crop the images to be squared
    transforms.CenterCrop(IMAGE_CROP),
    # resize the images to the desired resolution
    transforms.Resize(IMAGE_SIZE),
    # convert images to tensors and scale 
    transforms.ToTensor(),
    # normalize images to have values in [-1,1] in each channel
    transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
])

# read flower training dataset 
train_dataset = datasets.Flowers102(
    root="data", 
    split="train",
    download=True,
    transform=transform
)

# read flower validation dataset 
val_dataset = datasets.Flowers102(
    root="data",
    split="val",
    download=True,
    transform=transform
)

# read flower test dataset 
test_dataset = datasets.Flowers102(
    root="data",
    split="test",
    download=True,
    transform=transform
)

# join all datasets into one as we want to select images from the whole dataset
flower_dataset = torch.utils.data.ConcatDataset([train_dataset, val_dataset, test_dataset])

print('Total number of image samples in dataset:',len(flower_dataset))

## Flower category dictionary

In [None]:
# read 'label : flower category' dictionary
# note: class labels are in [0,1,...,101], in the the dictionary labels are in [1,2,...,102]
with open('./data/flowers-102/flower-categories.json', 'r') as f:
    label_to_flowername = json.load(f)

# output the flower name for classlabel
classlabel = 0
flowername = label_to_flowername[str(classlabel+1)]
print(classlabel,':',flowername)

# reversed dictionary: switch label and flower name
flowername_to_label= dict((v, k) for k, v in label_to_flowername.items()) 

# output the class label for flowername 
flowername = 'lotus'
classlabel = int(flowername_to_label[flowername])-1
print(flowername,':',classlabel)