Problem Statement "AI-Enhanced Product Photoshoot Visuals and Filter"

In [None]:
!pip install keras_cv

In [None]:
import tensorflow as tf
from keras_cv.models import StableDiffusion
import matplotlib.pyplot as plt
import numpy as np

In [None]:
# Initialize the Stable Diffusion model
# take tectual prompt and give visual product
model_diffusion = StableDiffusion(img_width=512, img_height=512)

In [None]:
# Load a pre-trained MobileNetV2 model for image classification
# We load a pre-trained MobileNetV2 model from TensorFlow's applications module.
# This model will be used for image classification.
model_classification = tf.keras.applications.MobileNetV2(weights='imagenet')

In [None]:
def generate_images(prompt, batch_size=3):
    # Generate images based on the prompt using Stable Deffusion model
    images = model_diffusion.text_to_image(prompt, batch_size=batch_size)
    return images


In [None]:
def plot_images(images):
    # Plot the generated images
    # function takes a list of images and plots them using matplotlib.
    plt.figure(figsize=(20, 20))
    for i, image in enumerate(images):
        ax = plt.subplot(1, len(images), i + 1)
        plt.imshow(image)
        plt.axis("off")
    plt.show()

In [None]:
def preprocess_images(images):
    # Resize images to match the input shape expected by MobileNetV2 (224x224)
    resized_images = [tf.image.resize(img, (224, 224)) for img in images]
    # Preprocess images for MobileNetV2
    preprocessed_images = [tf.keras.applications.mobilenet_v2.preprocess_input(img) for img in resized_images]
    preprocessed_images = np.array(preprocessed_images)
    return preprocessed_images

In [None]:
def classify_images(images, model):
    # Preprocess and resize images
    preprocessed_images = preprocess_images(images)

    # Make predictions
    predictions = model.predict(preprocessed_images)

    # Decode predictions
    # the MobileNetV2 model converts the raw predictions (typically a list of class probabilities) into a list of tuples,
    # where each tuple contains the ImageNet class ID, label, and probability score for each image.

    decoded_predictions = tf.keras.applications.mobilenet_v2.decode_predictions(predictions)
    return decoded_predictions

In [None]:
def filter_non_relevant_images(images, model):
    # Implement a binary classifier for demonstration

    relevant_images = []
    for image in images:
        # Use the MobileNetV2 model for binary classification
        preprocessed_image = preprocess_images([image])
        prediction = model.predict(preprocessed_image)#check weather image is relevent or not

        # Modify this condition based on your binary classification model
        if prediction[0][0] > 0.5:
            relevant_images.append(image)

    return relevant_images

In [None]:
# Example usage
#prompt = "Car"
#batch_size = 3

import random
entities = [
    "Shoe", "Sneaker", "Bottle", "Cup", "Sandal", "Perfume", "Toy", "Sunglasses",
    "Car", "Water Bottle", "Chair", "Office Chair", "Can", "Cap", "Hat",
    "Couch", "Wristwatch", "Glass", "Bag", "Handbag", "Baggage", "Suitcase",
    "Headphones", "Jar", "Vase"
]

def select_entity():
     return random.choice(entities)
if __name__ == "__main__":
     chosen_entity = select_entity()
     print("Chosen entity:", chosen_entity)
prompt=chosen_entity
batch_size=2

In [None]:
# Generate and plot images
generated_images = generate_images(prompt, batch_size=batch_size)
relevant_images = filter_non_relevant_images(generated_images, model_classification)
plot_images(generated_images)

In [None]:
# Classify generated images
classifications = classify_images(generated_images, model_classification)

In [None]:
accuracies = []
for i, img_class in enumerate(classifications):
    top_prediction = img_class[0]  # Get the top predicted class
    _, true_label, _ = top_prediction  # Extract the true label from the top prediction
    _, predicted_label, score = classifications[i][0]  # Extract the predicted label and score

    if true_label.lower() in prompt.lower():
        accuracy = score  # Use confidence score as accuracy if true label is in the prompt
    else:
        accuracy = 1 - score  # Invert confidence score as accuracy if true label is not in the prompt

    accuracies.append(accuracy)

    # Print the classification result including accuracy
    print(f"Image {i + 1} classification (Accuracy: {accuracy:.2f}):")
    for j, (imagenet_id, label, score) in enumerate(img_class):
        print(f"{j + 1}: {label} ({score:.2f})")
    print()

# Print the average accuracy
average_accuracy = sum(accuracies) / len(accuracies)
print(f"Average Accuracy: {average_accuracy:.2f}")
