# Data processing
The MobileViT model takes in two tensors, in the following shape:
```
input = {
    "pixel_values": torch.FloatTensor of shape (batch_size, num_channels, height, width),
    "labels": torch.LongTensor of shape (batch_size,)
}
```

where 
* batch_size = number of images in each batch
* num_channels = the number of color channels (i.e., RGB = 3 channels, RGBA = 4 channels)
* height = the height of the image in pixels
* width = the width of the image in pixels

For convenience, processing a dataset of images and labels into this shape will be done in this Jupyter Notebook so that the C# app can simply read in the input in the required format. However, since this data processing only requires translating images into numbers, this step can be done in any language, and to any file format (JSON was chosen for convenience).

In [None]:
import datasets
import torch
from torchvision import transforms
from PIL import Image
import json

In [None]:
# import the dataset
from datasets import load_dataset

dataset = load_dataset("FER-Universe/DiffusionFER")

In [None]:
tensor_converter = transforms.Compose([
    # dataset has varying sizes of images; resizing to a power of 2 to match ONNX model inputs
    transforms.Resize(256),
    transforms.ToTensor()
])

def convert_to_tensor(list_of_png, labels):
    # return [tensor_converter(image) for image in list_of_png]
    mega_tensor = None
    count = 0
    new_labels = []
    # for the sake of the demo, skip through some of the examples for a smaller dataset
    for i in range(0, len(list_of_png), 15):
    # for i in range(3):
        png = list_of_png[i]
        if mega_tensor is None:
            mega_tensor = tensor_converter(png).unsqueeze(0)
        else:
            mega_tensor = torch.vstack((mega_tensor, tensor_converter(png).unsqueeze(0)))

        new_labels.append(labels[i])
        print(mega_tensor.shape)
    return mega_tensor, new_labels

images, labels = convert_to_tensor(dataset['train']['image'], dataset['train']['label'])

tensor_dataset = {
    'image': images,
    'label': labels
    # concatenate the labels to be the same length if using for each loop when processing the images
    # 'label': dataset['train']['label'][:images.shape[0]]
}

In [None]:
def generate_json_dict(tensor_dict, keys_tensors, keys_1d):
    """
    Takes in a dictionary where the values are tensors

    Basically changes the 2d Python lists into two fields: a shape & a flattened list, for easier conversion to OnnxValues

    Returns a dictionary
    """
    json_dict = {}

    for key_name in keys_tensors:
        # add field for the shape of the tensor
        json_dict[key_name + "_shape"] = list(tensor_dict[key_name].shape)
        # flatten list
        json_dict[key_name] = torch.flatten(tensor_dict[key_name]).tolist()

    for key_name in keys_1d:
        # add field for the shape of the tensor
        json_dict[key_name + "_shape"] = [len(tensor_dict[key_name])]
        json_dict[key_name] = tensor_dict[key_name]
    
    
    return json_dict

json_dict = generate_json_dict(tensor_dataset, ['image'], ['label'])

In [None]:
tensor_dataset['image'].shape

In [None]:

with open('mini_train.json', 'w') as json_file:
    json.dump(json_dict, json_file)