# transformers: CLIPSeg for semantic segmentation

In [None]:
%matplotlib inline

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import requests
from PIL import Image
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation

## Load image

In [None]:
# load image
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'

image = Image.open(requests.get(url, stream=True).raw)

In [None]:
# show image
fig, ax = plt.subplots(figsize=(6, 4))
ax.imshow(np.asarray(image))
ax.set_aspect('equal', adjustable='box')
fig.tight_layout()

## Load model

In [None]:
# set model name
# model_name = 'CIDAS/clipseg-rd16'  # reduced dim. 16
model_name = 'CIDAS/clipseg-rd64-refined'  # reduced dim. 64

In [None]:
# create text and image preprocessors
processor = CLIPSegProcessor.from_pretrained(model_name)

# load model
model = CLIPSegForImageSegmentation.from_pretrained(model_name, device_map='auto')
model = model.eval()

print(f'Model device: {model.device}')
print(f'Model dtype: {model.dtype}')
print(f'Memory footprint: {model.get_memory_footprint() * 1e-9:.2f} GiB')

print(f'\nEmbedding dim.: {model.config.projection_dim}')
print(f'Reduced dim.: {model.config.reduce_dim}')

## Run model

In [None]:
# set candidate captions
candidate_labels = ['cat', 'remote', 'blanket', 'background']
candidate_captions = [f'a {label}' for label in candidate_labels]

# preprocess inputs
inputs = processor(
    text=candidate_captions,
    images=[image] * len(candidate_captions),
    return_tensors='pt',
    padding=True
)

print(f'Input IDs shape: {inputs['input_ids'].shape}')
print(f'Pixel values shape: {inputs['pixel_values'].shape}')

In [None]:
# run model
with torch.no_grad():
    outputs = model(**inputs.to(model.device))

logits = outputs.logits.cpu()  # get spatial image-text similarity scores
probs = logits.softmax(dim=0)  # get spatial label probabilities

print(f'Logits shape: {logits.shape}')
print(f'Conditional embeddings shape: {outputs.conditional_embeddings.shape}')
print(f'Pooled_output shape: {outputs.pooled_output.shape}')

In [None]:
# show predictions
idx = 0

fig, ax = plt.subplots(figsize=(6, 4))
ax.imshow(probs[idx].numpy() >= 0.5)
ax.set_aspect('equal', adjustable='box')
ax.set_title(f'{candidate_captions[idx]}')
fig.tight_layout()