In [1]:
import torch, json, gc, os, re
from pathlib import Path
from PIL import Image
from tqdm import tqdm
from transformers import AutoProcessor, AutoModelForVision2Seq
import pandas as pd
USE_FP16       = True
MAX_NEW_TOKENS = 150
BATCH_SIZE     = 3


2025-04-20 10:49:31.666204: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1745146171.899154      31 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1745146171.967323      31 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
BASE_DIR        = Path("/kaggle/input/indofashion-nlp")
IMG_DIR         = BASE_DIR/"train_sand"/"train_sand"
META_JSON_PATH  = BASE_DIR/"indofashion_metadata.json"
OUTPUT_CSV      = "LLaVADescriptionsIndofashion.csv"
CHECKPOINT_FILE = "llava_indofashion_checkpoint.json"


Using device: cuda


In [3]:
meta_df = pd.read_json(META_JSON_PATH, lines=True)
print("Available columns:", meta_df.columns.tolist())


Available columns: ['image_id', 'class_label', 'brand', 'product_title']


In [4]:
meta_df = pd.read_json(META_JSON_PATH, lines=True)
meta_df = meta_df[~meta_df["class_label"]
                  .isin(["petticoats","mojaris_men","mojaris_women","dupattas"])]
meta_df = meta_df.set_index("image_id", drop=False)
meta_df.rename(
    columns={
        "product_title": "title",
    },
    inplace=True
)

In [5]:
def build_prompt_llava_indofashion(row):
    item_cls = row["class_label"].strip().lower()
    brand    = row["brand"].strip()
    title    = row["title"].strip().rstrip(".")

    words = title.split()
    if len(words) > 12:
        title = " ".join(words[:12]) + "…"

    title = title.replace('"',"").replace("“","").replace("”","")
    parts = []
    if item_cls: parts.append(item_cls)
    if brand:    parts.append(f"by {brand}")
    if title:    parts.append(f'titled "{title}"')
    descriptor = ", ".join(parts)
    return (
        "USER: <image>\n"
        "You’re an e‑commerce copywriter. Describe this garment in vivid detail—"
        "focus on its color, fabric/material, pattern or embroidery, silhouette, "
        "neckline, sleeves or drape style, length, and any special accents or embellishments. "
        f"This is {descriptor}. "
        "Do NOT mention model, person, or background—only describe the garment itself.\n"
        "ASSISTANT:"
    )

meta_df["prompt"] = meta_df.apply(build_prompt_llava_indofashion, axis=1)

In [6]:
def load_checkpoint():
    if os.path.exists(CHECKPOINT_FILE):
        return json.load(open(CHECKPOINT_FILE))
    return {"processed_ids": [], "descriptions": []}

def save_checkpoint(cp):
    json.dump(cp, open(CHECKPOINT_FILE, "w"))

def clean_up():
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()



In [7]:
model_id  = "llava-hf/llava-1.5-13b-hf"
processor = AutoProcessor.from_pretrained(model_id)
model     = AutoModelForVision2Seq.from_pretrained(
    model_id,
    torch_dtype=torch.float16 if USE_FP16 else torch.float32,
    device_map="auto",
    low_cpu_mem_usage=True
)
model.gradient_checkpointing_enable()
model.eval()

processor_config.json:   0%|          | 0.00/173 [00:00<?, ?B/s]

chat_template.json:   0%|          | 0.00/701 [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/505 [00:00<?, ?B/s]

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


tokenizer_config.json:   0%|          | 0.00/1.45k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/3.62M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/41.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/552 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.10k [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/77.2k [00:00<?, ?B/s]

Fetching 6 files:   0%|          | 0/6 [00:00<?, ?it/s]

model-00001-of-00006.safetensors:   0%|          | 0.00/4.96G [00:00<?, ?B/s]

model-00005-of-00006.safetensors:   0%|          | 0.00/4.93G [00:00<?, ?B/s]

model-00006-of-00006.safetensors:   0%|          | 0.00/2.02G [00:00<?, ?B/s]

model-00002-of-00006.safetensors:   0%|          | 0.00/4.97G [00:00<?, ?B/s]

model-00003-of-00006.safetensors:   0%|          | 0.00/4.88G [00:00<?, ?B/s]

model-00004-of-00006.safetensors:   0%|          | 0.00/4.93G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/6 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/141 [00:00<?, ?B/s]

LlavaForConditionalGeneration(
  (vision_tower): CLIPVisionModel(
    (vision_model): CLIPVisionTransformer(
      (embeddings): CLIPVisionEmbeddings(
        (patch_embedding): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14), bias=False)
        (position_embedding): Embedding(577, 1024)
      )
      (pre_layrnorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      (encoder): CLIPEncoder(
        (layers): ModuleList(
          (0-23): 24 x CLIPEncoderLayer(
            (self_attn): CLIPSdpaAttention(
              (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
              (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
              (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
              (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
            )
            (layer_norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            (mlp): CLIPMLP(
              (activation_fn): Q

In [8]:
def clean_up():
    torch.cuda.empty_cache()
    gc.collect()

ckpt = load_checkpoint()
processed = set(ckpt["processed_ids"])
results = ckpt["descriptions"]
to_do = [obj_id for obj_id in meta_df.index if obj_id not in processed]
for index in tqdm(range(0, len(to_do), BATCH_SIZE)):
    batch_ids = to_do[index : index + BATCH_SIZE]
    imgs, prompts = [], []
    valid_batch_ids = []  # Keep track of IDs with valid images
    
    for obj_id in batch_ids:
        # Look for any of the three file formats
        img_path = None
        for ext in [".jpeg", ".jpg", ".png"]:
            potential_path = IMG_DIR/f"{obj_id}{ext}"
            if potential_path.exists():
                img_path = potential_path
                break
        
        if img_path is None:
            print(f"Warning: No image found for {obj_id}")
            continue
            
        img = Image.open(img_path).convert("RGB")
        img = img.resize((336, 336), Image.LANCZOS)
        imgs.append(img)
        prompts.append(meta_df.at[obj_id, "prompt"])
        valid_batch_ids.append(obj_id)  # Only add valid IDs
        
    if not imgs:
        continue  
        
    clean_up()
    inputs = processor(
        images=imgs,
        text=prompts,
        return_tensors="pt",
        padding=True,
        truncation=True
    ).to(device)
    
    with torch.amp.autocast('cuda', enabled=USE_FP16), torch.no_grad():
        out_ids = model.generate(
            **inputs,
            max_new_tokens=MAX_NEW_TOKENS,
            do_sample=False,
            use_cache=True,

            min_length=10,  
            num_return_sequences=1,
            no_repeat_ngram_size=3,
            early_stopping=True 
        )
    
    outputs = processor.batch_decode(out_ids, skip_special_tokens=True)
    
    for obj_id, raw in zip(valid_batch_ids, outputs):
        desc = raw.split("ASSISTANT:", 1)[-1].strip()
        if not desc.endswith(('.', '!', '?')):
            # Try to find the last complete sentence
            last_sentence_end = max(desc.rfind('.'), desc.rfind('!'), desc.rfind('?'))
            if last_sentence_end > len(desc) * 0.5:  # Only truncate if we have at least half the text
                desc = desc[:last_sentence_end + 1]
            else:
                desc = desc + "."
        
        results.append({"image_id": obj_id, "description": desc})
        processed.add(obj_id)
    
    save_checkpoint({"processed_ids": list(processed), "descriptions": results})
    clean_up()

pd.DataFrame(results).to_csv(OUTPUT_CSV, index=False)
print(f"Descriptions saved - {OUTPUT_CSV}")

  0%|          | 0/667 [00:00<?, ?it/s]Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
 27%|██▋       | 178/667 [1:18:41<3:37:04, 26.63s/it]



 70%|██████▉   | 465/667 [3:25:23<1:30:20, 26.83s/it]



 79%|███████▉  | 530/667 [3:54:10<1:00:58, 26.70s/it]



100%|██████████| 667/667 [4:54:22<00:00, 26.48s/it]  

Descriptions saved - LLaVADescriptionsIndofashion.csv



