# Inference and Prediction

This notebook is used to load the trained model and test it on new images, such as screenshots from Amazon.

In [5]:
import torch
from PIL import Image
import joblib
from torchvision import transforms
import os
import torch.nn as nn
from torchvision import models

class MultiOutputModel(nn.Module):
    def __init__(self, num_genders, num_colors, num_seasons, num_products):
        super().__init__()
        base = models.resnet18(pretrained=True)
        self.features = nn.Sequential(*list(base.children())[:-1])
        self.dropout = nn.Dropout(0.3)
        self.gender = nn.Linear(512, num_genders)
        self.color = nn.Linear(512, num_colors)
        self.season = nn.Linear(512, num_seasons)
        self.product = nn.Linear(512, num_products)

    def forward(self, x):
        x = self.features(x).squeeze()
        x = self.dropout(x)
        return {
            'gender': self.gender(x),
            'color': self.color(x),
            'season': self.season(x),
            'product': self.product(x)
        }


In [6]:
# 🔁 Load everything
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

gender_enc = joblib.load("gender_encoder.pkl")
color_enc = joblib.load("baseColour_encoder.pkl")
season_enc = joblib.load("usage_encoder.pkl")
product_enc = joblib.load("masterCategory_encoder.pkl")

model = MultiOutputModel(len(gender_enc.classes_), len(color_enc.classes_),
                         len(season_enc.classes_), len(product_enc.classes_))
model.load_state_dict(torch.load("fashion_model.pth", map_location=device))
model = model.to(device).eval()

transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor()
])




In [None]:
# 🔮 Prediction Function
def predict(img_path):
    image = Image.open(img_path).convert("RGB")
    image = transform(image).unsqueeze(0).to(device)
    with torch.no_grad():
        outputs = model(image)
        results = {
            'Gender': gender_enc.inverse_transform([outputs['gender'].argmax(0).item()])[0],
            'Color': color_enc.inverse_transform([outputs['color'].argmax(0).item()])[0],
            'Season': season_enc.inverse_transform([outputs['season'].argmax(0).item()])[0],
            'Product': product_enc.inverse_transform([outputs['product'].argmax(0).item()])[0]
        }
    return results


In [None]:
# 🖼️ Predict from Amazon Screenshot
predict("/Users/lakshiitakalyanasundaram/Desktop/projects/CodeMonk Assignment/fashion-product-classifier/amazon_screenshots/Screenshot 2025-07-20 at 1.50.43 AM.png")


IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)