<a href="https://colab.research.google.com/github/juhiikataria/dress_prediction_model/blob/main/dress_prediction_model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import csv

# Data: Image references and descriptions
image_data = [
    ("/content/image1.jpeg", "formal"),
    ("/content/image2.jpeg", "formal"),
    ("/content/images3.jpeg", "formal"),
    ("/content/image4.jpeg", "formal"),
    ("/content/images5.jpeg", "not formal"),
    ("/content/images =6.jpeg", "not formal"),
    ("/content/images7.jpeg", "not formal"),
    ("/content/images8.jpeg", "not formal"),
    ("/content/images9.jpeg", "formal"),
    ("/content/images10.jpeg", "formal")
]

# Create a CSV file and write data to it
csv_file_path = "image_data.csv"
with open(csv_file_path, mode='w', newline='', encoding='utf-8') as csv_file:
    csv_writer = csv.writer(csv_file)

    # Write header
    csv_writer.writerow(["Image", "Label"])

    # Write image data
    csv_writer.writerows(image_data)

print("CSV file created:", csv_file_path)

CSV file created: image_data.csv


In [2]:
import pandas as pd

data = pd.read_csv('image_data.csv')
print(data)

                     Image       Label
0     /content/image1.jpeg      formal
1     /content/image2.jpeg      formal
2    /content/images3.jpeg      formal
3     /content/image4.jpeg      formal
4    /content/images5.jpeg  not formal
5  /content/images =6.jpeg  not formal
6    /content/images7.jpeg  not formal
7    /content/images8.jpeg  not formal
8    /content/images9.jpeg      formal
9   /content/images10.jpeg      formal


In [3]:

%pip install --upgrade pip
%pip install --disable-pip-version-check \
    torch==1.13.1 \
    torchdata==0.5.1 --quiet

%pip install \
    transformers==4.27.2 \
    datasets==2.11.0 \
    evaluate==0.4.0 \
    rouge_score==0.1.2 \
    loralib==0.1.1 \
    peft==0.3.0 --quiet

Collecting pip
  Downloading pip-23.2.1-py3-none-any.whl (2.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m19.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pip
  Attempting uninstall: pip
    Found existing installation: pip 23.1.2
    Uninstalling pip-23.1.2:
      Successfully uninstalled pip-23.1.2
Successfully installed pip-23.2.1
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m887.5/887.5 MB[0m [31m1.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.6/4.6 MB[0m [31m65.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m849.3/849.3 kB[0m [31m44.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m557.1/557.1 MB[0m [31m1.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m317.1/317.1 MB[0m [31m4.2 MB/s[0m eta [36m0:00:0

In [4]:
import pandas as pd
import torch
import torch.nn as nn
from PIL import Image
from transformers import AutoImageProcessor, ViTModel
from sklearn.model_selection import train_test_split

In [5]:
# Load the dataset from CSV file
dataset_path = "image_data.csv"
data = pd.read_csv(dataset_path)

In [6]:

# Assuming your CSV file has columns "Image" and "Label"
images = data["Image"].tolist()
labels = data["Label"].tolist()


In [7]:
# Convert labels to numerical values (0 for "not formal", 1 for "formal")
label_to_index = {"not formal": 0, "formal": 1}
numerical_labels = [label_to_index[label.strip()] for label in labels]


In [8]:
# Split dataset into train and test
train_images, test_images, train_labels, test_labels = train_test_split(images, numerical_labels, test_size=0.2)


In [9]:
# Load pre-trained image processor and model
image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
model = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")


Downloading (…)rocessor_config.json:   0%|          | 0.00/160 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/502 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/346M [00:00<?, ?B/s]

In [10]:
# Define DressStyleClassifier
class DressStyleClassifier(nn.Module):
    def __init__(self, input_dim, num_classes):
        super(DressStyleClassifier, self).__init__()
        self.fc = nn.Linear(input_dim, num_classes)

    def forward(self, x):
        return self.fc(x)



In [11]:
# Preprocess images and extract features for training
train_features = []
with torch.no_grad():
    for image_path in train_images:
        image = Image.open(image_path)
        inputs = image_processor(images=image, return_tensors="pt")
        features = model(**inputs).last_hidden_state
        train_features.append(features[:, 0, :])

train_features = torch.cat(train_features, dim=0)

In [12]:
# Initialize and train the classifier
classifier = DressStyleClassifier(train_features.shape[-1], num_classes=2)  # Two classes: formal and not formal
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(classifier.parameters(), lr=0.001)

In [13]:
num_epochs = 10
for epoch in range(num_epochs):
    optimizer.zero_grad()
    outputs = classifier(train_features)
    loss = criterion(outputs, torch.tensor(train_labels))
    loss.backward()
    optimizer.step()
    print(f"Epoch {epoch+1}/{num_epochs} - Loss: {loss.item():.4f}")

Epoch 1/10 - Loss: 0.7420
Epoch 2/10 - Loss: 0.7053
Epoch 3/10 - Loss: 0.6701
Epoch 4/10 - Loss: 0.6362
Epoch 5/10 - Loss: 0.6038
Epoch 6/10 - Loss: 0.5728
Epoch 7/10 - Loss: 0.5432
Epoch 8/10 - Loss: 0.5150
Epoch 9/10 - Loss: 0.4882
Epoch 10/10 - Loss: 0.4628


In [14]:
# Save trained classifier weights
torch.save(classifier.state_dict(), "dress_style_classifier.pth")


In [15]:
# Load the trained classifier
trained_classifier = DressStyleClassifier(train_features.shape[-1], num_classes=2)
trained_classifier.load_state_dict(torch.load("dress_style_classifier.pth"))
trained_classifier.eval()


DressStyleClassifier(
  (fc): Linear(in_features=768, out_features=2, bias=True)
)

In [52]:
# Get an image path from the user
new_image_path = input("Enter the path to the image: ")

Enter the path to the image: /content/WhatsApp Image 2023-09-01 at 12.03.03 AM.jpeg


In [53]:
# Preprocess and classify the new image
image = Image.open(new_image_path)
inputs = image_processor(images=image, return_tensors="pt")
new_features = model(**inputs).last_hidden_state
predicted_style = trained_classifier(new_features[:, 0, :])


In [54]:
# Apply softmax and get the predicted class index
predicted_class_idx = torch.argmax(predicted_style, dim=1).item()

In [55]:
# Map class index to class label
index_to_label = {0: "not formal", 1: "formal"}
predicted_label = index_to_label[predicted_class_idx]

print("Predicted Dress Style:", predicted_label)

Predicted Dress Style: not formal
