In [None]:
!pip install numpy==1.22.0

### Investigating augmentation

As we decided to perform data augmentation, we conducted this brief study to identify suitable augmentation parameters

Imports

In [None]:
import numpy as np  
from matplotlib import pyplot as plt  
from keras.preprocessing.image import ImageDataGenerator  
from sklearn.model_selection import train_test_split  
from sklearn.preprocessing import MinMaxScaler

Util functions

In [None]:
def normalize_0_255(tensor_4d_data):
    
    num_images, height, width, channels = tensor_4d_data.shape

    # Reshape the data to (num_images, num_pixels)
    tensor_4d_data_reshaped = tensor_4d_data.reshape(num_images, -1)

    # Apply MinMaxScaler independently to each image
    scaler = MinMaxScaler(feature_range=(0, 255))
    
    tensor_4d_data_scaled = scaler.fit_transform(tensor_4d_data_reshaped).reshape(tensor_4d_data.shape).astype('uint8')
    return tensor_4d_data_scaled



#Plot plants
def plot_plants(my_images,my_labels,myrange=10,num_img=10):
    for index in range(myrange):
        fig, axes=plt.subplots(1, num_img, figsize=(30,30))
        myrange=np.arange(0, 0+num_img)

        for i in myrange:    
            ax=axes[i%num_img]
            ax.imshow(my_images[index*num_img+i])  # Display the image    
            ax.set_title(my_labels[index*num_img+i],size=20)

        plt.tight_layout()
        plt.show() 


In [None]:
# Load your data    
data = np.load('/kaggle/input/dataset-plants/PULITO/public_data_not_outliers.npz', allow_pickle=True)    
print(data.files)    
img_array = data['data']    
labels = data['labels']    

# Split the dataset into training and test sets (80-20 split). Add a seed as random state
augm_data, test_data, augm_labels, test_labels = train_test_split(img_array, 
                                                                  labels, 
                                                                  test_size=0.2, 
                                                                  random_state=42,
                                                                  stratify=labels)  #add stratification

In [None]:
augm_data=normalize_0_255(augm_data)
plot_plants(augm_data, augm_labels)

### Changing  brightness range

In [None]:
# Create a data generator    
datagen = ImageDataGenerator(rotation_range=45,
                             width_shift_range=0.2,
                             height_shift_range=0.2,
                             brightness_range=[0.5,1.5],
                             shear_range=0.2,
                             vertical_flip=True,
                             horizontal_flip=True,
                             fill_mode='reflect')
  
# Fit the generator to your data    
datagen.fit(augm_data)    
  
#Data Augmentation    
# Initialize an empty list to store the augmented images and labels    
augmented_images = []    
augmented_labels = []    
  
# Use the .flow() function to generate batches of augmented images    
for img_batch, label_batch in datagen.flow(augm_data, augm_labels, batch_size=1):    
    for img, label in zip(img_batch, label_batch):    
        # Append each augmented image and its label to the list    
        augmented_images.append(img)    
        augmented_labels.append(label)    
  
    # Stop the loop after augmenting each image 3 times    
    if len(augmented_images) >= len(augm_data) * 3:    
        break    

augmented_images = normalize_0_255(np.stack(augmented_images))
plot_plants(augmented_images,augmented_labels)

In [None]:
# Create a data generator    
datagen = ImageDataGenerator(rotation_range=45,
                             width_shift_range=0.2,
                             height_shift_range=0.2,
                             brightness_range=[0.8,1.1],
                             shear_range=0.2,
                             vertical_flip=True,
                             horizontal_flip=True,
                             fill_mode='reflect')
  
# Fit the generator to your data    
datagen.fit(augm_data)    
  
#Data Augmentation    
# Initialize an empty list to store the augmented images and labels    
augmented_images = []    
augmented_labels = []    
  
# Use the .flow() function to generate batches of augmented images    
for img_batch, label_batch in datagen.flow(augm_data, augm_labels, batch_size=1):    
    for img, label in zip(img_batch, label_batch):    
        # Append each augmented image and its label to the list    
        augmented_images.append(img)    
        augmented_labels.append(label)    
  
    # Stop the loop after augmenting each image 3 times    
    if len(augmented_images) >= len(augm_data) * 3:    
        break   
        
augmented_images = normalize_0_255(np.stack(augmented_images))
plot_plants(augmented_images,augmented_labels)

In [None]:
# Create a data generator    
datagen = ImageDataGenerator(rotation_range=45,
                             width_shift_range=0.2,
                             height_shift_range=0.2,
                             brightness_range=[0.9,1.1],
                             shear_range=0.2,
                             vertical_flip=True,
                             horizontal_flip=True,
                             fill_mode='reflect')
  
# Fit the generator to your data    
datagen.fit(augm_data)    
  
#Data Augmentation    
# Initialize an empty list to store the augmented images and labels    
augmented_images = []    
augmented_labels = []    
  
# Use the .flow() function to generate batches of augmented images    
for img_batch, label_batch in datagen.flow(augm_data, augm_labels, batch_size=1):    
    for img, label in zip(img_batch, label_batch):    
        # Append each augmented image and its label to the list    
        augmented_images.append(img)    
        augmented_labels.append(label)    
  
    # Stop the loop after augmenting each image 3 times    
    if len(augmented_images) >= len(augm_data) * 3:    
        break   
        
augmented_images = normalize_0_255(np.stack(augmented_images))
plot_plants(augmented_images,augmented_labels)

### Conclusion on brightness range: it seems to me that either 0.9-1.1 or 0.8-1.2 are the best

## Changing shear range (reduce)

In [None]:
# Create a data generator    
datagen = ImageDataGenerator(rotation_range=45,
                             width_shift_range=0.2,
                             height_shift_range=0.2,
                             brightness_range=[0.9,1.2],
                             shear_range=0.1,
                             vertical_flip=True,
                             horizontal_flip=True,
                             fill_mode='reflect')
  
# Fit the generator to your data    
datagen.fit(augm_data)    
  
#Data Augmentation    
# Initialize an empty list to store the augmented images and labels    
augmented_images = []    
augmented_labels = []    
  
# Use the .flow() function to generate batches of augmented images    
for img_batch, label_batch in datagen.flow(augm_data, augm_labels, batch_size=1):    
    for img, label in zip(img_batch, label_batch):    
        # Append each augmented image and its label to the list    
        augmented_images.append(img)    
        augmented_labels.append(label)    
  
    # Stop the loop after augmenting each image 3 times    
    if len(augmented_images) >= len(augm_data) * 3:    
        break   
        
augmented_images = normalize_0_255(np.stack(augmented_images))
plot_plants(augmented_images,augmented_labels)

In [None]:
# Create a data generator    
datagen = ImageDataGenerator(rotation_range=45,
                             width_shift_range=0.2,
                             height_shift_range=0.2,
                             brightness_range=[0.9,1.2],
                             vertical_flip=True,
                             horizontal_flip=True,
                             fill_mode='reflect')
  
# Fit the generator to your data    
datagen.fit(augm_data)    
  
#Data Augmentation    
# Initialize an empty list to store the augmented images and labels    
augmented_images = []    
augmented_labels = []    
  
# Use the .flow() function to generate batches of augmented images    
for img_batch, label_batch in datagen.flow(augm_data, augm_labels, batch_size=1):    
    for img, label in zip(img_batch, label_batch):    
        # Append each augmented image and its label to the list    
        augmented_images.append(img)    
        augmented_labels.append(label)    
  
    # Stop the loop after augmenting each image 3 times    
    if len(augmented_images) >= len(augm_data) * 3:    
        break   
        
augmented_images = normalize_0_255(np.stack(augmented_images))
plot_plants(augmented_images,augmented_labels)

## Changing fill model

In [None]:
# Create a data generator    
datagen = ImageDataGenerator(rotation_range=45,
                             width_shift_range=0.2,
                             height_shift_range=0.2,
                             brightness_range=[0.9,1.2],
                             vertical_flip=True,
                             horizontal_flip=True,
                             fill_mode='rotate')
  
# Fit the generator to your data    
datagen.fit(augm_data)    
  
#Data Augmentation    
# Initialize an empty list to store the augmented images and labels    
augmented_images = []    
augmented_labels = []    
  
# Use the .flow() function to generate batches of augmented images    
for img_batch, label_batch in datagen.flow(augm_data, augm_labels, batch_size=1):    
    for img, label in zip(img_batch, label_batch):    
        
        # Append each augmented image and its label to the list    
        augmented_images.append(img)    
        augmented_labels.append(label)    
  
    # Stop the loop after augmenting each image 3 times    
    if len(augmented_images) >= len(augm_data) * 3:    
        break   
        
augmented_images = normalize_0_255(np.stack(augmented_images))
plot_plants(augmented_images,augmented_labels)

### Trying to remove also shifts

In [None]:
# Create a data generator    
datagen = ImageDataGenerator(brightness_range=[0.9,1.2],
                             vertical_flip=True,
                             horizontal_flip=True)
  
# Fit the generator to your data    
datagen.fit(augm_data)    
  
#Data Augmentation    
# Initialize an empty list to store the augmented images and labels    
augmented_images = []    
augmented_labels = []    
  
# Use the .flow() function to generate batches of augmented images    
for img_batch, label_batch in datagen.flow(augm_data, augm_labels, batch_size=1):    
    for img, label in zip(img_batch, label_batch):   
        
        # Append each augmented image and its label to the list    
        augmented_images.append(img)    
        augmented_labels.append(label)    
  
    # Stop the loop after augmenting each image 3 times    
    if len(augmented_images) >= len(augm_data) * 3:    
        break   
        
augmented_images = normalize_0_255(np.stack(augmented_images))
plot_plants(augmented_images,augmented_labels)

## Considerations:

- due to the shape of the images, we prefer not to include zoom since we believe that it might be responsible of losing important features for the determination of the class
- We opt for a 0.9-1.1 brightness change
- We maintain reflect as model
- We keep this rotation range
- As concerns the other aspects, we made some slight variations in different models


We have also tried to run some basic CNN models with slightly different augmentation settings, (not shown) and the results were consistent with our choice.