# Finetune and Deploy BLIP-2 with Amazon SageMaker for Visual Question Answering

In this notebook, we are going to fine-tune BLIP-2 for visual question answering. This will be used for a fashion product description generation use case.

In our example, we are going to leverage Hugging Face [Transformers](https://huggingface.co/docs/transformers/index) and [PEFT](https://github.com/huggingface/peft) for the finetuning.

## 1. Setup development environment

Select the `Data Science 3.0` Image with `ml.t3.medium` instance type

Please make sure the IAM Role being used has the following permissions:
  - S3 Bucket access
  - SageMaker access

You can use the following IAM policy:
 - `arn:aws:iam::aws:policy/AmazonS3FullAccess`
 - `arn:aws:iam::aws:policy/AmazonSageMakerFullAcces`


In [25]:
! pip install -q --upgrade "scikit-image" "sagemaker>=2.190.0"

Here we set up the default session and bucket to use. If you want to use a different bucket, you can replace the bucket with the preferred bucket name.

In [26]:
import sagemaker
import boto3

sagemaker_session = sagemaker.Session()
sagemaker_session_bucket = None

if sagemaker_session_bucket is None and sagemaker_session is not None:
    # set to default bucket if a bucket name is not given, sagemaker will automatically create this bucket if it not exists
    sagemaker_session_bucket = sagemaker_session.default_bucket()

try:
    role = sagemaker.get_execution_role()
except ValueError:
    iam = boto3.client('iam')
    role = iam.get_role(RoleName='sagemaker_execution_role')['Role']['Arn']

sagemaker_session = sagemaker.Session(default_bucket=sagemaker_session_bucket)

print(f"sagemaker role arn: {role}")
print(f"sagemaker bucket: {sagemaker_session.default_bucket()}")
print(f"sagemaker session region: {sagemaker_session.boto_region_name}")

## 2. Load and prepare data

For the following demo we will be using tha Kaggle [Fashion Product Images Dataset](https://www.kaggle.com/datasets/paramaggarwal/fashion-product-images-dataset). We will use the shirt category `Tshirts` and `Shirts` as an example to finetune the model. The main files in the data sets are
1. `styles.csv`: all products and some of their key categories. We use this file to filter on the products that we're interested in.
2. `images.csv`: the link to all the images.
3. `images/product_id.jpg`: image of the product of id `product_id`.
4. `styles/product_id.json`: complete metadata of the product of id `product_id`.


#### You have to follow the following steps to load the data:
1. Sign in to Kaggle and download the dataset
2. Unzip the dataset and move the dataset into the folder `data`

For the following preprocessing code to work, the following structure is expected:
- `data/styles.csv`
- `data/images.csv`
- `data/styles/`
- `data/images/`

In the following processing, we perform two steps to create the train and test dataset.
1. Extract all the attributes from the product JSON file to create one CSV file containing information of all products and their attributes of interest. This can be used for data exploration.
2. Format the train and test dataset by 

### Combine all JSON attribute files into one dataset and filter the ids based on image availability.

In [None]:
pip install pandas

Collecting pandas
  Using cached pandas-2.3.3-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.metadata (91 kB)
Collecting pytz>=2020.1 (from pandas)
  Using cached pytz-2025.2-py2.py3-none-any.whl.metadata (22 kB)
Collecting pytz>=2020.1 (from pandas)
  Using cached pytz-2025.2-py2.py3-none-any.whl.metadata (22 kB)
Collecting tzdata>=2022.7 (from pandas)
  Using cached tzdata-2025.2-py2.py3-none-any.whl.metadata (1.4 kB)
Using cached pandas-2.3.3-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl (12.4 MB)
Using cached pytz-2025.2-py2.py3-none-any.whl (509 kB)
Using cached tzdata-2025.2-py2.py3-none-any.whl (347 kB)
Collecting tzdata>=2022.7 (from pandas)
  Using cached tzdata-2025.2-py2.py3-none-any.whl.metadata (1.4 kB)
Using cached pandas-2.3.3-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl (12.4 MB)
Using cached pytz-2025.2-py2.py3-none-any.whl (509 kB)
Using cached tzdata-2025.2-py2.py3-none-any.whl (347 kB)
Installing collected packages: pytz, tz

In [None]:
import pandas as pd
import numpy as np
from tqdm import tqdm
import json

In [None]:
image_summary = pd.read_csv('../data/images.csv')
attr_summary = pd.read_csv('../data/styles.csv', on_bad_lines='skip')

In [None]:
# We can see that some images don't have a link so we should remove them from the dataset
undefined_images = image_summary[image_summary.link == 'undefined']
undefined_images

Unnamed: 0,filename,link
6697,39403.jpg,undefined
16207,39410.jpg,undefined
32324,39401.jpg,undefined
36399,39425.jpg,undefined
40022,12347.jpg,undefined


In [None]:
# remove these records from the attribute summary
undefined_images_id = undefined_images.filename.apply(lambda x: x.replace('.jpg', '')).values
attr_summary['id'] = attr_summary['id'].astype(str)
attr_summary = attr_summary[~attr_summary.id.isin(undefined_images_id)]

In [None]:
# select only the target shirt for finetuning
tops = attr_summary[attr_summary['articleType'].isin(['Shirts', 'Tshirts'])]
tops.shape

(10281, 10)

In [None]:
dataset = pd.DataFrame()

for id in tqdm(tops['id']):
    with open(f"../data/styles/{id}.json", "r") as f:
        fl = f.read()
        jsn = json.loads(fl)

    attr = jsn['data']['articleAttributes']
    descr = jsn['data']['productDescriptors']
    
    attr.update(descr)
    attr['id'] = id
    attr['baseColour'] = tops
    row_df = pd.json_normalize(attr, sep='_')
    dataset = pd.concat([dataset, row_df], ignore_index=True)
    
    
# also add color to the dataset
dataset['id'] = dataset['id'].astype(str)
dataset = dataset.merge(tops[['id', 'baseColour']], on='id')

dataset.to_csv("../data/dataset.csv", index=False)

 16%|█▌        | 1606/10281 [00:04<00:22, 390.41it/s]



KeyboardInterrupt: 

In [None]:
dataset.columns

Index(['Fit', 'Pattern', 'Body or Garment Size', 'Sleeve Length', 'Fabric',
       'Collar', 'id', 'baseColour', 'description_descriptorType',
       'description_value', 'Occasion', 'Fabric 2', 'Fabric 3', 'Neck',
       'materials_care_desc_descriptorType', 'materials_care_desc_value',
       'size_fit_desc_descriptorType', 'size_fit_desc_value',
       'style_note_descriptorType', 'style_note_value', 'Sleeve Styling',
       'Business Unit', 'Multipack Set', 'Main Trend', 'Print or Pattern Type',
       'Number of Components', 'Sport Team', 'Segment', 'Sport', 'Technology',
       'Players', 'Character', 'Wash Care', 'Number of Pockets',
       'Surface Styling', 'Length', 'Brick', 'Family', 'Class',
       'Pattern Coverage', 'Processing Time', 'Brand Fit Name', 'Plating',
       'Hemline', 'Placket', 'Placket Length', 'Transparency', 'Pocket Type',
       'Weave Pattern', 'Cuff', 'Stitch', 'Brand'],
      dtype='object')

In [None]:
# formulate questions and answers for finetuning
cols = ["id", "Fabric", "Fit", "Collar", "Pattern", "Sleeve Length", "Sleeve Styling", "Neck", "baseColour"]
dataset = dataset[cols]

product_attributes = [
    {
        "Attribute": "Fabric",
        "Prompt": "What is the fabric of the shirt in this picture?",
    },
    {
        "Attribute": "Fit",
        "Prompt": "What is the Fit of the shirt in this picture?",
    },
    {
        "Attribute": "Collar",
        "Prompt": "What is the collar of the shirt in this picture?",
    },
    {
        "Attribute": "Pattern",
        "Prompt": "What is the pattern of the shirt in this picture?",
    },
    {
        "Attribute": "Neck",
        "Prompt": "What is the neck type of the shirt in this picture?",
    },
    {
        "Attribute": "Sleeve Length",
        "Prompt": "What is the sleeve length of the shirt in this picture?",
    },
    {
        "Attribute": "Sleeve Styling",
        "Prompt": "What is the sleeve styling of the shirt in this picture?",
    },
    {
        "Attribute": "baseColour",
        "Prompt": "What is the colour of the shirt in this picture?",
    },
]

In [None]:
vqa_data = []

for index, row in dataset.iterrows():
    for attribute in product_attributes:
        if row[attribute['Attribute']] is not np.nan:
            item_data = {}
            item_data['id'] = row['id']
            item_data['Question'] = attribute['Prompt'] 
            item_data['Answer'] = f"{attribute['Attribute']}: {row[attribute['Attribute']]}"
            vqa_data.append(item_data)

vqa = pd.DataFrame.from_records(vqa_data)
print(vqa.shape)
vqa.sample(1)

(8688, 3)


Unnamed: 0,id,Question,Answer
4379,7353,What is the Fit of the shirt in this picture?,Fit: Regular Fit


In [None]:
# Create a reduced vqa (20% of products), stratified by articleType proportions
# We use the earlier 'tops' DataFrame which includes articleType per id

# Ensure 'tops' has id as string for consistency with vqa
tops_reduced_base = tops.copy()
tops_reduced_base['id'] = tops_reduced_base['id'].astype(str)

# Target fraction of product IDs to keep
target_frac = 0.2

# Sample product IDs stratified by articleType to preserve ratios
reduced_ids = (
    tops_reduced_base
    .groupby('articleType', group_keys=False)
    .apply(lambda g: g.sample(frac=target_frac, random_state=200))['id']
    .unique()
)

# Filter the VQA dataframe to these IDs only
vqa_reduced = vqa[vqa['id'].astype(str).isin(reduced_ids)].reset_index(drop=True)

print("Original VQA rows:", len(vqa))
print("Reduced VQA rows:", len(vqa_reduced))
print("Original unique ids:", vqa['id'].nunique(), "| Reduced unique ids:", vqa_reduced['id'].nunique())

Original VQA rows: 8688
Reduced VQA rows: 1873
Original unique ids: 1606 | Reduced unique ids: 339


  .apply(lambda g: g.sample(frac=target_frac, random_state=200))['id']


### Reduce dataset size by ~80% while preserving ratios
We reduce the dataset to ~20% of product IDs, stratified by `articleType` to maintain the same proportions as the full dataset. The subsequent train/test split will run on this reduced set.

In [None]:
# Build train/test from the reduced VQA set
train = vqa_reduced.groupby("id").sample(frac=0.8,random_state=200)
test = vqa_reduced.drop(train.index)

In [None]:
# we dont want to finetune on colour, however we want to extract colour during testing/inference
print(train.shape[0])
train = train[train['Question'] != "What is the colour of the shirt in this picture?"]
print(train.shape[0])

1562
1268


In [None]:
train.sample()

Unnamed: 0,id,Question,Answer
1831,40030,What is the fabric of the shirt in this picture?,Fabric: Cotton


### Upload data to S3 for the finetuning job

In [None]:

# If you still want to use S3/SageMaker parts, keep the uploads below.
# upload train and test sets
train.to_csv(f"s3://{sagemaker_session_bucket}/data/vqa_train.csv", index=False)
test.to_csv(f"s3://{sagemaker_session_bucket}/data/vqa_test.csv", index=False)

# upload images to S3
s3_client = boto3.client('s3')

for id in tqdm(tops['id']):
    s3_client.upload_file(f'/data/images/{id}.jpg', sagemaker_session_bucket, f'/data/images/{id}.jpg')

In [27]:
# Save reduced train and test sets locally and (optionally) to S3 if desired
# Local saves for the "Run locally" section downstream
train.to_csv("../data/vqa_train.csv", index=False)
test.to_csv("../data/vqa_test.csv", index=False)


## 3. Fine-Tune BLIP-2 with LoRA on Amazon SageMaker

In [85]:
## Specify inputs to Training Jobs
inputs = f"s3://{sagemaker_session_bucket}/data/"
image_s3_uri = f"s3://{sagemaker_session_bucket}/data/images"
output_path = f"s3://{sagemaker_session_bucket}/training"

In [86]:
from sagemaker.inputs import TrainingInput

input_file = TrainingInput(s3_data=inputs, input_mode="File")
images_input = TrainingInput(s3_data=image_s3_uri, input_mode="FastFile", content_type="application/jpeg")

In [117]:
from sagemaker.huggingface import HuggingFace

hyperparameters = {
    'epochs': 10,
    'file-name': "vqa_train.csv",
}

estimator = HuggingFace(
    entry_point="entrypoint_vqa_finetuning.py",
    source_dir="../src",
    role=role,
    instance_count=1,
    instance_type="ml.g5.2xlarge", 
    transformers_version='4.26',
    pytorch_version='1.13',
    py_version='py39',
    hyperparameters = hyperparameters,
    base_job_name="VQA",
    sagemaker_session=sagemaker_session,
    output_path=f"{output_path}/models",
    code_location=f"{output_path}/code",
    volume_size=60,
    metric_definitions=[
        {'Name': 'batch_loss', 'Regex': 'Loss: ([0-9\\.]+)'},
        {'Name': 'epoch_loss', 'Regex': 'Epoch Loss: ([0-9\\.]+)'}
    ],
)

estimator.fit({"images": images_input, "input_file": input_file})

INFO:sagemaker.image_uris:image_uri is not presented, retrieving image_uri based on instance_type, framework etc.
INFO:sagemaker:Creating training-job with name: VQA-2024-02-07-13-03-54-497


2024-02-07 13:03:55 Starting - Starting the training job...
2024-02-07 13:03:58 Pending - Training job waiting for capacity......
2024-02-07 13:05:07 Pending - Preparing the instances for training......
2024-02-07 13:06:14 Downloading - Downloading input data...................................................
2024-02-07 13:14:52 Training - Training image download completed. Training in progress.bash: cannot set terminal process group (-1): Inappropriate ioctl for device
bash: no job control in this shell
2024-02-07 13:14:54,024 sagemaker-training-toolkit INFO     Imported framework sagemaker_pytorch_container.training
2024-02-07 13:14:54,044 sagemaker-training-toolkit INFO     No Neurons detected (normal if no neurons installed)
2024-02-07 13:14:54,053 sagemaker_pytorch_container.training INFO     Block until all host DNS lookups succeed.
2024-02-07 13:14:54,057 sagemaker_pytorch_container.training INFO     Invoking user training script.
2024-02-07 13:14:54,258 sagemaker-training-toolk

## 4. Deploy Fine-tuned BLIP-2 on Amazon SageMaker

In [95]:
from sagemaker.huggingface import HuggingFaceModel

# create Hugging Face Model Class
model = HuggingFaceModel(
   model_data=estimator.model_data,
   role=role, 
   transformers_version="4.28", 
   pytorch_version="2.0", 
   py_version="py310",
   model_server_workers=1,
   sagemaker_session=sagemaker_session
)

In [96]:
endpoint_name = "endpoint-finetuned-blip2"

model.deploy(
    initial_instance_count=1,
    instance_type="ml.g5.2xlarge",
    endpoint_name=endpoint_name
)

INFO:sagemaker:Creating model with name: huggingface-pytorch-inference-2024-02-07-12-45-48-758
INFO:sagemaker:Creating endpoint-config with name endpoint-finetuned-blip2-final
INFO:sagemaker:Creating endpoint with name endpoint-finetuned-blip2-final


----------!

<sagemaker.huggingface.model.HuggingFacePredictor at 0x168a36e50>

## 5. Run inference on the model

In [None]:
# sample one test image
sample_image_id = test.sample(1)['id'].values[0]
test_image = f"../data/{sample_image_id}.jpg"
test_image

In [101]:
import base64
import re 

smr_client = boto3.client("sagemaker-runtime")

def encode_image(img_file):
    with open(img_file, "rb") as image_file:
        img_str = base64.b64encode(image_file.read())
        base64_string = img_str.decode("latin1")
    return base64_string

def run_inference(endpoint_name, inputs):
    response = smr_client.invoke_endpoint(
        EndpointName=endpoint_name, Body=json.dumps(inputs),  ContentType="application/json"
    )
    return response["Body"].read().decode("utf-8")
    
base64_string = encode_image(test_image)


attributes = []

for product_attribute in product_attributes:
    inputs = {
        "prompt": f"Question: {product_attribute['Prompt']} Answer: ",
        "image": base64_string
    }
    response = run_inference(endpoint_name, inputs)
    attributes.append(re.sub("[\"']", "", response))
    print(response)


'Sleeve Length: Long Sleeves'

## 6. Generate product description with Amazon Bedrock

In [38]:
prompt = f"""
You are an expert in writing product descriptions for shirts. Use the data below to create product description for a website. 
The product description should contain all given attributes.
Provide some inspirational sentences, on e.g. how the fabric moves. Think about what a potential customer wants to know about the shirts. 
"""

print(prompt)


You are an expert in writing product descriptions for shirts. Use the data below to create product description for a website. 
The product description should contain all given attributes.
Provide some inspirational sentences, on e.g. how the fabric moves. Think about what a potential customer wants to know about the shirts. 



In [27]:
attributes_content =  {"role": "user", "content": f"Here are the facts you need to create the product descriptions: <product_attributes>{', '.join(attributes)}</product_attributes>"}

In [34]:
bedrock = boto3.client(    
    service_name='bedrock-runtime',
    region_name='us-west-2',
)

model_id = "anthropic.claude-3-sonnet-20240229-v1:0"

body = json.dumps({
    "system":prompt,
    "messages": [attributes_content],
    "max_tokens": 400,
    "temperature": 0.1,
    "anthropic_version": "bedrock-2023-05-31",
})

accept = 'application/json'
contentType = 'application/json'

response = bedrock.invoke_model(
     body=body,
     modelId=model_id,
     accept=accept,
     contentType=contentType
 )

response_body = json.loads(response.get('body').read())

print(response_body['content'][0]['text'])

Classic Striped Shirt Relax into comfortable casual style with this classic collared striped shirt. With a regular fit that is neither too slim nor too loose, this versatile top layers perfectly under sweaters or jackets.


## 7. Delete resources

Delete the SageMaker endpoint and the endpoint configuration

In [None]:
client = boto3.client('sagemaker')

response = client.delete_endpoint(
    EndpointName=endpoint_name
)

response = client.delete_endpoint_config(
    EndpointConfigName=endpoint_name
)

# Run locally (no SageMaker)

This section lets you fine-tune and run BLIP-2 directly in this notebook on your school GPU cluster (via SSH tunnel), without SageMaker. It:
- Installs minimal dependencies in the current kernel
- Detects GPU and configures 8-bit / half precision
- Prepares the VQA dataset from the CSV and images you created above
- Fine-tunes with LoRA and saves artifacts locally
- Loads the saved model for local inference

Tip: You can skip or collapse the SageMaker sections above and start here if running locally.

In [None]:
# Install minimal dependencies (one-time per kernel)
import sys, subprocess

def pip_install(pkgs):
    subprocess.check_call([sys.executable, "-m", "pip", "install", "-q"] + pkgs)

pkgs = [
    "transformers>=4.31.0",
    "peft==0.5.0",
    "accelerate==0.21.0",
    "bitsandbytes==0.40.2",
    "safetensors>=0.3.3",
    "Pillow",
]

try:
    import transformers, peft, accelerate, bitsandbytes, safetensors, PIL
except Exception:
    pip_install(pkgs)

print("Dependencies are ready.")

In [None]:
# GPU check and device setup
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)
if device == "cuda":
    print("CUDA device count:", torch.cuda.device_count())
    print("Current device:", torch.cuda.current_device())
    print("Device name:", torch.cuda.get_device_name(0))

In [None]:
# Local paths and data settings
from pathlib import Path

# Adjust these to your environment
DATA_DIR = Path("../data")  # where images.csv/styles.csv/images live (as above)
LOCAL_OUT = Path("../artifacts")  # where to save models
LOCAL_OUT.mkdir(parents=True, exist_ok=True)

CSV_TRAIN = DATA_DIR / "vqa_train.csv"  # produced in earlier cells
IMAGES_DIR = DATA_DIR / "images"       # images folder

print("Train CSV:", CSV_TRAIN)
print("Images dir:", IMAGES_DIR)
print("Output dir:", LOCAL_OUT)

In [None]:
# Dataset and processor/model setup (local)
import os
import pandas as pd
import numpy as np
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from transformers import AutoProcessor, Blip2ForConditionalGeneration
from peft import LoraConfig, get_peft_model

class VQADataset(Dataset):
    def __init__(self, csv_file, root_dir, img_size=(1800, 2400)):
        self.attributes = pd.read_csv(csv_file)[["id", "Question", "Answer"]]
        self.root_dir = root_dir
        self.img_size = img_size

    def __len__(self):
        return len(self.attributes)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        item = {}
        item["id"] = str(self.attributes.iloc[idx, 0])
        item["question"] = (
            "Question: " + str(self.attributes.iloc[idx, 1]) + " Answer: "
        )
        item["answer"] = self.attributes.iloc[idx, 2]
        img_path = os.path.join(self.root_dir, (item["id"] + ".jpg"))

        image = Image.open(img_path).convert("RGB")
        if image.size != self.img_size:
            image = image.resize(self.img_size)
        return (np.array(image), item["question"], item["answer"], item["id"])

In [None]:
# Hyperparameters for local training
model_name = "Salesforce/blip2-flan-t5-xl"  # consider 'blip2-flan-t5-large' if memory is tight
batch_size = 8
epochs = 2
learning_rate = 2.5e-3
lora_cfg = dict(r=8, lora_alpha=32, lora_dropout=0.05, bias="none", target_modules=["q", "v"])

In [None]:
# Load model, processor and prepare dataloader
from peft import LoraConfig

processor = AutoProcessor.from_pretrained(model_name)

# Try 8-bit if CUDA available; otherwise load normally
if device == "cuda":
    model = Blip2ForConditionalGeneration.from_pretrained(
        model_name,
        device_map="auto",
        load_in_8bit=True,
    )
else:
    model = Blip2ForConditionalGeneration.from_pretrained(model_name)

config = LoraConfig(**lora_cfg)
model = get_peft_model(model, config)
model.train()

optimizer = torch.optim.Adam(model.parameters(), learning_rate)

train_dataset = VQADataset(csv_file=str(CSV_TRAIN), root_dir=str(IMAGES_DIR))
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

print(f"Dataset size: {len(train_dataset)} | Batches/epoch: {len(train_loader)}")

In [None]:
# Local training loop with periodic logging and NaN guard
import math

nan_flag = False
for epoch in range(epochs):
    batch_losses = []
    for idx, (imgs, questions, answers, _) in enumerate(train_loader):
        if nan_flag:
            break
        inputs = processor(
            images=imgs,
            text=list(questions),
            text_target=list(answers),
            padding=True,
            return_tensors="pt",
        )
        # Move tensors to model device; keep dtype logic safe for CPU
        inputs = inputs.to(model.device)
        # If running on CUDA with 8-bit base weights, keep pixel_values to float16
        if device == "cuda" and inputs.get("pixel_values") is not None:
            inputs["pixel_values"] = inputs["pixel_values"].to(torch.float16)

        outputs = model(
            pixel_values=inputs["pixel_values"],
            input_ids=inputs["input_ids"],
            labels=inputs["labels"],
        )
        loss = outputs.loss

        if torch.isnan(loss):
            print(f"loss is NaN at epoch {epoch}, batch {idx}; saving and stopping")
            nan_flag = True
            break

        if idx % 50 == 0:
            with torch.no_grad():
                gen = model.generate(
                    pixel_values=inputs["pixel_values"],
                    input_ids=inputs["input_ids"],
                    max_length=20,
                )
            print("=" * 30)
            print("question:", processor.batch_decode(inputs["input_ids"], skip_special_tokens=True)[0])
            print("correct:", answers[0])
            print("pred:", processor.batch_decode(gen, skip_special_tokens=True)[0])

        batch_losses.append(loss.item())
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

    if batch_losses:
        print(f"Epoch {epoch} | Loss: {sum(batch_losses)/len(batch_losses):.4f}")

In [None]:
# Save locally in HF format + minimal metadata
from pathlib import Path

SAVE_DIR = LOCAL_OUT / "blip2-vqa-lora"
SAVE_DIR.mkdir(parents=True, exist_ok=True)

model.eval()
processor.save_pretrained(str(SAVE_DIR))
# Save LoRA-adapted model
model.save_pretrained(str(SAVE_DIR))
print("Saved to:", SAVE_DIR)

In [None]:
# Local inference utilities
import base64, json
from io import BytesIO

from transformers import AutoProcessor, Blip2ForConditionalGeneration
from PIL import Image

LOAD_DIR = SAVE_DIR  # reuse just-saved directory

model_inf = Blip2ForConditionalGeneration.from_pretrained(
    str(LOAD_DIR), device_map="auto" if device == "cuda" else None, load_in_8bit=(device=="cuda")
)
processor_inf = AutoProcessor.from_pretrained(str(LOAD_DIR))


def run_local_inference(image_path: str, prompt: str, **gen_kwargs):
    image = Image.open(image_path).convert("RGB")
    inputs = processor_inf(images=image, text=prompt, return_tensors="pt")
    inputs = inputs.to(model_inf.device)
    if device == "cuda" and inputs.get("pixel_values") is not None:
        inputs["pixel_values"] = inputs["pixel_values"].to(torch.float16)
    output = model_inf.generate(**inputs, **gen_kwargs)
    return processor_inf.decode(output[0], skip_special_tokens=True)

print("Inference model loaded from:", LOAD_DIR)

In [None]:
# Try local inference on a random test image and prompts
import re

# Reuse earlier variables if present, otherwise pick one image from the folder
try:
    sample_image_id = test.sample(1)['id'].values[0]
    test_image = f"../data/{sample_image_id}.jpg"
except Exception:
    # Fallback: list images dir
    from glob import glob
    imgs = sorted(glob(str(IMAGES_DIR / "*.jpg")))
    test_image = imgs[0] if imgs else None

if test_image:
    print("Test image:", test_image)
    product_attributes = [
        {"Attribute": "Fabric", "Prompt": "What is the fabric of the shirt in this picture?"},
        {"Attribute": "Fit", "Prompt": "What is the Fit of the shirt in this picture?"},
        {"Attribute": "Collar", "Prompt": "What is the collar of the shirt in this picture?"},
    ]
    attributes = []
    for pa in product_attributes:
        prompt = f"Question: {pa['Prompt']} Answer: "
        resp = run_local_inference(test_image, prompt, max_length=20)
        print(pa["Attribute"], "->", re.sub("[\"']", "", resp))
        attributes.append(resp)
else:
    print("No images found for inference in:", IMAGES_DIR)

## Download dataset from Kaggle (local)

This section downloads the Kaggle Fashion Product Images dataset directly into `../data` using the Kaggle API. You need Kaggle credentials (see https://www.kaggle.com/settings/account -> Create New API Token) which provide `kaggle.json`.

Two options to authenticate:
- Place your `kaggle.json` at `~/.kaggle/kaggle.json` with mode 600, or
- Set environment variables `KAGGLE_USERNAME` and `KAGGLE_KEY` in this notebook session.

In [None]:
# Install kaggle package if missing
import sys, subprocess

def ensure_pkg(pkg):
    try:
        __import__(pkg)
    except ImportError:
        subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", pkg])

ensure_pkg("kaggle")
print("kaggle package ready")

In [None]:
# Configure credentials: prefers ~/.kaggle/kaggle.json; falls back to env vars
import os, json
from pathlib import Path

home = Path.home()
kaggle_dir = home / ".kaggle"
kaggle_json = kaggle_dir / "kaggle.json"

if kaggle_json.exists():
    # Ensure correct permissions for Kaggle CLI
    try:
        kaggle_dir.mkdir(parents=True, exist_ok=True)
        kaggle_json.chmod(0o600)
    except Exception as e:
        print("Warning: could not set permissions on kaggle.json:", e)
    print("Using credentials from:", kaggle_json)
else:
    # Look for env vars
    if os.getenv("KAGGLE_USERNAME") and os.getenv("KAGGLE_KEY"):
        print("Using KAGGLE_USERNAME/KAGGLE_KEY from environment")
    else:
        print("Missing Kaggle credentials. Upload kaggle.json to ~/.kaggle or set KAGGLE_USERNAME and KAGGLE_KEY.")
        raise SystemExit(1)

In [None]:
# Download and unzip the dataset into ../data using Kaggle API
from kaggle.api.kaggle_api_extended import KaggleApi
from zipfile import ZipFile

DATASET = "paramaggarwal/fashion-product-images-dataset"
TARGET_DIR = Path("../data")
TARGET_DIR.mkdir(parents=True, exist_ok=True)

a = KaggleApi()
a.authenticate()

print("Starting download... This is ~1.2GB and can take a while.")
a.dataset_download_files(DATASET, path=str(TARGET_DIR), quiet=False)

# Find the zip file(s) just downloaded and extract
zips = sorted(TARGET_DIR.glob("*.zip"))
if not zips:
    # The API may save to a single zip with dataset name
    zips = sorted(Path.cwd().glob("*.zip"))

if not zips:
    raise FileNotFoundError("No zip files found after download.")

for z in zips:
    print("Extracting:", z)
    with ZipFile(z, 'r') as zip_ref:
        zip_ref.extractall(TARGET_DIR)
    # Optional: remove zip to save space
    # z.unlink(missing_ok=True)

print("Download and extraction complete.")

In [30]:
# Verify expected files and structure
import pandas as pd
from pathlib import Path
TARGET_DIR = Path("../data")
expected_files = [
    TARGET_DIR / "images.csv",
    TARGET_DIR / "styles.csv",
    TARGET_DIR / "images",
    TARGET_DIR / "styles",
]

for f in expected_files:
    print(f, "->", ("OK" if f.exists() else "MISSING"))

if (TARGET_DIR / "images.csv").exists():
    df_images = pd.read_csv(TARGET_DIR / "images.csv")
    print("images.csv rows:", len(df_images))

if (TARGET_DIR / "styles.csv").exists():
    df_styles = pd.read_csv(TARGET_DIR / "styles.csv", on_bad_lines='skip')
    print("styles.csv rows:", len(df_styles))

# Display a few sample images if available
from glob import glob
samples = glob(str(TARGET_DIR / "images" / "*.jpg"))[:3]
print("Sample images:", samples)

../data/images.csv -> OK
../data/styles.csv -> OK
../data/images -> OK
../data/styles -> OK
images.csv rows: 44446
styles.csv rows: 44424
Sample images: ['../data/images/17397.jpg', '../data/images/32378.jpg', '../data/images/2141.jpg']


### Copy reduced subset assets
Copy only the files referenced by `vqa_reduced` into compact folders for faster local training or sharing:
- Images to `../data/images_reduced/`
- Style JSONs to `../data/styles_reduced/`

In [31]:
# Copy reduced images and styles
from pathlib import Path
import shutil
from tqdm import tqdm

DATA_DIR = Path("../data")
IMAGES_SRC = DATA_DIR / "images"
STYLES_SRC = DATA_DIR / "styles"
IMAGES_DST = DATA_DIR / "images_reduced"
STYLES_DST = DATA_DIR / "styles_reduced"

IMAGES_DST.mkdir(parents=True, exist_ok=True)
STYLES_DST.mkdir(parents=True, exist_ok=True)

reduced_ids_list = sorted(vqa_reduced['id'].astype(str).unique())
img_copied = 0
json_copied = 0

for pid in tqdm(reduced_ids_list, desc="Copying files"):
    src_img = IMAGES_SRC / f"{pid}.jpg"
    dst_img = IMAGES_DST / f"{pid}.jpg"
    if src_img.exists():
        if not dst_img.exists():
            shutil.copy2(src_img, dst_img)
        img_copied += 1
    
    src_json = STYLES_SRC / f"{pid}.json"
    dst_json = STYLES_DST / f"{pid}.json"
    if src_json.exists():
        if not dst_json.exists():
            shutil.copy2(src_json, dst_json)
        json_copied += 1

print(f"Images copied: {img_copied} -> {IMAGES_DST}")
print(f"Style JSONs copied: {json_copied} -> {STYLES_DST}")

Copying files: 100%|██████████| 339/339 [00:00<00:00, 1023.85it/s]

Images copied: 339 -> ../data/images_reduced
Style JSONs copied: 339 -> ../data/styles_reduced



