In [3]:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from scipy.spatial import distance

# Load the data
data = np.load('public_data.npz', allow_pickle=True)
X = data['data']
y = data['labels']

X = ((X - X.min()) * (1/(X.max() - X.min()) * 255)).astype('uint8')

# Define templates for outlier detection
template_plant = X[0]  # Use the first image as a template for plants
template_shrek = X[58] # Load a template image of Shrek
template_trololo = X[338]  # Load a template image of "TROLOLO"

# Define a function to calculate the Euclidean distance between two images
def calculate_distance(image1, image2):
    return distance.euclidean(image1.flatten(), image2.flatten())

# Set a threshold for outlier detection
threshold = 5000

removed_images = []
removed_labels = []

cleaned_images = []
cleaned_labels = []

for i in range(len(X)):
    img = X[i]
    label = y[i]

    if i == 0:
        cleaned_images.append(img)
        cleaned_labels.append(label)
    else:
        distance_to_plant = calculate_distance(img, template_plant)
        distance_to_shrek = calculate_distance(img, template_shrek)
        distance_to_trololo = calculate_distance(img, template_trololo)

        if (
            distance_to_plant < threshold or
            distance_to_shrek < threshold or
            distance_to_trololo < threshold
        ):
            removed_images.append(img)
            removed_labels.append(label)
        else:
            cleaned_images.append(img)
            cleaned_labels.append(label)

X_cleaned = np.array(cleaned_images)
y_cleaned = np.array(cleaned_labels)

X_removed = np.array(removed_images)
y_removed = np.array(removed_labels)

np.savez('/kaggle/working/plants.npz', data=X_cleaned, labels=y_cleaned)
np.savez('/kaggle/working/shrek_trololo.npz', data=X_removed, labels=y_removed)

In [4]:
print("Number of plant images:", len(X_cleaned))
print("Number of Shrek and Trololo images:", len(X_removed))

Number of plant images: 5004
Number of Shrek and Trololo images: 196


In [5]:
from IPython.display import FileLink
FileLink(r'plants.npz')

In [6]:
from IPython.display import FileLink
FileLink(r'shrek_trololo.npz')