<a href="https://colab.research.google.com/github/katybohanan/5588-hands-on-26/blob/main/multimodal_model_ecommerce_data.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torchvision.transforms as transforms
from transformers import AutoTokenizer, ViTFeatureExtractor, ViTModel, AutoModel
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import pandas as pd

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [3]:
from google.colab import files
uploaded = files.upload()

Saving kaggle.json to kaggle (1).json


In [4]:
!pip install -q kaggle

In [5]:
!mkdir -p ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json

In [6]:
!kaggle datasets download -d vikashrajluhaniwal/fashion-images

Dataset URL: https://www.kaggle.com/datasets/vikashrajluhaniwal/fashion-images
License(s): CC0-1.0
fashion-images.zip: Skipping, found more recently modified local copy (use --force to force download)


In [7]:
!unzip fashion-images.zip

Archive:  fashion-images.zip
replace data/Apparel/Boys/Images/images_with_product_ids/10054.jpg? [y]es, [n]o, [A]ll, [N]one, [r]ename: a
error:  invalid response [a]
replace data/Apparel/Boys/Images/images_with_product_ids/10054.jpg? [y]es, [n]o, [A]ll, [N]one, [r]ename: A
  inflating: data/Apparel/Boys/Images/images_with_product_ids/10054.jpg  
  inflating: data/Apparel/Boys/Images/images_with_product_ids/10649.jpg  
  inflating: data/Apparel/Boys/Images/images_with_product_ids/10671.jpg  
  inflating: data/Apparel/Boys/Images/images_with_product_ids/12840.jpg  
  inflating: data/Apparel/Boys/Images/images_with_product_ids/12844.jpg  
  inflating: data/Apparel/Boys/Images/images_with_product_ids/12845.jpg  
  inflating: data/Apparel/Boys/Images/images_with_product_ids/12846.jpg  
  inflating: data/Apparel/Boys/Images/images_with_product_ids/12847.jpg  
  inflating: data/Apparel/Boys/Images/images_with_product_ids/13306.jpg  
  inflating: data/Apparel/Boys/Images/images_with_product_id

In [8]:
df = pd.read_csv('/content/data/fashion.csv')

In [9]:
df.head()

Unnamed: 0,ProductId,Gender,Category,SubCategory,ProductType,Colour,Usage,ProductTitle,Image,ImageURL
0,42419,Girls,Apparel,Topwear,Tops,White,Casual,Gini and Jony Girls Knit White Top,42419.jpg,http://assets.myntassets.com/v1/images/style/p...
1,34009,Girls,Apparel,Topwear,Tops,Black,Casual,Gini and Jony Girls Black Top,34009.jpg,http://assets.myntassets.com/v1/images/style/p...
2,40143,Girls,Apparel,Topwear,Tops,Blue,Casual,Gini and Jony Girls Pretty Blossom Blue Top,40143.jpg,http://assets.myntassets.com/v1/images/style/p...
3,23623,Girls,Apparel,Topwear,Tops,Pink,Casual,Doodle Kids Girls Pink I love Shopping Top,23623.jpg,http://assets.myntassets.com/v1/images/style/p...
4,47154,Girls,Apparel,Bottomwear,Capris,Black,Casual,Gini and Jony Girls Black Capris,47154.jpg,http://assets.myntassets.com/v1/images/style/p...


In [10]:
# Load the Vision Transformer model and feature extractor for images
vit_model = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")
feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")

# Load the BERT model and tokenizer for text
text_model_name = 'bert-base-uncased'  # You can use 'distilbert-base-uncased' for DistilBERT
text_model = AutoModel.from_pretrained(text_model_name)
tokenizer = AutoTokenizer.from_pretrained(text_model_name)


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [11]:
import os
import glob

# Define the root directory containing all subfolders with images
image_root_dir = "/content/data"  # Update this path

# Scan all subdirectories and create a mapping of image names to their full paths
image_paths = {os.path.basename(path): path for path in glob.glob(f"{image_root_dir}/**/*.jpg", recursive=True)}

# Function to get the correct image path from the mapping
def get_image_path(image_name):
    return image_paths.get(image_name, None)  # Return None if the image is missing

In [12]:
image_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

In [13]:
def preprocess(row):
    # Process text
    text_inputs = tokenizer(row["ProductTitle"], padding="max_length", truncation=True, return_tensors="pt")

    # Get the correct image path
    image_path = get_image_path(row["Image"])
    if image_path is None:
        raise FileNotFoundError(f"Image {row['Image']} not found in directories!")

    # Load and transform the image
    image = Image.open(image_path).convert("RGB")
    image_tensor = image_transform(image).unsqueeze(0)  # Add batch dimension

    return text_inputs, image_tensor


In [14]:
# Test with a random row from your dataset
sample_text, sample_image = preprocess(df.iloc[0])
print(sample_text, sample_image.shape)  # Ensure it prints expected values


{'input_ids': tensor([[  101, 18353,  2072,  1998,  6285,  2100,  3057, 22404,  2317,  2327,
           102,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,  

In [15]:
class MultiModalModel(nn.Module):
    def __init__(self, vit_model, text_model, hidden_size=768, num_classes=10):
        super(MultiModalModel, self).__init__()

        # Vision transformer (ViT) for image features
        self.vit_model = vit_model

        # BERT-based transformer for text features
        self.text_model = text_model

        # Fully connected layer for classification
        self.fc = nn.Linear(hidden_size * 2, num_classes)  # Concatenating image and text features

    def forward(self, image, text_input_ids, text_attention_mask):
        # Process the image through the ViT model
        vit_outputs = self.vit_model(pixel_values=image)
        vit_cls_token = vit_outputs.last_hidden_state[:, 0]  # CLS token

        # Process the text through the transformer model (BERT/DistilBERT)
        text_outputs = self.text_model(input_ids=text_input_ids, attention_mask=text_attention_mask)
        text_cls_token = text_outputs.last_hidden_state[:, 0]  # CLS token

        # Concatenate the image and text features
        combined_features = torch.cat((vit_cls_token, text_cls_token), dim=1)

        # Pass through the classifier
        logits = self.fc(combined_features)
        return logits


In [16]:
# Example image preprocessing (use PIL or any image format you have)
from PIL import Image
image = Image.open("/content/data/Apparel/Boys/Images/images_with_product_ids/10054.jpg")
image = feature_extractor(images=image, return_tensors="pt").pixel_values

# Example text preprocessing
text = "Sample text input."
text_inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)

# Create the model
model = MultiModalModel(vit_model, text_model)

# Forward pass through the model
logits = model(image, text_inputs['input_ids'], text_inputs['attention_mask'])


In [17]:
model = MultiModalModel(
    vit_model=vit_model,
    text_model=text_model,
    hidden_size=768,
    num_classes=len(df['Category'].unique())
)

print("Model Ready: ", model)


Model Ready:  MultiModalModel(
  (vit_model): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTSdpaAttention(
            (attention): ViTSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Li

In [18]:
import os

#Save the trained model
torch.save(model.state_dict(), 'multi_modal_model.pth')
print("Model saved successfully")

#Verify model file
if os.path.exists('multi_modal_model.pth'):
    print("Model file 'multi_modal_model.pth' exists.")
else:
    print("Model file 'multi_modal_model.pth' does not exist.")

Model saved successfully
Model file 'multi_modal_model.pth' exists.


In [19]:
!pip install flask flask-ngrok



In [20]:
from flask import Flask, request, jsonify
from flask_ngrok import run_with_ngrok
import io

In [21]:
app = Flask(__name__)
run_with_ngrok(app)

In [22]:
# Load pretrained models
vit_model = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")
text_model = AutoModel.from_pretrained("bert-base-uncased")
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

In [23]:
# Define the MultiModalModel
class MultiModalModel(nn.Module):
    def __init__(self, vit_model, text_model, hidden_size=768, num_classes=10):
        super(MultiModalModel, self).__init__()
        self.vit_model = vit_model
        self.text_model = text_model
        self.fc = nn.Linear(hidden_size * 2, num_classes)

    def forward(self, image, text_input_ids, text_attention_mask):
        vit_outputs = self.vit_model(pixel_values=image)
        vit_cls_token = vit_outputs.last_hidden_state[:, 0]

        text_outputs = self.text_model(input_ids=text_input_ids, attention_mask=text_attention_mask)
        text_cls_token = text_outputs.last_hidden_state[:, 0]

        combined_features = torch.cat((vit_cls_token, text_cls_token), dim=1)
        logits = self.fc(combined_features)
        return logits

# Load trained model
model = MultiModalModel(vit_model, text_model, hidden_size=768, num_classes=len(df['Category'].unique()))
model.load_state_dict(torch.load("multi_modal_model.pth", map_location=torch.device("cpu")))
model.eval()

  model.load_state_dict(torch.load("multi_modal_model.pth", map_location=torch.device("cpu")))


MultiModalModel(
  (vit_model): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTSdpaAttention(
            (attention): ViTSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_featur

In [24]:
# Define image transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

In [25]:
@app.route('/predict', methods=['POST'])
def predict():
    try:
        if 'image' not in request.files or 'text' not in request.form:
            return jsonify({'error': 'Missing image or text'}), 400

        image_file = request.files['image']
        text_input = request.form['text']

        # Process image
        image = Image.open(io.BytesIO(image_file.read())).convert('RGB')
        image = transform(image).unsqueeze(0)

        # Process text
        text_tokens = tokenizer(text_input, padding='max_length', truncation=True, max_length=128, return_tensors='pt')
        text_input_ids = text_tokens['input_ids']
        text_attention_mask = text_tokens['attention_mask']

        with torch.no_grad():
            output = model(image, text_input_ids, text_attention_mask)
            prediction = torch.argmax(output, dim=1).item()

        return jsonify({'prediction': prediction})

    except Exception as e:
        return jsonify({'error': str(e)}), 500

In [26]:
# Run Flask app
app.run()

 * Serving Flask app '__main__'
 * Debug mode: off


 * Running on http://127.0.0.1:5000
INFO:werkzeug:[33mPress CTRL+C to quit[0m
Exception in thread Thread-10:
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/urllib3/connection.py", line 198, in _new_conn
    sock = connection.create_connection(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/urllib3/util/connection.py", line 85, in create_connection
    raise err
  File "/usr/local/lib/python3.11/dist-packages/urllib3/util/connection.py", line 73, in create_connection
    sock.connect(sa)
ConnectionRefusedError: [Errno 111] Connection refused

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/urllib3/connectionpool.py", line 787, in urlopen
    response = self._make_request(
               ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/urllib3/connectionpool.py", line 493, in _make_reque