In [1]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.preprocessing.image import ImageDataGenerator

In [2]:
# Define hyperparameters
batch_size = 32
img_size = (224, 224)
num_classes = 30
epochs = 10
refund_classes = [1, 12, 23]  # Classes allowed for refund

In [3]:
# Load and preprocess data
data_generator = ImageDataGenerator(
    validation_split=0.2,
    preprocessing_function=tf.keras.applications.mobilenet_v2.preprocess_input,
)

train_generator = data_generator.flow_from_directory(
    directory="balls",
    target_size=img_size,
    batch_size=batch_size,
    class_mode="categorical",
    subset="training",
)

validation_generator = data_generator.flow_from_directory(
    directory="balls",
    target_size=img_size,
    batch_size=batch_size,
    class_mode="categorical",
    subset="validation",
)

Found 2904 images belonging to 30 classes.
Found 711 images belonging to 30 classes.


In [4]:
# Define and train the model
model = keras.Sequential(
    [
        layers.Input(shape=(img_size[0], img_size[1], 3)),
        tf.keras.applications.MobileNetV2(
            include_top=False, weights="imagenet", input_tensor=None, pooling="avg"
        ),
        layers.Flatten(),
        layers.Dense(1024, activation="relu"),
        layers.Dropout(0.5),
        layers.Dense(num_classes, activation="softmax"),
    ]
)

model.compile(optimizer="adam", loss="categorical_crossentropy", metrics=["accuracy"])

model.fit(
    train_generator,
    epochs=epochs,
    validation_data=validation_generator,
)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<keras.callbacks.History at 0x1f3fc85a790>

In [6]:
# Use the model to make predictions
def predict_class(image_path):
    img = keras.preprocessing.image.load_img(image_path, target_size=img_size)
    img_array = keras.preprocessing.image.img_to_array(img)
    img_array = tf.expand_dims(img_array, 0)  # Create batch axis

    predictions = model.predict(img_array)
    predicted_class = predictions[0].argmax() + 1  # Add 1 to convert from 0-indexed to 1-indexed

    if predicted_class in refund_classes:
        return "refund"
    else:
        return "declined"

In [7]:
predict_class("1.jpg")



'declined'