In [None]:
import numpy as np

'''
This script permits to create a copy of the input dataset without containing the images of Shrek and Trololo and duplicates.
'''

# filename of the new collection
new_npz_filename = 'clean_collection.npz'

# load the original collection file given by the teacher
dataset = np.load('public_data.npz', allow_pickle=True)
data = dataset['data']

shrek = 58                                # The index of one Shrek image in the dataset
trol = 753                                # The index of one Trololo image in the dataset

count_per_image = {}                      # For each unique image we store the number of times that we see it in the dataset
different_images = []                     # Each element is list containing the properties of an unique image [the label, the index in the dataset]
correct_images_indexes = []               # A list containing the indexes of the images that should remain in the dataset

# Iterate on the dataset
for i, image in enumerate(data):

  # If the image is equal to Shrek or equal to Trololo is skipped
  if np.array_equal(image, data[58]) or np.array_equal(image, data[753]):
    continue

  for image_ind in count_per_image:
    if np.array_equal(image, data[different_images[image_ind][1]]):

      count_per_image[image_ind] += 1     # Increase the counter

      # If a duplicate image has a different label we print a warning (it should be impossible)
      if dataset['labels'][i] != different_images[image_ind][0]:
        print(f"Image {i} is equal to image {different_images[image_ind][1]} but has a different label")

      break

  else:
    # Add a new correct image
    correct_images_indexes.append(i)
    # Start the counter
    count_per_image[len(different_images)] = 1
    # Add the new unique image
    different_images.append((dataset['labels'][i], i))

# Count the duplicates
duplicates = {"healthy": 0, "unhealthy": 0}
for image_ind in count_per_image:
  if count_per_image[image_ind] > 1:
    print(f"Image {different_images[image_ind][1]} has {count_per_image[image_ind]} copies in the dataset")
    if different_images[image_ind][0] == "healthy":
      duplicates["healthy"] += count_per_image[image_ind] - 1
    elif different_images[image_ind][0] == "unhealthy":
      duplicates["unhealthy"] += count_per_image[image_ind] - 1
    else:
      raise Exception("Impossible")

print(f"Total duplicates: {duplicates}")

# Filter the dataset
non_outliers_data = data[correct_images_indexes]
non_outliers_labels = dataset['labels'][correct_images_indexes]

# Save the new collection
np.savez(new_npz_filename, data=non_outliers_data, labels=non_outliers_labels)

Image 1 has 2 copies in the dataset
Image 3 has 2 copies in the dataset
Image 23 has 3 copies in the dataset
Image 44 has 2 copies in the dataset
Image 47 has 2 copies in the dataset
Image 54 has 2 copies in the dataset
Image 60 has 2 copies in the dataset
Image 88 has 2 copies in the dataset
Image 143 has 2 copies in the dataset
Image 148 has 3 copies in the dataset
Image 150 has 2 copies in the dataset
Image 158 has 2 copies in the dataset
Image 190 has 2 copies in the dataset
Image 191 has 2 copies in the dataset
Image 196 has 3 copies in the dataset
Image 227 has 2 copies in the dataset
Image 276 has 2 copies in the dataset
Image 284 has 2 copies in the dataset
Image 285 has 2 copies in the dataset
Image 335 has 2 copies in the dataset
Image 361 has 2 copies in the dataset
Image 380 has 2 copies in the dataset
Image 393 has 2 copies in the dataset
Image 395 has 4 copies in the dataset
Image 403 has 2 copies in the dataset
Image 436 has 2 copies in the dataset
Image 442 has 2 copies