In [35]:
import tensorflow as tf
from tensorflow import keras

# Load the Fashion MNIST dataset
fashion_mnist = keras.datasets.fashion_mnist
(train_images, train_labels), (_, _) = fashion_mnist.load_data()

# Preprocess the data
train_images = train_images / 255.0

# Define the model architecture
model = keras.Sequential([
    keras.layers.Flatten(input_shape=(28, 28)),
    keras.layers.Dense(128, activation='relu'),
    keras.layers.Dense(10, activation='softmax')
])

# Compile and train the model
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
model.fit(train_images, train_labels, epochs=15)

# Save the trained model
model.save('fashion_mnist_model.h5')


In [35]:
import tensorflow as tf
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from google.colab import files

# Load the Fashion MNIST dataset
fashion_mnist = keras.datasets.fashion_mnist
(_, _), (test_images, _) = fashion_mnist.load_data()

# Preprocess the data
test_images = test_images / 255.0

# Load the trained model
model = keras.models.load_model('fashion_mnist_model.h5')  # Use the correct model file name

# Define class labels
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress',
               'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']

# Function to preprocess the uploaded image
def preprocess_image(image_path):
    img = Image.open(image_path).convert('L')
    img = img.resize((28, 28), Image.ANTIALIAS)
    img = np.array(img) / 255.0
    img = np.expand_dims(img, axis=0)
    return img

# Function to predict the class of the uploaded image
def predict_image(image_path):
    img = preprocess_image(image_path)
    prediction = model.predict(img)
    predicted_class = np.argmax(prediction[0])
    return class_names[predicted_class]

# Upload the image file
uploaded = files.upload()

# Get the uploaded image path
image_path = list(uploaded.keys())[0]

# Predict the class of the uploaded image
predicted_class = predict_image(image_path)
print(f"Predicted class: {predicted_class}")

# Display the uploaded image
img = Image.open(image_path)
plt.imshow(img, cmap='gray')
plt.axis('off')
plt.show()
