# Hair Style Classification
Zero-shot hair style classification from OpenAI CLIP.

In [1]:
# Install CLIP
%pip install git+https://github.com/openai/CLIP.git

Collecting git+https://github.com/openai/CLIP.git
  Cloning https://github.com/openai/CLIP.git to /private/var/folders/98/2x4_h0694k17yxw6khx3tmx80000gn/T/pip-req-build-1za9sy7m
  Running command git clone -q https://github.com/openai/CLIP.git /private/var/folders/98/2x4_h0694k17yxw6khx3tmx80000gn/T/pip-req-build-1za9sy7m
  Resolved https://github.com/openai/CLIP.git to commit 3702849800aa56e2223035bccd1c6ef91c704ca8
You should consider upgrading via the '/usr/local/bin/python3 -m pip install --upgrade pip' command.[0m
Note: you may need to restart the kernel to use updated packages.


In [2]:
from PIL import Image
from tempfile import gettempdir
from torch import no_grad
import clip

def predict (image: Image.Image) -> str:
    """
    Predict whether hair in a given image is curly or straight.

    Parameters:
        image (PIL.Image): Input image.

    Returns:
        str: One of ["CURLY", "STRAIGHT"].
    """
    # Load CLIP
    # Use `download_root=gettempdir()` cos we're creating an endpoint with `SERVERLESS` acceleration
    # Serverless endpoints don't have writable file systems except for the temp directory
    model, preprocess = clip.load("ViT-B/32", device="cpu", download_root=gettempdir())
    image_feature = preprocess(image).unsqueeze(dim=0)
    # Infer
    PROMPTS = {
        "CURLY": "curly hair, twisted hair, loc hair, bundled hair, tangled hair",
        "STRAIGHT": "straight hair, pressed hair, flat hair"
    }
    with no_grad():
        prompt = clip.tokenize(list(PROMPTS.values()))
        _, logits_per_text = model(image_feature, prompt)
    # Get label
    result_idx = logits_per_text.argmax()
    result = list(PROMPTS.keys())[result_idx]
    # Return
    return result