In [4]:
import tensorflow as tf
from tensorflow.keras.preprocessing import image
from tensorflow.keras.applications.vgg16 import VGG16, preprocess_input, decode_predictions

In [46]:
import numpy as np

In [5]:
# Define the path to the dataset folder
data_path = "data"

In [6]:
# Create a data generator for the dataset
data_gen = tf.keras.preprocessing.image_dataset_from_directory(
    data_path,
    validation_split=0.2, # Split the dataset into train and validation sets
    subset="training",
    seed=123,
    image_size=(224, 224),
    batch_size=32,
)


Found 616 files belonging to 6 classes.
Using 493 files for training.


In [7]:
# Preprocess the data
def preprocess_fn(img, label):
    img = tf.image.resize(img, (224, 224))
    img = preprocess_input(img)
    return img, label

In [8]:
data_gen = data_gen.map(preprocess_fn)

In [9]:
# Load the pre-trained model
model = VGG16(weights='imagenet')

In [55]:
from collections import defaultdict
output = defaultdict(set)

In [56]:
# Iterate through the dataset and predict classes
for img, label in data_gen:
    for i in range(len(img)):
        image = tf.expand_dims(img[i], axis=0)
        preds = model.predict(image)
        classes = decode_predictions(preds, top=1)[0]
        lab = label[i].numpy()
        output[lab].add(classes[0][1])
        print(lab, classes[0][1])

2 binder
4 pill_bottle
5 packet
3 carpenter's_kit
4 bakery
3 carpenter's_kit
3 screwdriver
0 container_ship
4 restaurant
1 broom
3 hammer
2 matchstick
1 miniskirt
5 handkerchief
5 Christmas_stocking
3 screwdriver
3 hatchet
3 envelope
3 web_site
4 restaurant
5 lighter
5 hotdog
5 birdhouse
2 desktop_computer
2 combination_lock
0 pirate
5 lighter
3 paintbrush
5 envelope
0 cinema
1 crutch
1 jean
0 aircraft_carrier
5 matchstick
5 whistle
3 screwdriver
5 birdhouse
2 envelope
3 shovel
4 tray
0 container_ship
0 bullet_train
3 web_site
2 paintbrush
0 brass
2 crossword_puzzle
0 container_ship
5 sewing_machine
2 envelope
0 container_ship
3 screwdriver
0 container_ship
4 plate
4 tray
3 carpenter's_kit
3 puck
1 miniskirt
4 hotdog
1 shoe_shop
5 safety_pin
5 lighter
0 container_ship
4 shoe_shop
0 space_shuttle
5 knot
3 menu
4 candle
3 jigsaw_puzzle
3 rubber_eraser
3 forklift
1 sunglasses
2 switch
4 hair_spray
3 radiator
5 Christmas_stocking
3 menu
1 cowboy_boot
1 puck
4 confectionery
2 cleaver
0 cont

In [66]:
value_0 = output.pop(0)
output['cargo'] = value_0
value_1 = output.pop(1)
output['clothing'] = value_1
value_2 = output.pop(2)
output['document'] = value_2
value_3 = output.pop(3)
output['equipment'] = value_3
value_4 = output.pop(4)
output['food'] = value_4
value_5 = output.pop(5)
output['gift'] = value_5

In [67]:
with open("output.txt", 'w') as f:
  for line in output.items():
        f.write(str(line))
        f.write('\n')

In [58]:
output

defaultdict(set,
            {2: {'analog_clock',
              'binder',
              'birdhouse',
              'brass',
              'cleaver',
              'combination_lock',
              'crossword_puzzle',
              'desktop_computer',
              'envelope',
              'fountain_pen',
              'hard_disc',
              'jersey',
              'lab_coat',
              'lampshade',
              'matchstick',
              'menu',
              'packet',
              'paintbrush',
              'perfume',
              'photocopier',
              'refrigerator',
              'rubber_eraser',
              'rule',
              'scale',
              'slide_rule',
              'spatula',
              'swab',
              'switch',
              'toilet_seat',
              'water_bottle',
              'web_site',
              'window_shade'},
             4: {'analog_clock',
              'apron',
              'bagel',
              'bakery',
         