# Image data preprocessing

For the image dataset we are working on a waste recycling plant dataset, which can be found on [Kaggle](https://www.kaggle.com/datasets/parohod/warp-waste-recycling-plant-dataset). There are three different versions of this dataset. We are using WaRP-C which contains cutout images of a single waste object. 
In this notebook we are going to process our dataset so that it can be used for machine learning classification. The first step is to import the libraries. You can also find a function which takes a dataset as imput and plots 20 random images.

In [164]:
import os
from PIL import Image
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import torch
from torchvision.transforms import v2
from torch.utils.data import WeightedRandomSampler, Dataset, DataLoader

In [165]:
def plot_20_images(dataset):
    df_sample = dataset.sample(n=20)
    print(df_sample)
    fig, axes = plt.subplots(nrows=4, ncols=5, figsize=(10,10))
    count = 0
    for index, row in df_sample.iterrows():
        if count < 5:
            axes[0,count].imshow(row["image"])
        elif count < 10:
            axes[1, count-5].imshow(row["image"])
        elif count < 15:
            axes[2, count-10].imshow(row["image"])
        else:
            axes[3, count-15].imshow(row["image"])
        count += 1
    return fig, axes

## Loading the image dataset
Like before the first step is to load our image dataset into memory. We wil again use PIL Image to do this. However, this time we are only saving the image, its path and their label into the pandas dataframe.

In [None]:
count = 0
images = pd.DataFrame(columns=["label", "path", "image"])
directory = 'datasets/Warp-C/'
for root_dir, cur_dir, files in os.walk(directory):
    print("root dir: " + str(root_dir))
    label = os.path.basename(os.path.normpath(root_dir))
    for file in files:
        if ".jpg" in file:
            file_name = root_dir +"/"+ file
            count += 1
            image = Image.open(file_name)
            row = [ label, image, file_name]
            images.loc[len(images)] = row

print("file count: " + str(count))
print(images)

root dir: datasets/Warp-C/
root dir: datasets/Warp-C/bottle-blue
root dir: datasets/Warp-C/bottle-blue-full
root dir: datasets/Warp-C/bottle-blue5l
root dir: datasets/Warp-C/bottle-blue5l-full
root dir: datasets/Warp-C/bottle-dark
root dir: datasets/Warp-C/bottle-dark-full
root dir: datasets/Warp-C/bottle-green
root dir: datasets/Warp-C/bottle-green-full
root dir: datasets/Warp-C/bottle-milk
root dir: datasets/Warp-C/bottle-milk-full
root dir: datasets/Warp-C/bottle-multicolor
root dir: datasets/Warp-C/bottle-multicolor-full
root dir: datasets/Warp-C/bottle-oil
root dir: datasets/Warp-C/bottle-oil-full
root dir: datasets/Warp-C/bottle-transp
root dir: datasets/Warp-C/bottle-transp-full
root dir: datasets/Warp-C/bottle-yogurt
root dir: datasets/Warp-C/canister
root dir: datasets/Warp-C/cans
root dir: datasets/Warp-C/cardboard-juice
root dir: datasets/Warp-C/cardboard-milk
root dir: datasets/Warp-C/detergent-box
root dir: datasets/Warp-C/detergent-color
root dir: datasets/Warp-C/detergen

Before we start processing our data we need to have a function which saves all processed images as a new dataset. Do this using the following code:

    save image:
    PIL_image.save("path/to/file.png")

    iterate over pandas dataframe:
    for idx, row in result.iterrows():

In [179]:
def save_dataset(dataframe):
    return None

## Data transformation

### resizing your images
The first step for images is to resize them. A machine learning model has a set input size, which means that all input images need to have the same dimensions. 

To resize an image you can use the following PIL Image function
<code>resized_image = image.resize((size_width, size_height))</code>

You will also have to work with a lambda expression. This means that you define a function (here a "resize" function) which you than call in the "apply" function on your dataframe so that each row in your dataframe will apply this function.

<code>
dataframe["image"] = dataframe["image"].apply(lambda img:function(inputs))
</code>
</br>
You can use the plot_20_images(dataset) function to look at your results.

Don't forget to save your image dataset with DVC.

In [None]:
def resize(image):

    return resized_image

### Data augmentation

When creating your image dataset, you might have taken a certain amount of pictures with a specific lighting setup or a specific camera angle. However, most of the time you want your model to be robust against different lighting setups, angles and noise. To help with this, you can use data augmentation and transform techniques.

Transformations allow you to change your original image in a specific way. This can include things like cropping, changing colors, adding noice, rotating or flipping the image, and many more. Pytorch has an easy way to add transformed images to your dataset:

    transforms = v2.Compose([

        v2.transformation1(),
        v2.transformation2(),
        v2.transformation3(),
    ])

    image = transforms(image)


You can find a list of the possible transformations [here](https://docs.pytorch.org/vision/stable/transforms.html). Try to implement at least 3 different transformations to the dataset. Do this by creating a transform() function and implementing it on all images in the dataframe using a lambda expression. If you want to create a larger dataset, you can add the transformed images to the original dataset.

In [180]:
def transform(image): 
    
    return image_transformed

###  Normalizing your pixel values

Normally your data is saved as three matrices containing values between 0 and 255 (RGB values). However, sometimes, to create a consistent scale in your dataset it is a good idea to normalize these values between 0 and 1. We will use a transform function from torchvision. The Normalize transform function uses the mean and standard deviation to calculate the normalized values. Therefore we first need to transform our PIL image to a Tensor image consisting of floats (this is also necessary for certain other transformations). Then we can calculate the mean and standard deviation. Lastly we transform our tensor image with the Normalize function. To implement this to all images, you can again use a lambda function. You will need the following code:

<code>transform = v2.Compose([<i>list, of, transformations</i>])</code></br>

    result = transform(image)
    
    mean = image_tensor.mean([1,2])
    std = image_tensor.std([1,2])

    change int to float: v2.ToDtype(torch.float32, scale=True)
    normalize: v2.Normalize(mean, std)
    convert PIL to tensor: v2.ToImage()
    Convert tensor to PIL: v2.ToPILImage()

Don't forget to save your dataset with DVC.

In [181]:
def normalize_image(image):

    return normalized_image

### Dealing with class imbalance

In the original dataset we checked the amount of values per class. This is done to find out if there is class imbalance. If there is one class which has a lot of data samples and another which has very little data samples, there might be some problems whilst training. For example, you might find that the model learns to return the overrepresented class instead of actually learning something of value. This problem can be solved when creating the dataloader, which will call your data whilst training.

There are several ways to deal with class imbalance. This includes oversampling (adding some samples multiple times from the unerrepresented class), undersampling (not adding some samples from the overrepresented class) or class weighting (sample the data with weigths respresenting the class imbalance). These techniques are sometimes done before training, like we will do here in the dataloader. However you can also add weigths to certain loss functions in pytorch which deal with class imbalances. 

In this exercise we are going to use a weighted random sampler in a dataloader. This means that we will create a torch dataset and dataloader and add the sampler to the dataloader. That way the data that the model will get whilst training will be sampled using the weigths of the class imbalances. 

First, you will need to calculate the class weights. The class weigths are calculated as follows: $\frac{1}{class-count}$. This means that we first need to count the number of times a certain class is represented. Do this for the image dataset.

    count = dataset["column"].value_counts().to_list()

When you have your clas weigths, they can be used in our WeightedRandomSampler. The weigths need to be transformed to tensors.

    samples_weigths = torch.from_numpy(np.array(weigths))
    sampler = WeightedRandomSampler(samples_weigths, len(samples_weights))

Lastly we create a TensorDataset and DataLoader from our pandas dataframe. to represent your labels as a one-hot-encoding we created a function for you. This is to change the labels from a unique string label to something that the machine learning model will understand, a vector of zeroes and ones. We also made a custom torch Dataset to be able to use in the Dataloader since we use PIL images instead of tensors.

In [175]:
def create_onehotlabels(images_normal):
    result = pd.get_dummies(images_normal["label"])
    onehotlist = []
    for idxm, row in result.iterrows():
        boollist= row[1:].tolist()
        onehotlist.append([int(value) for value in boollist])

    images_normal["label"] = onehotlist

    return images_normal

In [None]:
class ImageLabelDataset(Dataset):
    def __init__(self, dataframe):
        self.dataframe = dataframe
        self.transformer = v2.ToTensor()

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        image = self.transformer(self.dataframe.iloc[idx]['image'])
        label = self.dataframe.iloc[idx]['label']
        label = torch.tensor(label, dtype=torch.float32)

        return image, label


Use all previously defined functions and classes to create a dataloader which will load the image and its one hot encoded label:

    dataset = ImageLabelDataset(images)
    dataloader = DataLoader(dataset, batch_size, sampler=sampler)