# Loading the data

In [3]:
import tensorflow as tf
import numpy as np

In [4]:
# Getting the data
from pathlib import Path
train_dir = Path('/kaggle/input/vlg-recruitment-24-challenge/vlg-dataset/vlg-dataset/train')

In [5]:
# Defined the batch size here
batch_size = 32
# Image size 336 is used here because RandomCrop is used later to get a 224*224 part out
img_height = 336
img_width = 336

# Carrying out the train test split
train_ds = tf.keras.utils.image_dataset_from_directory(
  train_dir,
  label_mode='categorical',
  validation_split=0.2,
  subset='training',
  seed=123,
  image_size=(img_height, img_width),
  batch_size=batch_size
)

val_ds = tf.keras.utils.image_dataset_from_directory(
  train_dir,
  label_mode='categorical',
  validation_split=0.2,
  subset='validation',
  seed=123,
  image_size=(img_height, img_width),
  batch_size=batch_size
)

Found 9544 files belonging to 40 classes.
Using 7636 files for training.
Found 9544 files belonging to 40 classes.
Using 1908 files for validation.


In [6]:
# For faster loading
AUTOTUNE = tf.data.AUTOTUNE
train_ds = train_ds.cache().prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

# Creating the model

In [7]:
# Initializing the model using sequential API
resnet_model=tf.keras.models.Sequential()

In [8]:
# Implementing the data augmentation and pixel value rescaling layer
data_augmentation = tf.keras.Sequential([
    tf.keras.layers.RandomFlip('horizontal'),
    tf.keras.layers.RandomRotation(0.2),  
    tf.keras.layers.RandomCrop(224, 224),
    tf.keras.layers.Resizing(224, 224),
    tf.keras.layers.Rescaling(1.0/255)      
])

In [9]:
# Using MobileNetV2 for feature extraction
pretrained_model = tf.keras.applications.MobileNetV2(
    input_shape=(224,224,3),
    include_top=False,
    weights="imagenet",
    pooling='avg',
)

# Freezing the pre-trained layers
pretrained_model.trainable = False

Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v2/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_224_no_top.h5
[1m9406464/9406464[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 0us/step


In [10]:
resnet_model.add(data_augmentation)
resnet_model.add(pretrained_model)
# Adding layers for personalized task
resnet_model.add(tf.keras.layers.Dense(512, activation='relu'))
resnet_model.add(tf.keras.layers.Dense(40, activation='softmax'))

In [11]:
resnet_model.compile(
    optimizer='Adam',  # Optimizer
    loss='categorical_crossentropy',  # Loss function
    metrics=['accuracy']  
)

# Training and fine tuning

Training

In [None]:
history = resnet_model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=15,
    batch_size=batch_size
)

Fine tuning

In [None]:
# Unfreezing the pretrained layers
pretrained_model.trainable = True

resnet_model.compile(optimizer=tf.keras.optimizers.Adam(1e-5),  # Very low learning rate
              loss='categorical_crossentropy',
              metrics=['accuracy'])


resnet_model.fit(train_ds, epochs=10, validation_data=val_ds)

# Creating the submission csv

In [None]:
# List of all class names
class_names = ['antelope', 'bat', 'beaver', 'blue+whale', 'bobcat', 'buffalo', 'chihuahua', 'chimpanzee', 'collie', 'cow', 'dalmatian', 'deer', 'dolphin', 'elephant', 'fox', 'german+shepherd', 'giant+panda', 'giraffe', 'gorilla', 'grizzly+bear', 'hamster', 'hippopotamus', 'horse', 'humpback+whale', 'killer+whale', 'leopard', 'lion', 'mole', 'moose', 'mouse', 'otter', 'ox', 'persian+cat', 'pig', 'polar+bear', 'rabbit', 'raccoon', 'rat', 'rhinoceros', 'seal', 'sheep', 'siamese+cat', 'skunk', 'spider+monkey', 'squirrel', 'tiger', 'walrus', 'weasel', 'wolf', 'zebra']


In [None]:
# Matrix of embedding vectors of all 50 classes
arr = np.array([
[-1, -1, -1, -1, 12.34, 0, 0, 0, 16.11, 9.19, 0, 38.09, 4.44, 28.55, 38.75, 5.68, 17.07, 39.99, 0, 0, 67.08, 7.78, 0, 60.24, 16.8, 40.59, 29.7, 5.56, 2.47, 0, 87.43, 0, 8.64, 9.04, 0, 9.23, 1.23, 0, 54.58, 70.86, 3.33, 33.56, 8.15, 26.14, 0, 67.85, 41.19, 7.36, 1.11, 6.94, 62.32, 0, 4.44, 0, 57.76, 12.63, 33.24, 61.86, 0, 0, 0, 0, 22.72, 55.81, 5.9, 0, 0, 19.88, 54.79, 4.94, 40.97, 0, 22.32, 0, 57.14, 0, 0, 1.23, 10.49, 39.24, 17.57, 50.59, 2.35, 9.7, 8.38],
[91.55, 1.39, 0, 54.76, 27.64, 0, 0, 0, 1.25, 1.25, 0, 38.75, 40.61, 30, 1.25, 72.47, 15.28, 45.69, 0, 5.28, 0, 0, 7.5, 10.96, 5.42, 15.11, 26.18, 62.03, 8.12, 0, 0, 54.12, 0, 30.33, 95.9, 2.5, 0, 0, 2.5, 80.96, 9.91, 7.5, 35.58, 35.73, 23.75, 6.48, 36.24, 36.81, 98.86, 45.3, 42.11, 6.25, 43.4, 0, 36.46, 46.43, 24.38, 10.62, 43.28, 28.22, 0, 8.12, 65.69, 65.69, 0, 9.38, 11.62, 14.03, 5.91, 58.01, 16.41, 37.52, 31.25, 0, 13.14, 1.39, 56.97, 80.62, 32.57, 8.98, 32.3, 62.25, 19.97, 34.91, 5.56],
[19.38, 0, 0, 87.81, 7.5, 0, 0, 0, 0, 7.5, 0, 46.25, 1.25, 25, 6.88, 43.12, 37.5, 8.75, 4.38, 13.12, 0, 24.38, 39.38, 0, 0, 86.56, 45.31, 5, 83.12, 29.69, 0, 30.62, 0, 3.12, 0, 0, 74.06, 8.75, 10.94, 25, 8.75, 32.81, 8.75, 24.38, 12.5, 24.38, 67.81, 2.5, 21.88, 36.88, 36.25, 48.75, 2.5, 3.12, 18.12, 0, 10.62, 4.69, 3.75, 7.5, 6.25, 1.25, 63.44, 19.06, 11.25, 33.75, 0, 0, 0, 19.06, 15.62, 0, 0, 0, 31.25, 65.62, 0, 0, 3.75, 31.88, 41.88, 23.44, 31.88, 33.44, 13.12],
[12.92, 4.38, 67.08, 7.5, 25.6, 0, 0, 0, 15.31, 23.75, 2.5, 0, 64.47, 45.17, 86.46, 0, 65.15, 11.88, 64.87, 0, 0, 0, 0, 0, 0, 26.42, 0, 0, 0, 69.32, 0, 0, 0, 13.75, 0, 0, 71.82, 0, 0, 21.42, 42.6, 55.26, 0, 25.37, 0, 0, 7.88, 35.48, 6.25, 0, 7.06, 28.81, 0, 75.25, 0, 0, 0, 0, 5.94, 0, 45.73, 0, 33.96, 30.21, 32.75, 5.31, 0, 0, 0, 0, 0, 0, 0, 74.11, 0, 76.61, 0, 0, 7.5, 44.58, 39.06, 33.12, 25.99, 10.83, 5],
[16.13, 9.44, 0, 38.39, 2.5, 37.08, 5, 22.83, 7.5, 33.55, 2.5, 64.71, 0, 2.5, 19.93, 23.37, 0, 42.51, 0, 0, 0, 22.81, 62.75, 3.75, 0, 64.5, 5.92, 73.9, 0, 0, 0, 71.75, 0, 9.59, 0, 3.75, 0, 0, 56.83, 73.97, 0, 32.99, 2.5, 47.3, 0, 74.25, 61.13, 6.38, 20.73, 28.02, 71.74, 15.28, 71.52, 0, 2.5, 0, 36.97, 0, 81.44, 13.54, 0, 66.98, 51.79, 42.75, 6.67, 1.39, 15.97, 18.34, 27.05, 30.95, 14.38, 17.08, 50.56, 0, 42.51, 0, 24.9, 26.2, 71.95, 0, 42.89, 13.47, 59.47, 20.28, 3.75],
[45.37, 0, 0, 61.05, 9.9, 0, 0, 0, 2.08, 0, 0, 39.24, 6.46, 42.36, 84.91, 0, 50.83, 0, 0, 0, 31.67, 0, 0, 7.5, 0, 5.39, 31.19, 0.62, 17.15, 0, 49.87, 0, 0, 20.97, 0, 0, 0, 0, 59.53, 24.79, 32.64, 50.81, 0, 21.13, 0, 61.82, 18.94, 29, 0, 0, 4.38, 0, 0, 0, 51.22, 0, 23.67, 54.04, 6.67, 0, 0, 0, 55.83, 18.33, 0, 0, 0, 6.94, 68.84, 0, 14.55, 0, 4.63, 0, 51.34, 9.38, 0, 0, 26.99, 19.88, 10, 52.99, 8.8, 9.38, 2.31],
[32.63, 10, 0, 64.79, 22.63, 2.5, 0.12, 0.25, 17.1, 11.79, 1.56, 51.16, 16.46, 0.62, 0, 86.05, 0.62, 34.3, 0, 0, 0, 22.1, 73.54, 0, 7.5, 48.8, 6.25, 52.61, 0, 4.38, 0, 29.76, 0, 29.4, 0, 8.75, 5, 0, 70.71, 27.87, 20.65, 4.3, 61.36, 6.88, 3.12, 81.01, 56.13, 13.06, 0, 0, 20.24, 0, 56.74, 0, 3.75, 1.25, 3.2, 3.75, 6.58, 15.2, 0, 1.38, 63.86, 37.23, 0, 1.51, 6.33, 3.2, 5.08, 5.08, 11.47, 0, 1.88, 0, 72.94, 1.25, 9.72, 0, 33.03, 36.73, 29.84, 1.25, 44.91, 9.8, 73.55],
[47.51, 9.93, 0, 69.1, 13.05, 7.14, 0, 2.5, 0.62, 0, 0, 81.98, 0.46, 20.96, 28.29, 38.26, 6.56, 45.75, 0, 77.77, 0, 9.69, 9.31, 65.34, 2.5, 54.26, 52.62, 24.34, 11.84, 0, 0, 14.44, 0, 49.83, 0, 10.31, 3.75, 0, 59.72, 56.94, 15.42, 41.19, 17.27, 41.83, 66.82, 44.53, 77.8, 7.5, 9.81, 0, 75.86, 2.5, 8.31, 0, 67.15, 22.28, 33.53, 11.94, 7.5, 15.36, 0, 8.75, 52.98, 74.34, 1.25, 8.75, 5, 65.09, 15.58, 63.58, 0, 91.98, 27.35, 0, 28.7, 1.25, 81.36, 10, 23.1, 29.64, 84.36, 74.51, 8.75, 41.93, 36.62],
[10.13, 41.37, 0, 47.27, 3.75, 8, 0.5, 0, 37, 9.09, 0, 78.21, 0, 1.25, 26.31, 21.19, 8.07, 21.02, 0, 0, 0, 12.95, 76.17, 15.91, 1.25, 67.09, 32.83, 49.64, 8.12, 6.25, 0, 16.7, 0, 17.54, 0, 0, 0, 0, 61.23, 49.05, 1.25, 19.55, 4.66, 15.34, 0.62, 70.14, 54.26, 1.88, 0, 0, 32.05, 0, 51.12, 0, 7.5, 0, 6.16, 0, 22.27, 4.54, 0, 3.75, 55.39, 40.72, 0, 6.25, 0, 10.86, 18.36, 0, 17.47, 0, 12.5, 0, 54.43, 0, 0, 0, 5.25, 43.09, 42.17, 0.62, 45.99, 18.57, 79.11],
[55.31, 55.46, 0, 58.48, 15.5, 1.49, 0.15, 0, 26.67, 31.35, 4.09, 32.05, 19.46, 42.45, 68.31, 4.61, 35.81, 5.13, 0, 0, 42.86, 7.61, 3.8, 13.72, 6.35, 62.77, 39.99, 2.8, 14.58, 0.19, 30.37, 0, 4.95, 34.8, 0, 0, 0.46, 0.19, 58.21, 10.89, 45.96, 47.02, 11.91, 18.63, 0, 62.42, 20.93, 44.12, 7.14, 0, 16.06, 0, 0.3, 0, 61.33, 2.32, 5.24, 60.89, 0, 0.72, 0, 0, 61.94, 61.57, 0, 8.76, 0.31, 8.03, 30.36, 10.21, 57.48, 0.59, 9.17, 0, 57.36, 0.55, 0, 0.32, 3.34, 47.87, 13.97, 51.57, 5.04, 18.89, 72.99],
[69.58, 73.33, 0, 6.39, 0, 0, 0, 0, 37.08, 100, 0, 27.15, 25.9, 7.5, 39.31, 8.12, 0, 63.68, 0, 0, 0, 7.5, 69.03, 40.07, 0, 53.75, 34.44, 35.56, 0, 0, 0, 4.17, 0, 0, 0, 0, 9.38, 0, 67.99, 61.74, 3.75, 34.93, 3.89, 23.75, 10.14, 77.92, 37.5, 3.75, 2.5, 0, 38.68, 0, 39.58, 0, 0, 0, 0, 0, 7.5, 0, 0, 8.75, 63.47, 29.86, 0, 1.25, 0, 0, 0, 0, 1.25, 0, 0, 0, 41.39, 1.25, 6.25, 0, 9.38, 31.67, 53.26, 24.44, 29.38, 11.25, 72.71],
[0, 20.34, 0, 75.85, 5.92, 0, 0, 0, 30.34, 48.08, 0, 52.33, 0, 18.22, 39.88, 15.36, 0, 42.52, 0, 0, 66.73, 0, 0, 48.98, 34.17, 43.85, 32.56, 4.61, 9.03, 0, 70.76, 0, 0, 6.53, 0, 10.24, 0, 0, 63.33, 73.73, 1.25, 23.88, 6.51, 32.58, 0, 68.71, 50, 1.25, 2.63, 0, 62, 0, 5.92, 0, 67.66, 0, 23.91, 69.35, 5.26, 4.61, 0, 0, 62.18, 45.19, 1.39, 4.85, 1.56, 19.79, 36.94, 69.22, 56.72, 0, 25.16, 0, 63.06, 0, 0, 0, 0, 67.49, 40.05, 55.46, 6.51, 31.55, 10.27],
[10.22, 21.53, 27.73, 0.33, 60.82, 0, 0, 0.16, 3.3, 0.8, 0, 0.62, 73.27, 39.21, 47.82, 17.26, 14.81, 28.2, 86.44, 0, 0, 0, 0, 5.76, 0, 41.85, 21.95, 11.78, 2.55, 11.96, 0, 0, 0, 7.57, 0.82, 1.29, 60.36, 0, 0, 59.09, 2.96, 41.88, 9.18, 26.29, 7.69, 4.8, 53.15, 1.72, 0.62, 0, 50.56, 56.88, 2.99, 11.8, 7.38, 0, 8.86, 0.75, 20.01, 4.02, 5.19, 6.24, 57.1, 56.41, 20.34, 24.07, 0, 0, 0, 0, 0, 0, 0, 77.19, 0, 71.4, 0, 0, 10.73, 21.06, 60.38, 49.62, 3.96, 14.05, 37.98],
[2.5, 3.75, 0, 15.23, 83.97, 0, 0, 0, 0, 1.25, 0, 1.14, 68.27, 77.62, 85.48, 1.25, 48.52, 1.25, 0, 0, 10.6, 17.1, 5.08, 39.14, 1.25, 51.97, 38.16, 0, 11.25, 0, 1.14, 0, 70.47, 49.67, 3.75, 0, 0, 0, 66.19, 3.75, 63.03, 67.45, 0, 25.56, 0, 70.61, 6.48, 53.62, 7.5, 0, 5.78, 0, 5.68, 0, 55.85, 1.25, 8.06, 43.83, 3.75, 2.5, 0, 0, 4.53, 84.56, 0, 0, 8.75, 36.91, 14.11, 7.58, 16.1, 36.33, 0, 0, 66.81, 1.25, 0, 0, 20.63, 39.36, 22.81, 49.87, 6.25, 12.29, 6.88],
[0, 1.56, 0, 51.25, 11.17, 48.92, 40.89, 3.75, 1.56, 0, 0, 66.35, 0, 5.7, 5.23, 52.79, 0, 50.61, 0, 6.25, 0, 20.98, 42.58, 7.59, 0, 66.65, 37.28, 64.48, 6.25, 7.5, 0, 32.64, 0, 11.66, 0, 0, 3.12, 0, 62.75, 67.51, 0, 26.72, 1.25, 21.6, 0, 66.89, 54, 5.7, 43.24, 21.91, 49.56, 28.57, 68.87, 0, 5.28, 0, 32.03, 0, 61.52, 19.58, 0, 61.2, 61.81, 43.33, 19.28, 7.5, 13.12, 12.5, 21.17, 61.82, 39.69, 5, 13.12, 0, 58.62, 0, 1.25, 4.06, 44.9, 5, 79.26, 18.75, 48.18, 27.88, 4.38],
[43.54, 15.88, 5, 54.16, 26.82, 3.12, 2.5, 0.38, 48.78, 11.59, 1.56, 66.05, 3.75, 18.46, 54.88, 1.25, 15.85, 38.05, 0, 0, 0, 41.94, 59.8, 23.91, 0, 72.3, 29.38, 75.43, 0, 5, 0, 31.83, 0, 23.41, 0, 0, 0, 0, 76.91, 57.02, 3.12, 62.33, 1.25, 50.69, 5, 82.93, 58.15, 1.88, 7.5, 0, 60.14, 6.25, 69.25, 0, 7.5, 0, 5.62, 0.62, 43.77, 11.25, 0, 21.02, 64.49, 48.92, 6.25, 2.5, 1.25, 2.5, 21.33, 17.89, 12.5, 0, 11.25, 0, 72.61, 3.75, 0, 2.5, 57.44, 10, 57.53, 12.5, 35.11, 16.53, 68.55],
[76.85, 72.33, 0, 5, 4.38, 0, 0, 0, 84.95, 24.17, 0, 84.15, 0, 10, 75.01, 3.75, 70.22, 1.25, 0, 14.31, 0, 41.61, 70.55, 2.5, 2.5, 9.03, 55.5, 18.75, 38.12, 5, 0, 38.44, 0, 20.23, 0, 1.25, 0, 0, 71.96, 5, 59.61, 46.11, 14.03, 17.69, 48.68, 55.2, 5.5, 53.51, 9.63, 18.66, 10, 21.25, 14.38, 0, 66.01, 6.07, 40.63, 41.84, 3.75, 3.75, 0, 6.25, 29.17, 84.4, 5, 5.14, 0, 31.96, 15.28, 48.03, 9.03, 45.65, 18.12, 0, 44.78, 1.25, 49.12, 14.4, 6.25, 65.84, 37.3, 28.1, 55.38, 38.19, 29.58],
[6.11, 11.87, 0, 32.21, 0, 24.7, 1.39, 48.43, 47.52, 77.08, 0, 17.17, 14.05, 13.57, 78.23, 0, 1.11, 68.82, 0, 0, 48.04, 0.56, 6.65, 84.01, 96.71, 44.8, 44.19, 0, 18.32, 6.67, 29.6, 0, 0, 22.09, 4.44, 2.02, 0, 0, 64.07, 31.86, 28.17, 31.43, 16.41, 21.52, 0, 77.12, 24.22, 35.89, 0, 0, 18.38, 0, 5.56, 0, 69.61, 9.49, 12.78, 55.84, 0, 0, 0, 0, 11.11, 66.16, 0, 0, 12.84, 36.14, 50.69, 11.11, 36.48, 16.92, 0, 0, 57.66, 0, 9.88, 0, 1.11, 45.36, 16.95, 40.89, 9.57, 8.69, 3.33],
[63.37, 1.79, 7.14, 45.51, 17.01, 8.48, 3.57, 0, 0, 0, 0, 79.21, 0, 28.08, 70.48, 2.5, 31.77, 12.68, 0, 61.5, 0, 8.12, 1.25, 33.57, 3.75, 10, 34.01, 22.14, 11.98, 0, 0, 4.55, 0, 42.6, 0, 7.5, 2.5, 0, 63.43, 37.06, 18.75, 73.42, 1.25, 65.79, 68.8, 37.17, 49.22, 6.56, 8.75, 0, 59.73, 7.41, 30.62, 0, 53.31, 15.36, 34.71, 7.5, 19.11, 12.63, 0, 11.25, 12.5, 73.83, 1.25, 8.75, 1.25, 15, 5, 33.19, 6.25, 77.01, 17.14, 0, 40.73, 1.25, 50.29, 14.12, 46.23, 12.13, 56.19, 54.5, 7.5, 44.17, 18.93],
[39.25, 1.39, 0, 74.14, 3.75, 0, 0, 0, 1.25, 0, 0, 82.37, 0, 21.82, 86.69, 0, 45.13, 0, 0, 11.65, 0, 3.75, 69.6, 9.01, 0, 9.38, 44.25, 64.69, 15, 1.25, 0, 68.87, 0, 11.25, 0, 0, 2.5, 0, 64.85, 46.97, 22.57, 78.48, 1.25, 48.89, 51.21, 36.77, 29.95, 32.57, 24.32, 86.14, 15.92, 32.15, 64.58, 0, 16.88, 0, 25.74, 0, 60.83, 5.26, 1.12, 26.05, 61.54, 3.95, 6.65, 2.78, 0, 0, 0, 77.4, 10, 2.5, 43.85, 0, 47.77, 7.64, 9.79, 53.14, 61.8, 12.5, 24, 3.12, 58.64, 20.14, 11.39],
[41.38, 39.71, 0, 62.76, 37.38, 17.5, 2.5, 10, 52.77, 8.5, 7.5, 86.82, 0, 5.56, 0, 76.33, 39.03, 5.56, 0, 7.5, 0, 7.5, 44.47, 0, 0, 26.34, 33.15, 11.11, 62.08, 5, 0, 22.53, 0, 26.23, 0, 23.61, 0, 35.99, 62.24, 39.33, 8.75, 3.89, 46.96, 0, 2.64, 72.49, 47.28, 29.95, 34.58, 26.88, 32.57, 0, 0, 0, 60.33, 10, 27.26, 22.71, 0, 9.61, 0, 0, 53.14, 20.56, 0, 3.75, 13.75, 5, 8.75, 2.5, 42.91, 0, 0, 0, 69.09, 0, 5, 2.5, 8.75, 40.69, 18.37, 16.98, 35, 36.6, 71.44],
[4.77, 0, 0, 18.61, 81.49, 0, 0, 0, 0, 0, 0, 0, 63.91, 66.12, 79.48, 2.5, 59.45, 0, 0, 0, 14.31, 8.33, 0, 3.56, 0, 7.27, 31.94, 8.75, 0, 28.12, 6.25, 0, 0, 17.5, 1.25, 0, 40.37, 0, 48.83, 2.5, 68.69, 49.76, 1.25, 32.09, 0, 60.22, 6.25, 50.23, 0, 0, 8.52, 23.33, 13.75, 6.25, 49.22, 0, 6.02, 18.52, 3.75, 2.5, 0, 0, 5, 68.55, 0, 3.75, 0, 3.41, 5.21, 8.75, 0, 28.47, 0, 0, 31.15, 46.78, 0, 0, 19.17, 29.58, 8.75, 20.6, 25.14, 2.31, 0],
[44.9, 42.91, 4.44, 69.41, 35.94, 0, 0, 0, 22.29, 15.8, 0, 40.58, 12.59, 42.45, 71.5, 0, 15.72, 47.96, 0, 0, 86.32, 3.7, 0, 70.57, 44.94, 70.42, 50.14, 2.92, 27.6, 0, 0, 0, 0, 33.07, 0, 0, 0, 0, 55.58, 81.68, 1.11, 69.13, 1.11, 51.14, 0, 70.35, 55.95, 1.11, 1.11, 1.11, 49, 0, 5.85, 0, 51.05, 0, 2.34, 72.19, 4.09, 0, 0, 0, 62.29, 50.07, 0, 2.92, 6.73, 8.48, 52.54, 10.76, 70.14, 3.33, 16.22, 0, 56.52, 2.22, 0, 0, 15.51, 35.39, 37.28, 36.47, 16.78, 14.62, 59.33],
[24.01, 5.92, 31.1, 8.75, 59.43, 0, 0, 0, 18.72, 6.88, 3.95, 0, 65.21, 68.52, 94.6, 0, 64.92, 13.75, 63.98, 0, 0, 0, 0, 0, 0, 44.8, 16.25, 0, 0, 42.37, 0, 0, 0, 15, 0, 0, 89.36, 0, 0, 26.7, 25.94, 63.27, 2.5, 28.98, 0, 0, 12.5, 29.86, 3.75, 0, 10.34, 42.24, 2.5, 82.3, 12.38, 0, 11.25, 0, 15.65, 10.34, 50.11, 0, 35.34, 27.84, 48.7, 27.61, 0, 0, 0, 0, 0, 0, 0, 89.36, 0, 89.95, 0, 0, 3.75, 60.72, 60.31, 51.03, 19.06, 20.34, 0],
[83.4, 64.79, 0, 0, 1.25, 0, 0, 0, 68.49, 32.69, 0, 1.25, 70.62, 57.04, 90.85, 1.25, 61.87, 22.68, 79.94, 0, 0, 0, 0, 0, 1.25, 41.67, 12.5, 45.15, 5, 30.22, 0, 0, 0, 7.5, 1.25, 0, 91.45, 0, 0, 57.37, 5.14, 63.35, 1.25, 10.45, 0, 0, 27.29, 13.23, 8.75, 0, 27.8, 66.75, 21.81, 32.86, 3.75, 0, 10.89, 0, 57.87, 6.61, 20.36, 15, 16.25, 12.5, 24.51, 30.11, 0, 0, 0, 0, 0, 0, 0, 88.28, 0, 79.49, 0, 0, 38.27, 9.77, 52.03, 24.94, 15.77, 13.41, 15.42],
[40.88, 19.44, 0, 31.33, 2.5, 20.42, 2.5, 24.58, 26.11, 74.97, 1.25, 43.27, 10.97, 12.18, 54.64, 5, 0, 73.48, 0, 0, 0, 19.45, 69.9, 40.39, 6.81, 48.22, 14.58, 75.91, 0, 8.75, 0, 68.8, 0, 11.02, 0, 0, 5.28, 0, 62.77, 85.26, 0, 55.18, 0, 62.22, 0, 78, 71.82, 7.5, 25.81, 14.76, 75.1, 26.19, 84.52, 0, 1.25, 1.25, 30.33, 9.87, 69.02, 6.25, 0, 79.76, 24.24, 62.61, 2.5, 0, 0, 38.05, 32.91, 18.05, 15.08, 68.99, 26.03, 0, 55.17, 2.5, 25.37, 10.56, 70.3, 0, 37.16, 11.25, 33.84, 22.08, 6.25],
[1.88, 1.25, 0, 43.91, 0, 13.12, 0, 34.38, 0, 0, 0, 67.03, 5, 16.25, 71.88, 0, 23.12, 22.81, 0, 0, 0, 27.81, 67.66, 10.62, 0, 62.22, 19.69, 75.94, 0, 3.75, 0, 57.5, 0, 30, 3.75, 6.25, 0.25, 0, 54.06, 65.47, 7.5, 77.19, 0, 50.94, 2.5, 65.55, 33.44, 36.88, 9.77, 0, 43.52, 5, 82.81, 0, 8.75, 0.62, 30.94, 7.5, 76.56, 16.25, 0, 77.34, 10.94, 66.02, 0, 3.75, 22.5, 53.05, 20, 22.5, 18.75, 63.05, 21.25, 0, 57.81, 0, 2.5, 13.75, 70.47, 4.38, 25.7, 56.17, 11.88, 43.12, 3.75],
[39.05, 0, 0, 51.33, 34.91, 0, 0, 0, 5.62, 0, 0, 45.81, 14.88, 8.96, 1.25, 59.06, 42.27, 20.91, 0, 7.92, 0, 19.36, 43.17, 0, 0, 20, 27.71, 10, 27.16, 1.88, 0, 28.73, 0, 10.42, 0, 0, 0, 77.61, 49.03, 20.9, 26.2, 1.25, 31.35, 16.41, 1.39, 51.05, 48.83, 20.67, 25.14, 31.69, 26.7, 0, 12.92, 0, 40.05, 31.67, 39.28, 5.62, 0.62, 11.49, 0, 0, 44.68, 24.72, 0, 1.39, 0, 9.38, 27.58, 43.67, 57.67, 0, 0.62, 0, 51.3, 0, 10, 9.72, 16.88, 29.47, 16.67, 0, 55.31, 29.79, 13.75],
[10.24, 6.25, 0, 91.2, 11.81, 0, 3.57, 0, 12.5, 12.5, 0, 36.85, 10, 28.57, 87.73, 0, 36.63, 11.25, 0, 0, 73.81, 2.5, 0, 44.24, 31.52, 29.59, 53.48, 12.08, 17.5, 0, 93.02, 0, 0, 27.06, 0, 0, 2.5, 0, 66.55, 22.74, 42.01, 68.53, 1.25, 35.5, 0, 82.61, 18.06, 37.86, 3.75, 0, 6.25, 7.5, 14.17, 0, 78.06, 0, 46.46, 77.33, 18.06, 2.5, 0, 11.25, 82.33, 34.31, 26.58, 3.75, 1.25, 5, 49.44, 58.47, 46.04, 3.75, 67.19, 0, 67.44, 1.25, 0, 2.5, 8.75, 40.42, 19.11, 30.34, 31.94, 10.14, 5.14],
[18.37, 55.35, 0, 32.53, 49.72, 0, 0, 0, 12.19, 6.25, 0, 51.94, 11.56, 1.25, 0, 83.07, 10, 7.5, 0, 1.56, 0, 0, 35.15, 0, 0, 70.04, 26.98, 8.75, 44.58, 5.62, 0, 20.23, 0, 24.06, 0, 0, 0, 27.47, 54.44, 55.15, 1.25, 0, 48.38, 0.62, 2.5, 45.35, 44.45, 3.12, 29.31, 26.07, 30.66, 0, 6.95, 0, 37.12, 8.4, 15.38, 21.25, 2.5, 22.97, 0, 0.62, 58.23, 53.33, 0, 0, 14.69, 14.69, 22.58, 32.09, 54.52, 0, 4.06, 0, 62.74, 0, 0, 0, 1.25, 50, 12.42, 23.96, 6.88, 23.26, 27.64],
[46.81, 0, 0, 44.86, 16.25, 0, 0, 0, 0, 0, 0, 46.94, 7.5, 20, 11.25, 54.44, 23.06, 30.14, 47.36, 10.28, 0, 10.42, 24.17, 0, 1.25, 43.61, 28.68, 22.5, 28.75, 17.5, 0, 28.75, 0, 12.64, 0, 0, 85, 4.51, 5, 40.97, 11.81, 11.39, 12.5, 11.39, 1.25, 35.56, 41.11, 2.5, 12.5, 10.14, 32.92, 83.33, 7.5, 6.25, 7.78, 4.03, 1.25, 0, 32.92, 0, 0, 3.75, 54.72, 30.83, 35.56, 35.97, 0, 0, 0, 25.83, 0, 0, 0, 35, 10, 78.75, 0, 0, 6.25, 33.47, 27.22, 17.92, 30.56, 23.19, 7.78],
[43.99, 27.68, 12.21, 50.8, 36.31, 5.91, 3.75, 0, 19.03, 15.62, 0, 30.02, 33.3, 52.54, 73.47, 0, 52.36, 3.75, 0, 0, 69.01, 3.75, 5, 12.78, 2.5, 58.05, 45.18, 8.75, 20.68, 0, 63.7, 0, 1.25, 49.28, 16.25, 0, 1.25, 0, 78.9, 5.28, 64.13, 88.78, 1.25, 30.51, 0, 77.18, 17.61, 51.44, 2.5, 0, 3.75, 1.25, 10.8, 0, 59.98, 2.5, 10.36, 66.23, 0, 2.5, 0, 0, 50.38, 54.7, 0, 18, 12.5, 21.75, 40.95, 6.25, 45.51, 5, 10, 0, 70.02, 2.5, 0, 1.25, 20.16, 33.65, 14.68, 16.79, 32.39, 13.12, 53.33],
[19.38, 50.09, 29.44, 8.98, 38.19, 0, 0, 0, 17.93, 6.25, 6.25, 90.19, 0, 6.25, 6.25, 42.02, 32.92, 14.54, 0, 0, 0, 42.44, 68.81, 8.89, 8.75, 66.8, 30.69, 41.26, 0, 0, 0, 55.68, 0, 7.86, 0, 6.25, 6.25, 0, 65.69, 26.98, 46.18, 12.58, 29.53, 12.5, 7.5, 66.15, 13.89, 49.04, 10.55, 1.56, 44.5, 39.96, 28.85, 0, 6.25, 5, 6.25, 6.25, 10.98, 5.08, 0, 9.03, 50.66, 54.31, 0, 1.39, 6.25, 6.25, 10.55, 8.98, 9.77, 6.25, 6.25, 0, 47.5, 1.25, 2.64, 0, 13.98, 43.69, 38.62, 6.25, 36.6, 9.17, 72.88],
[21.52, 25.04, 0, 35.13, 26.62, 0, 6.25, 3.75, 41.47, 21.23, 0, 17.05, 43.7, 51.37, 48.17, 12.78, 54.21, 0, 0, 0, 50.16, 7.5, 7.5, 1.25, 1.25, 40.56, 32.98, 5.62, 8.44, 0, 5, 0, 5, 51.17, 0, 0, 2.5, 0, 52.77, 16.24, 38.23, 27.44, 15.7, 14.4, 1.25, 68, 15.53, 38.78, 2.5, 0, 3.52, 0, 6.16, 0, 48.85, 4.51, 31.63, 9.42, 1.79, 9.11, 0, 0, 53.27, 41.02, 3.75, 10, 1.25, 8.12, 13.12, 2.5, 32.03, 2.5, 6.25, 0, 52.33, 3.87, 0, 3.75, 25.48, 27.85, 23.04, 28.07, 4.19, 17.88, 48.95],
[10, 95.62, 2.5, 3.12, 12.5, 0, 0, 2.5, 0, 0, 0, 81.94, 0, 22.29, 85.49, 0, 53.75, 0, 0, 5, 0, 26.25, 66.04, 3.12, 5, 10, 32.5, 62.15, 6.25, 10, 0, 59.03, 0, 22.5, 0, 3.75, 39.17, 0, 73.54, 38.68, 26.88, 77.08, 0, 16.25, 25, 66.04, 36.25, 29.38, 10, 58.12, 24.38, 62.29, 51.04, 0, 14.38, 0, 34.38, 0, 56.32, 21.25, 0, 30.38, 44.38, 31.25, 96.88, 48.96, 0, 0, 0, 0, 0, 0, 1.25, 35, 53.4, 45.9, 0, 17.5, 70.35, 3.75, 16.25, 13.75, 48.75, 19.38, 5],
[26.49, 64.32, 0, 47.12, 39.71, 0, 0, 0, 26.04, 9.26, 0, 80.89, 0, 7.95, 2.27, 63.35, 47.73, 5.05, 0, 0, 0, 15.21, 56.45, 4.51, 2.78, 61.1, 45.97, 0, 60.52, 0, 0, 17.59, 0, 12.71, 0, 87.59, 1.39, 16.89, 19.48, 63.26, 10.07, 4.79, 40, 13.96, 3.12, 49.13, 47.86, 22.73, 18.61, 19.68, 24.56, 0, 0, 0, 75.87, 0, 34.33, 48.98, 0, 0, 0, 0, 62.1, 45.71, 12.38, 11.46, 3.75, 21.11, 23.26, 28.32, 58.09, 0.25, 4, 0, 65.18, 0, 0, 7.27, 4.17, 61.21, 11.47, 38.16, 18.89, 38.89, 60.52],
[63.57, 43.1, 0, 17.29, 54.51, 0, 0, 0, 29.58, 22.87, 53.15, 67.37, 0, 3.41, 7.5, 45.54, 16.46, 0, 0, 12.27, 0, 22.08, 61.36, 0, 0, 69.78, 22.71, 26.25, 15.59, 1.04, 0, 27.28, 0, 19.97, 0, 1.14, 0.62, 0, 57.43, 38.21, 3.33, 10.62, 4.66, 13.11, 3.41, 66.87, 70.54, 4.17, 70.56, 29.83, 36.09, 31.25, 22.76, 0, 26.91, 12.33, 53.74, 3.41, 10.89, 34.74, 0, 8.98, 60.2, 24.7, 0, 13.12, 0, 0, 7.27, 63.95, 21.23, 0, 25.86, 0, 49.09, 2.08, 33.41, 7.27, 34.46, 2.27, 48.68, 13.01, 35.95, 28.26, 5],
[50.13, 33.78, 0, 46.23, 40.42, 0, 0, 0, 10.62, 1.25, 0, 48.79, 15.39, 7.36, 1.88, 73.98, 21.3, 28.76, 0, 3.34, 0, 10.78, 49.02, 0, 0, 73.69, 19.31, 42.87, 56.08, 0, 0, 36.73, 0, 48.31, 0, 8.12, 0, 26.44, 70.75, 57.14, 4.17, 18.19, 17.66, 20.24, 2.27, 73.62, 56.88, 11.95, 67.81, 31.16, 28.54, 1.25, 32.55, 0, 11, 25.8, 52.36, 0, 24, 60.85, 0, 3.4, 72.65, 70.36, 0, 11.67, 8.89, 9.03, 29.65, 36.71, 55.09, 1.39, 7.5, 0, 84.48, 0, 1.25, 20.47, 59.44, 11.81, 24.99, 18.94, 40.87, 23.12, 16.69],
[9.75, 7.5, 7.5, 19.9, 63.96, 0, 0, 0, 2.95, 1.39, 0, 1.14, 54.46, 76.17, 82, 0, 48.9, 0, 0, 0, 29.65, 17.99, 0, 9.9, 6.94, 17.05, 19.44, 0, 0, 3.12, 71.89, 3.75, 24.58, 22.52, 0, 0, 0, 0, 52.54, 10.21, 43.19, 55.38, 1.25, 12.04, 1.39, 58.54, 2.64, 32.6, 0, 0, 4.91, 1.04, 7.95, 0, 40.18, 9.72, 9.66, 23.06, 0, 0, 0, 0, 2.64, 72.1, 0, 0, 0, 39.55, 18.54, 0, 0, 27.78, 0, 0, 63.6, 13.54, 0, 0, 42.5, 13.33, 10.59, 22.85, 25.17, 3.75, 2.5],
[81.96, 29.31, 1.56, 35, 44.5, 0, 0, 0, 17.19, 23.12, 0, 20.45, 44.75, 49.97, 30.91, 24.32, 20.23, 10.94, 78.12, 2.5, 0, 2.5, 5, 0, 6.25, 41.14, 11.25, 10.31, 7.5, 8.3, 0, 3.75, 11.25, 6.25, 0, 8.75, 81.51, 0, 4.06, 29.32, 31.36, 23.18, 14.32, 10.45, 1.56, 11.25, 28.98, 20.8, 2.5, 0, 49.81, 70, 6.25, 12.5, 7.81, 1.25, 7.81, 0, 27.88, 0, 6.25, 6.25, 32.24, 30.31, 74.38, 59.15, 0, 0, 0, 0, 0, 0, 0, 82.53, 19.43, 78.92, 0, 5, 13.75, 35.31, 44.49, 49.84, 6.25, 18.18, 47.03],
[32.36, 89.09, 0, 8.41, 19.38, 0, 0, 0, 0, 0, 0, 61.16, 3.12, 9.55, 16.53, 18.75, 35.62, 2.5, 0, 0, 36.62, 3.75, 0, 10.8, 0, 11.65, 34.55, 1.25, 15, 0, 6.53, 0, 0, 23.75, 1.25, 1.25, 0, 0, 73.69, 4.66, 40.94, 10, 40, 8.41, 0, 80.14, 3.75, 36.22, 0, 0, 7.05, 0, 0, 0, 69.91, 0, 2.5, 68.35, 0, 1.25, 0, 0, 57.78, 57.33, 0, 5, 1.56, 17.19, 42.81, 1.25, 65.17, 2.5, 49.8, 0, 66.62, 0, 0, 0, 0, 54.38, 10.62, 73.69, 0, 7.5, 30.68],
[56.21, 23.51, 12.22, 32.69, 38.13, 0, 0, 0, 35.83, 6.94, 0, 72.76, 10, 5, 2.22, 61.5, 4.44, 61.86, 0, 0, 0, 31.44, 64.22, 26.27, 11.11, 59.15, 25.56, 39.91, 0, 9.44, 0, 59.78, 0, 13.61, 0, 0, 2.22, 0, 67.39, 43.24, 18.33, 8.96, 28.89, 35.03, 10, 80.44, 37.63, 32.01, 29.16, 5.62, 64.05, 30.39, 40.63, 0, 5, 3.89, 19.05, 0, 34.6, 18.37, 0, 51.61, 64.25, 58.2, 1.11, 7.78, 1.11, 5.56, 7.78, 4.44, 10, 2.22, 5.56, 0, 60.42, 2.22, 10, 1.11, 35.98, 28.82, 52.9, 3.33, 47.54, 17.22, 83.55],
[87.99, 85.35, 0, 0, 0, 0, 0, 0, 6.46, 1.25, 85.76, 80, 0.62, 8.12, 0, 64.33, 28.06, 16.88, 0, 8.33, 0, 27.71, 59.79, 1.39, 2.78, 83.33, 20, 8.75, 18.12, 5, 0, 17.57, 0, 100, 0, 0, 8.33, 2.5, 64.86, 30.21, 30.97, 3.12, 36.04, 3.75, 5, 70.83, 22.5, 19.93, 35.62, 33.68, 15.42, 6.94, 4.38, 0, 44.38, 6.46, 33.75, 0, 4.17, 5.56, 6.94, 0, 67.85, 21.88, 0, 0, 0, 8.33, 11.6, 47.85, 51.46, 0, 5.62, 0, 73.82, 4.03, 3.26, 10, 17.29, 46.11, 12.5, 2.5, 47.85, 18.19, 8.89],
[36.04, 6.77, 0, 55.21, 34.48, 0, 0, 0, 0, 0, 3.12, 64.93, 5, 7.81, 20, 45.24, 0, 44.93, 0, 67.64, 0, 18.75, 27.5, 55.69, 0, 72.26, 29.48, 3.12, 0, 3.12, 0, 0, 0, 16.98, 0, 0, 0, 0, 37.85, 58.12, 8.75, 24.17, 18.23, 24.06, 32.81, 32.4, 74.9, 7.5, 5, 0, 88.33, 0.62, 9.38, 0, 52.81, 12.57, 35.38, 21.15, 5, 2.5, 0, 2.08, 21.01, 45.83, 0, 0, 0, 8.75, 1.25, 33.26, 3.75, 61.88, 3.75, 0, 20.62, 0, 76.01, 1.25, 17.71, 23.68, 55, 39.06, 9.38, 31.67, 12.78],
[10.56, 13.19, 0, 64.51, 72.67, 0, 3.12, 0, 9.79, 2.5, 12.5, 78.78, 0, 14.62, 2.64, 74.1, 0, 19.58, 0, 17.19, 0, 23.12, 43.19, 1.39, 3.89, 84.69, 37.53, 0, 49.44, 0, 0, 29.19, 0, 1.95, 13.02, 23.02, 3.75, 8.06, 35.21, 54.93, 4.03, 14.17, 19.97, 15, 36.04, 42.67, 62.01, 2.5, 10, 51.26, 60.26, 3.75, 1.25, 0, 64.72, 3.75, 77.01, 2.5, 0, 0, 0, 0, 67.85, 38.65, 3.75, 7.5, 0, 10, 10, 66.53, 19.9, 6.25, 10, 0, 28.68, 1.25, 70.99, 1.25, 7.92, 38.16, 18.48, 9.17, 22.58, 21.56, 14.76],
[42.47, 30.12, 3.75, 20, 2.5, 72.91, 5, 17.18, 8.26, 6.08, 89.3, 66.23, 2.64, 16.01, 78.36, 0, 14.09, 48.87, 0, 0, 0, 33.28, 60.83, 19.2, 8.33, 66.65, 31.74, 85.88, 21.11, 0, 0, 75.05, 0, 28.52, 0, 0, 5, 0, 73.92, 76.21, 3.75, 84.04, 0, 82.34, 0, 77.12, 59.45, 10.62, 38.24, 14.24, 68.97, 13.61, 86.97, 0, 4.63, 1.25, 4.38, 0, 84.93, 2.75, 0, 87.22, 13.12, 83.28, 1.25, 0, 1.25, 47.83, 25.42, 21.32, 2.5, 75.79, 14.88, 0, 76.24, 4.06, 6.39, 20.14, 83.81, 3.75, 36.59, 23.96, 39.36, 24.34, 5.64],
[18.84, 4.82, 0, 67.59, 44.27, 0, 0, 0, 11.61, 13.12, 0, 18.99, 45.94, 62.06, 76.46, 2.5, 82.08, 9.11, 70.92, 1.25, 0, 2.5, 6.88, 0, 4.38, 20.76, 33.75, 31.88, 28.12, 24.74, 0, 5, 90.74, 28.12, 0, 5, 76.85, 0, 6.25, 24.76, 39.06, 49.94, 5, 16.55, 7.19, 11.88, 8.78, 52.14, 2.5, 6.88, 14.68, 76.63, 9, 0, 9.62, 2.5, 8.12, 0.25, 20.36, 11.88, 0, 10, 38.06, 16.88, 81.45, 38.78, 0, 0, 0, 0, 0, 0, 0, 74.18, 12.5, 62.3, 0, 0, 20, 29.43, 23.96, 60.41, 11.19, 33.77, 18.75],
[30.57, 14.11, 0, 62.48, 30.34, 0, 0, 2.5, 16.61, 8.75, 2.5, 45.41, 1.25, 5, 1.25, 51.98, 0, 54.05, 0, 0, 0, 10.42, 44.21, 6.25, 11.25, 30.55, 21.72, 36.87, 6.94, 0, 0, 39.14, 0, 22.39, 0, 7.5, 14.93, 20.83, 51.87, 61.5, 2.27, 8.89, 12.27, 22.11, 3.75, 57.65, 52.09, 1.25, 28.32, 3.57, 47.2, 9.58, 43.37, 0, 9.66, 3.75, 28.79, 7.5, 41.69, 11.43, 0, 19.84, 36.01, 23.61, 0, 6.39, 6.25, 16.25, 16.44, 31.66, 45.63, 1.25, 9.82, 0, 44.9, 5.42, 12.5, 10, 59.37, 3.75, 36.56, 13.06, 26.26, 10.14, 3.75],
[43.74, 23.96, 0, 58.07, 53.94, 0, 0, 0, 19.92, 0, 0, 67.2, 0, 5.83, 40.21, 10.92, 5.21, 41.55, 0, 0, 0, 23.96, 59.95, 1.56, 8.14, 62.59, 35.78, 58.82, 0, 0, 5.92, 25.69, 0, 4.17, 0, 0, 0, 1.56, 32.95, 63.13, 0, 61.63, 0, 63.39, 0, 80.84, 77.68, 0, 38.82, 4.17, 64.18, 2.78, 79.77, 0, 0, 0, 22.97, 0, 71.47, 33.4, 0, 73.45, 59.47, 52.78, 36.11, 11.67, 0, 9.5, 29.04, 34.25, 16.9, 10.76, 23.47, 0, 50.94, 0, 0, 30.49, 75.75, 0, 61.48, 46.81, 41.37, 18.4, 5],
[85.04, 85.04, 0, 0, 0, 0, 0, 0, 0, 0, 98.86, 36.4, 7.69, 20.8, 48.82, 8.18, 11.88, 38.74, 0, 0, 76.45, 0, 0, 39.02, 21.88, 61.43, 64.72, 0, 14.38, 6.25, 0, 0, 0, 9.55, 0, 4.17, 0, 0, 77.61, 70.85, 1.25, 29.85, 16.25, 38.76, 0, 89, 42.51, 20, 0, 0, 29.52, 0, 5.14, 0, 74.3, 0, 32.08, 78.67, 4.17, 0, 0, 0, 5.42, 90.92, 0, 0, 0, 41.53, 58.91, 8.75, 29.09, 10, 10, 0, 82.75, 0, 0, 0, 0, 53.25, 22.63, 80.6, 1.25, 19.09, 9.94]
])

In [49]:
# Matrix of embedding vectors of 40 seen classes
arr_for_calc = arr[[0, 1, 2, 3, 4, 5, 6, 9, 10, 11, 12, 13, 15, 16, 17, 19, 20, 21, 23, 24, 25, 26, 27, 29, 30, 31, 32, 33, 34, 36, 37, 39, 41, 42, 43, 45, 46, 47, 48, 49], :]

In [50]:

def emb_from_predictions(X):
    X = X.reshape(40,1)
    return np.sum(X*arr_for_calc, axis=0) # X is the product of the predicted probabilities out of each of the 40 seen classes and it is being multiplied by the corresponding embedding vectors of each of the seen class.

In [32]:
# For calculating cosine similarity
import numpy as np
from numpy.linalg import norm
from sklearn.metrics.pairwise import cosine_similarity
import pandas as pd
import os
from tensorflow.keras.preprocessing.image import load_img, img_to_array

# Test directory
test_dir = '/kaggle/input/vlg-recruitment-24-challenge/vlg-dataset/vlg-dataset/test' 
image_size = (336, 336)  
batch_size = 32  

# Loading and preprocessing test images
def preprocess_image(image_path):
    img = load_img(image_path, target_size=image_size)
    img_array = img_to_array(img)  
    return img_array

# Getting all image file names
image_paths = [os.path.join(test_dir, fname) for fname in os.listdir(test_dir)]
image_ids = [os.path.basename(fname) for fname in image_paths]  

# Preprocessing all test images
test_images = np.array([preprocess_image(path) for path in image_paths])

# Generating predictions
predictions = resnet_model.predict(test_images, batch_size=batch_size)

[1m94/94[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m90s[0m 956ms/step


In [48]:
# Calling all the functions created for building the csv
batch_embs = np.array([np.squeeze(emb_from_predictions(row)) for row in predictions]) 

cos_sim_matrix = cosine_similarity(batch_embs, arr) # Cosine similarity

predicted_num = np.array([np.argmax(row) for row in cos_sim_matrix]) # 

predicted_labels = [class_names[id] for id in predicted_num] # Converting index to name of class
# Here keeping the track of dimensions of each of the array or dataframe helps to understand what calculations are actually being
# carried out and thus helps in debugging too

# Submission DataFrame
submission = pd.DataFrame({
    "image_id": image_ids,  
    "class": predicted_labels  
})

# Saving to CSV
submission.to_csv("submission.csv", index=False)

print("Submission file saved as 'submission.csv'")

(3000, 40)
(3000, 31)
(3000, 50)
[41 37 15 ... 20 23 43]
Submission file saved as 'submission.csv'
