In [1]:
import pandas as pd

In [2]:
from datasets import load_from_disk
cloth_dataset = load_from_disk("cloth_ds")

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
cloth_dataset["train"]

Dataset({
    features: ['image', 'text'],
    num_rows: 62928
})

In [4]:
import re

NORMALIZATION_MAP = {
    # garments
    "jumper": "sweater",
    "tank top": "top",
    "vest top": "top",
    "vest": "top",
    "kaftan": "dress",
    "caftan": "dress",
    "hooded sweatshirt": "hoodie",

    # necklines
    "turtle neck": "high neck",
    "turtleneck": "high neck",
    "polo neck": "high neck",
    "mock neck": "high neck",

    # sleeves
    "bell sleeves": "long sleeve",
    "long sleeved": "long sleeve",
    "short sleeved": "short sleeve",
    "long sleeves": "long sleeve",
    "short sleeves": "short sleeve",

    # patterns
    "print motif": "pattern",
    "printed": "pattern",
    "all over pattern": "pattern",
    "floral": "pattern",

    # materials
    "rib knit": "knit",
    "stretch jersey": "jersey",
}

def normalize_text(text: str) -> str:
    text = text.lower()
    text = text.replace("-", " ")
    text = re.sub(r"\s+", " ", text)

    for k, v in NORMALIZATION_MAP.items():
        text = text.replace(k, v)

    return text.strip()


In [5]:
COLOR_MAP = {
    "black": ["black", "jet black", "charcoal"],
    "white": ["white", "ivory", "off-white", "cream"],
    "blue": ["blue", "navy", "dark blue", "light blue", "sky blue", "royal blue"],
    "red": ["red", "burgundy", "wine", "maroon"],
    "green": ["green", "olive", "khaki", "mint"],
    "pink": ["pink", "rose", "blush", "fuchsia"],
    "purple": ["purple", "lilac", "lavender"],
    "brown": ["brown", "tan", "camel", "beige"],
    "grey": ["grey", "gray", "charcoal"],
    "orange": ["orange", "rust", "coral"],
    "yellow": ["yellow", "mustard"],
}


In [6]:
MATERIAL_MAP = {
    "cotton": ["cotton", "organic cotton"],
    "jersey": ["jersey", "stretch jersey"],
    "silk": ["silk", "satin", "charmeuse"],
    "wool": ["wool", "merino", "cashmere"],
    "linen": ["linen"],
    "leather": ["leather", "faux leather", "vegan leather"],
    "denim": ["denim", "jean"],
    "polyester": ["polyester", "poly"],
    "viscose": ["viscose", "rayon"],
    "knit": ["knit", "ribbed knit"],
}


In [7]:
FIT = {
    "fitted": ["fitted", "body-hugging", "figure-hugging"],
    "oversized": ["oversized", "relaxed", "boxy"],
    "slim": ["slim", "tailored"],
    "loose": ["loose", "flowy", "relaxed fit"],
    "cropped": ["cropped"],
}


In [8]:
SLEEVE = {
    "long": ["long sleeve", "long-sleeved"],
    "short": ["short sleeve", "short-sleeved"],
    "sleeveless": ["sleeveless"],
    "three-quarter": ["3/4 sleeve", "three-quarter sleeve"],
    "cap": ["cap sleeve"],
    "batwing": ["batwing sleeve"],
}


In [9]:
NECKLINE = {
    "v-neck": ["v-neck", "v neck"],
    "scoop": ["scoop neck", "scooped neckline"],
    "round": ["round neck", "crew neck"],
    "boat": ["boat neck", "bateau"],
    "square": ["square neckline"],
    "high": ["high neck", "turtleneck", "mock neck"],
    "off-shoulder": ["off the shoulder", "off-shoulder"],
}


In [10]:
GARMENT = {
    "top": ["top", "blouse", "shirt", "tee", "t-shirt"],
    "dress": ["dress", "gown"],
    "jacket": ["jacket", "coat", "blazer"],
    "pants": ["pants", "trousers", "jeans"],
    "skirt": ["skirt"],
    "bra": [ "bra", "bralette", "sports bra"],
    "underwear": [ "underwear", "panties", "briefs", "thong"],
    "lingerie": [ "lingerie", "bodysuit", "corset"],
    "socks": [   "socks", "ankle socks", "crew socks" ],
    "hosiery": [ "tights", "stockings", "pantyhose" ],
    "pajamas": ["pajamas", "pyjamas", "sleepwear", "nightwear" ],
    "scarf": ["scarf"],
    "belt": ["belt"],
    "hat": [ "hat", "cap", "beanie" ],
}


In [11]:
PATTERN = {
    "solid": ["solid"],
    "striped": ["striped", "stripes"],
    "floral": ["floral"],
    "printed": ["print", "printed"],
    "polka": ["polka dot", "polka-dot"],
    "pattern": ["pattern"],
}


In [12]:
LENGTH = {
    "mini": ["mini"],
    "midi": ["midi"],
    "maxi": ["maxi", "full-length"],
    "cropped": ["cropped", "crop"],
}


In [13]:
ATTRIBUTES = {
    "color": COLOR_MAP,
    "material": MATERIAL_MAP,
    "fit": FIT,
    "sleeve": SLEEVE,
    "neckline": NECKLINE,
    "garment": GARMENT,
    "pattern": PATTERN,
    "length": LENGTH,
}


In [14]:
def contains_phrase(text: str, phrase: str) -> bool:
    return re.search(rf"\b{re.escape(phrase)}\b", text) is not None


In [15]:
def extract_attributes(text):
    if not isinstance(text, str):
        return {attr: [] for attr in ATTRIBUTES}

    text = normalize_text(text)

    extracted = {attr: [] for attr in ATTRIBUTES}

    for attr, canon_map in ATTRIBUTES.items():
        for canon, variants in canon_map.items():
            for v in variants:
                if contains_phrase(text, v):
                    extracted[attr].append(canon)
                    break

    return extracted


In [16]:
gt = "solid dark blue fitted top in soft stretch jersey with a wide neckline and long sleeves"
pred = "a woman's long sleeved top with a scoop neck"

print(extract_attributes(gt))
print(extract_attributes(pred))


{'color': ['blue'], 'material': ['jersey'], 'fit': ['fitted'], 'sleeve': ['long'], 'neckline': [], 'garment': ['top'], 'pattern': ['solid'], 'length': []}
{'color': [], 'material': [], 'fit': [], 'sleeve': ['long'], 'neckline': ['scoop'], 'garment': ['top'], 'pattern': [], 'length': []}


In [17]:
from transformers import BlipProcessor, BlipForConditionalGeneration

processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")


  if not hasattr(np, "object"):
Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


In [18]:
from collections import defaultdict

def init_counters():
    return {
        attr: {"tp": 0, "fp": 0, "fn": 0}
        for attr in ATTRIBUTES
    }


In [19]:
def generate_blip_caption(image):
    inputs = processor(image, return_tensors="pt")
    out = model.generate(**inputs, max_length=80)
    caption = processor.decode(out[0], skip_special_tokens=True)
    return caption

In [20]:
def evaluate(dataset, blip_caption_fn):
    """
    dataset: iterable of dicts with keys:
        - "image"
        - "text"  (ground truth description)

    blip_caption_fn(image) -> caption string
    """

    COUNTS = init_counters()

    for sample in dataset:
        gt_text = sample["text"]
        blip_text = blip_caption_fn(sample["image"])

        gt_attrs = extract_attributes(gt_text)
        pred_attrs = extract_attributes(blip_text)

        for attr in ATTRIBUTES:
            gt_set = set(gt_attrs[attr])
            pred_set = set(pred_attrs[attr])

            COUNTS[attr]["tp"] += len(gt_set & pred_set)
            COUNTS[attr]["fp"] += len(pred_set - gt_set)
            COUNTS[attr]["fn"] += len(gt_set - pred_set)

    return COUNTS


In [21]:
def compute_metrics(COUNTS):
    metrics = {}

    for attr, c in COUNTS.items():
        tp, fp, fn = c["tp"], c["fp"], c["fn"]

        precision = tp / (tp + fp) if (tp + fp) > 0 else None
        recall = tp / (tp + fn) if (tp + fn) > 0 else None

        metrics[attr] = {
            "precision": precision,
            "recall": recall,
            "tp": tp,
            "fp": fp,
            "fn": fn
        }

    return metrics


In [22]:
eval_sample = cloth_dataset["train"].shuffle(seed=42).select(range(800))

In [23]:
COUNTS = evaluate(eval_sample, generate_blip_caption)
metrics = compute_metrics(COUNTS)

for attr, m in metrics.items():
    print(attr, m)


color {'precision': 0.7160326086956522, 'recall': 0.6704834605597965, 'tp': 527, 'fp': 209, 'fn': 259}
material {'precision': 0.7043478260869566, 'recall': 0.11538461538461539, 'tp': 81, 'fp': 34, 'fn': 621}
fit {'precision': 0.08333333333333333, 'recall': 0.017045454545454544, 'tp': 3, 'fp': 33, 'fn': 173}
sleeve {'precision': 0.5454545454545454, 'recall': 0.08759124087591241, 'tp': 24, 'fp': 20, 'fn': 250}
neckline {'precision': 0.6176470588235294, 'recall': 0.12138728323699421, 'tp': 21, 'fp': 13, 'fn': 152}
garment {'precision': 0.8270181219110379, 'recall': 0.6077481840193705, 'tp': 502, 'fp': 105, 'fn': 324}
pattern {'precision': 0.3628691983122363, 'recall': 0.1416803953871499, 'tp': 86, 'fp': 151, 'fn': 521}
length {'precision': 0.11627906976744186, 'recall': 0.5555555555555556, 'tp': 5, 'fp': 38, 'fn': 4}


In [26]:
for i in range(10):
    sample = eval_sample[i]
    blip = generate_blip_caption(sample["image"])

    print("GT:", sample["text"])
    print("BLIP:", blip)
    print("GT attrs:", extract_attributes(sample["text"]))
    print("BLIP attrs:", extract_attributes(blip))
    print("-" * 60)


GT: all over pattern pink kaftan in an airy patterned weave with a slight sheen small opening with a covered button at the top and a narrow tie belt at the waist unlined
BLIP: the pink floral dress is a great way to wear it
GT attrs: {'color': ['pink'], 'material': [], 'fit': [], 'sleeve': [], 'neckline': [], 'garment': ['top', 'dress', 'belt'], 'pattern': ['pattern'], 'length': []}
BLIP attrs: {'color': ['pink'], 'material': [], 'fit': [], 'sleeve': [], 'neckline': [], 'garment': ['dress'], 'pattern': ['pattern'], 'length': []}
------------------------------------------------------------
GT: solid pink wide top in sturdy jersey with low dropped shoulders and short sleeves with a wide flounce
BLIP: a pink top with bell sleeves
GT attrs: {'color': ['pink'], 'material': ['jersey'], 'fit': [], 'sleeve': ['short'], 'neckline': [], 'garment': ['top'], 'pattern': ['solid'], 'length': []}
BLIP attrs: {'color': ['pink'], 'material': [], 'fit': [], 'sleeve': ['long'], 'neckline': [], 'garment':