# Zero Shot Object Localization and Detection with OpenAI's CLIP

### Download the dataset

Let's start by downloading the dataset. We are using the <a ref= https://huggingface.co/datasets/jamescalam/image-text-demo> image-text-demo </a> from HuggingFace.

In [None]:
from datasets import load_dataset #pip install datasets
from PIL import Image, ImageDraw, ImageOps, ImageFilter
import matplotlib.pyplot as plt
import matplotlib as mpl

from transformers import CLIPProcessor, CLIPModel
import torch
from tqdm.auto import tqdm
import numpy as np
import os
import warnings
import logging
from absl import logging as absl_logging

# Set environment variables to suppress TensorFlow and oneDNN warnings
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'  # Suppress TensorFlow logs
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'  # Disable oneDNN warnings

# Suppress specific warnings
warnings.filterwarnings('ignore', category=UserWarning)
warnings.filterwarnings('ignore', category=DeprecationWarning)
warnings.filterwarnings('ignore', category=FutureWarning)

# Configure logging for TensorFlow and related libraries
logging.getLogger('tensorflow').setLevel(logging.ERROR)
logging.getLogger('xla').setLevel(logging.ERROR)
logging.getLogger('cuda').setLevel(logging.ERROR)
logging.getLogger('absl').setLevel(logging.ERROR)

# Suppress absl logging
absl_logging.set_verbosity(absl_logging.ERROR)
absl_logging.set_stderrthreshold(absl_logging.ERROR)


In [None]:
data = load_dataset(
	"jamescalam/image-text-demo",
	split="train",
	revision="180fdae",
	trust_remote_code=True  # Add this line to trust the remote code
)
data

The dataset includes 21 labelled images. We chose the 3rd image in the dataset, *butterfly landing on the nose of a cat*, to perform our object localization task. Let's have a look at it.

In [None]:
# image
query_image = 2
data[query_image]['image']
# label
data[query_image]['text']

In [None]:
for img, lbl in zip(data['image'], data['text']):
	print(lbl)
	# display(img)

### Break Image into Equal Patches

The first step consists of transforming our image into a tensor.

In [None]:
from torchvision import transforms

# transform the image into tensor
transform = transforms.ToTensor()

img = transform(data[query_image]["image"])
# img = transform(Image.open("/home/farid/WS_Farid/ImACCESS/TEST_IMGs/cat_butterfly.jpg"))
img.data.shape

The generated tensor has $3$ color channels, a height of $5184$, and width of $3456$

H,W = 5184,3456

We now want to add an extra dimension, which will be needed for later calculations. We can use the `unfold` function.

In [None]:
# add extra dimension for later calculations
patches = img.data.unfold(0,3,3)
patches.shape

We can now break the image into patches. More precisely, we want to break the image into $256x256$ pixels patches. We start by breaking it horizontally, meaning that we will end up with an image composed of $20$ patches of $256$ pixels in height, and $1$ patch of $3456$ pixels in width.

In [None]:
# break the image into patches (in height dimension)
patch = 256

patches = patches.unfold(1, patch, patch)
patches.shape

We can visualize it below:

In [None]:
X = patches.shape[1]

fig, ax = plt.subplots(X, 1, figsize=(15, 15))
# loop through each strip and display
for x in range(X):
	print(x)
	print(patches[0, x].permute(2, 0, 1).shape)
	ax[x].imshow(patches[0, x].permute(2, 0, 1))
	ax[x].axis("off")
fig.tight_layout()
plt.show()

We can now use the `unfold` function again to break the image vertically. After this operation, we will get an image composed of $20x13$ patches of $256x256$ pixels. 

In [None]:
# break the image into patches (in width dimension)
patches = patches.unfold(2, patch, patch)
patches.shape

We can visualize it below:

In [None]:
import matplotlib.pyplot as plt

X = patches.shape[1]
Y = patches.shape[2]

fig, ax = plt.subplots(X, Y, figsize=(Y*2, X*2))
for x in range(X):
	for y in range(Y):
		ax[x, y].imshow(patches[0, x, y].permute(1, 2, 0))
		ax[x, y].axis("off")
fig.tight_layout()
plt.show()

### Process Patches using CLIP

The first step is done. We are now almost ready to process those patches using CLIP. Before doing it, we might want to work through these patches by grouping them into a 6x6 window.

<center><div> <img src="https://raw.githubusercontent.com/pinecone-io/examples/master/learn/image-retrieval/clip-object-detection/assets/window.png" alt="Drawing" style="width:300px;"/></div> </center> 

Let's visualize the first patch.

In [None]:
# set the 6x6 window
window = 6

big_patch = torch.zeros(patch*window, patch*window, 3)
patch_batch = patches[0][:window][:window]

# visualize patch
for y in range(window):
	for x in range(window):
		big_patch[y*patch:(y+1)*patch, x*patch:(x+1)*patch, :] = patch_batch[y, x].permute(1, 2, 0)

plt.imshow(big_patch)
plt.axis("off")
plt.show()

In [None]:
patches.shape

This is now a patch consisting of 6x6 smaller patches.

We can repeat this process by "sliding" the 6x6 window over the full image. We set the stride, i.e., the number of steps the window moves, to $1$.

In [None]:
# window = 6
# stride = 1

# # window slides from top to bottom
# for Y in range(0, patches.shape[1]-window+1, stride):
#     # window slides from left to right
#     for X in range(0, patches.shape[2]-window+1, stride):
#         # initialize an empty big_patch array
#         big_patch = torch.zeros(patch*window, patch*window, 3)

#         # this gets the current batch of patches that will make big_batch
#         patch_batch = patches[0, Y:Y+window, X:X+window]
#         # loop through each patch in current batch
#         for y in range(patch_batch.shape[1]):
#             for x in range(patch_batch.shape[0]):
#                 # add patch to big_patch
#                 big_patch[
#                     y*patch:(y+1)*patch, x*patch:(x+1)*patch, :
#                 ] = patch_batch[y, x].permute(1, 2, 0)
#         # display current big_patch
#         plt.imshow(big_patch)
#         plt.show()

We need now to process these through CLIP and calculate the similarity between the patch and a prompt. Our first prompt will be `"a fluffy cat"`. Let's first define our processor and model CLIP and move it to device, if possible. 

In [None]:
# define processor and model
model_id = "openai/clip-vit-base-patch32"

processor = CLIPProcessor.from_pretrained(model_id)
model = CLIPModel.from_pretrained(model_id)

# move model to device if possible
device = 'cuda' if torch.cuda.is_available() else 'cpu'

model.to(device)

Now, we can add a step to the previous operation, i.e., we calculate similarity scores for each patch (`scores`) and the number of time the window slides over each patch (`runs`). To do that, we are trasmitting to the CLIP model the image, and the prompt, which is "a fluffy cat".

In [None]:
window = 6
stride = 1

scores = torch.zeros(patches.shape[1], patches.shape[2])
runs = torch.ones(patches.shape[1], patches.shape[2])

for Y in range(0, patches.shape[1]-window+1, stride):
	for X in range(0, patches.shape[2]-window+1, stride):
		big_patch = torch.zeros(patch*window, patch*window, 3)
		patch_batch = patches[0, Y:Y+window, X:X+window]
		for y in range(window):
			for x in range(window):
				big_patch[
					y*patch:(y+1)*patch, x*patch:(x+1)*patch, :
				] = patch_batch[y, x].permute(1, 2, 0)
		# we preprocess the image and class label with the CLIP processor
		inputs = processor(
			images=big_patch,  # big patch image sent to CLIP
			return_tensors="pt",  # tell CLIP to return pytorch tensor
			text="a fluffy cat",  # class label sent to CLIP
			padding=True
		).to(device) # move to device if possible

		# calculate and retrieve similarity score
		score = model(**inputs).logits_per_image.item()
		# sum up similarity scores from current and previous big patches
		# that were calculated for patches within the current window
		scores[Y:Y+window, X:X+window] += score
		# calculate the number of runs on each patch within the current window
		runs[Y:Y+window, X:X+window] += 1

We then want to divide the total score (`scores`) by the number of time the window slided over the patch (`runs`) to get an average score for each patch.

In [None]:
# average score for each patch
scores /= runs

The initial visual is not very useful...

In [None]:
# transform the patches tensor 
adj_patches = patches.squeeze(0).permute(3, 4, 2, 0, 1)
# normalize scores
scores = (
	scores - scores.min()) / (scores.max() - scores.min()
)
# multiply patches by scores
adj_patches = adj_patches * scores
# rotate patches to visualize
adj_patches = adj_patches.permute(3, 4, 2, 0, 1)

Y = adj_patches.shape[0]
X = adj_patches.shape[1]

fig, ax = plt.subplots(Y, X, figsize=(X*.5, Y*.5))
for y in range(Y):
	for x in range(X):
		ax[y, x].imshow(adj_patches[y, x].permute(1, 2, 0))
		ax[y, x].axis("off")
		ax[y, x].set_aspect('equal')
plt.subplots_adjust(wspace=0, hspace=0)
plt.show()

The resulting tensor is characterized by a smooth gradient of scores and little-to-no impact from the scores on our visual. That is why (1) we are going to clip the score's interval edges for each patch so that they are around the average score, and (2) normalize scores for each patch. 

Using the `numpy.clip()` function, we can set all the scores below the average score equal to $0$ while keeping those higher than the average score as they are.

We are running the same operation 3 times as it seems to give us better results.

In [None]:
# clip the scores' interval edges
for _ in range(1):
	scores = np.clip(scores-scores.mean(), 0, np.inf)

scores

After, we normalize the scores using the min-max normalization. For every tensor, the minimum value is transformed into a 0, the maximum value into a 1, and every other value into a decimal between 0 and 1.

In [None]:
# normalize scores
scores = (
	scores - scores.min()) / (scores.max() - scores.min()
)
scores

### Visualize Results

These scores are telling us if a given patch contains "a fluffy cat" or not. The higher the score, the more the probability that the cat is localized in that patch. 

We now want to visualize those scores, i.e., the localized object, on the original image. Scores equal to zero will be represented by black patches, so that the localized object can be clearly seen.

To do that, we can multiply our scores by the patches. This requires that scores and patches have the same shape.

In [None]:
scores.shape, patches.shape

Given they do not have the same shape, we can transform the patches shape using `squeeze` and `permute`. Squeeze reduce the dimensionality, while permute rotates the tensor.

In [None]:
# transform the patches tensor 
adj_patches = patches.squeeze(0).permute(3, 4, 2, 0, 1)
adj_patches.shape

We can now multiply patches to scores.

In [None]:
# multiply patches by scores
adj_patches = adj_patches * scores

Before plotting the localized object, we can rotate the patch tensor again to make our life easier ahead. 

In [None]:
# rotate patches to visualize
adj_patches = adj_patches.permute(3, 4, 2, 0, 1)
adj_patches.shape

We can now visualize the localized object. We are expecting to visualize the cat only.

In [None]:
Y = adj_patches.shape[0]
X = adj_patches.shape[1]

fig, ax = plt.subplots(Y, X, figsize=(X*.5, Y*.5))
for y in range(Y):
	for x in range(X):
		ax[y, x].imshow(adj_patches[y, x].permute(1, 2, 0))
		ax[y, x].axis("off")
		ax[y, x].set_aspect('equal')
plt.subplots_adjust(wspace=0, hspace=0)
plt.show()

It worked pretty well! 

Now let's do the same for butterfly... in this case, the text trasmitted to the CLIP model will be "a butterfly".

In [None]:
window = 6
stride = 1

scores = torch.zeros(patches.shape[1], patches.shape[2])
runs = torch.ones(patches.shape[1], patches.shape[2])

for Y in range(0, patches.shape[1]-window, stride):
	for X in range(0, patches.shape[2]-window, stride):
		big_patch = torch.zeros(patch*window, patch*window, 3)
		patch_batch = patches[0, Y:Y+window, X:X+window]
		for y in range(window):
			for x in range(window):
				big_patch[y*patch:(y+1)*patch, x*patch:(x+1)*patch, :] = patch_batch[y, x].permute(1, 2, 0)
		inputs = processor(
			images=big_patch,
			return_tensors="pt",
			text="a butterfly",
			padding=True
		).to(device)
		score = model(**inputs).logits_per_image.item()
		scores[Y:Y+window, X:X+window] += score
		runs[Y:Y+window, X:X+window] += 1

In [None]:
scores /= runs
for _ in range(3):
	scores = np.clip(scores-scores.mean(), 0, np.inf)
# normalize scores
scores = (scores - scores.min()) / (scores.max() - scores.min())

In [None]:
# adjust patches
adj_patches = patches.squeeze(0).permute(3, 4, 2, 0, 1) * scores
adj_patches = adj_patches.permute(3, 4, 2, 0, 1)

In [None]:
Y = adj_patches.shape[0]
X = adj_patches.shape[1]

fig, ax = plt.subplots(Y, X, figsize=(X*.5, Y*.5))
for y in range(Y):
	for x in range(X):
		ax[y, x].imshow(adj_patches[y, x].permute(1, 2, 0))
		ax[y, x].axis("off")
		ax[y, x].set_aspect('equal')
plt.subplots_adjust(wspace=0, hspace=0)
plt.show()

## CLIP for Object Detection

Localization is one step towards object detection, where we might expect to detect multiple objects.

We can extend our current localization code quite easily to achieve this, we just add an extra layer of logic. But, before we do anything, we need to rethink our localization visual — this would be hard to represent when displaying two or more objects.

The typical approach to this is to create a "bounding box".

The scores we are using are from the last run, with which we localized the butterly. Therefore, we are going to build the a bounding box around the butterfly first. We want the bounding box to focus as much as possible on the butterly only, nothing else.

Scores higher than $0.5$ seem to give us a more precise bounding box. We have then defined 'detection', which gives us *True* when the score is higher than $0.5$ (i.e.,non-zero positions), and *False* otherwise. 

In [None]:
# scores higher than 0.5
detection = scores > 0.5

We can now detect the non-zero *positions* with the `np.nonzero` function.  These represent the co-ordinates of our patches with scores $>0.5$. 

In [None]:
# non-zero positions
np.nonzero(detection)

In [None]:
y_min, y_max = (
	np.nonzero(detection)[:,0].min().item(),
	np.nonzero(detection)[:,0].max().item()+1
)
y_min, y_max

In [None]:
x_min, x_max = (
	np.nonzero(detection)[:,1].min().item(),
	np.nonzero(detection)[:,1].max().item()+1
)
x_min, x_max

These give us the bounding box corner co-ordinates based on patches rather than pixel values. To get the pixel co-ordinates we need to multiply by the `patch` size.

In [None]:
y_min *= patch
y_max *= patch
x_min *= patch
x_max *= patch
x_min, y_min

We use `(y_max - y_min)` and `(x_max - x_min)` to calculate the height and width of the bounding box respectively...

In [None]:
height = y_max - y_min
width = x_max - x_min

height, width

Given our patches are $256x256$ pixels, we obtain a total height and width of $256*8=2048$ and $256*4=1024$ pixels, respectively. 

We should be now be able to visualize the bounding box on our image. We are going to use 'matplotlib'. This need the image's color channel to be as last in the image's shape.

In [None]:
img.data.numpy().shape

In our case, this is in the first position. We then need to move it to end. We are using 'moveaxis' to do that.

In [None]:
# move color channel to final dim
image = np.moveaxis(img.data.numpy(), 0, -1)
image.shape

We can now plot the image.

In [None]:
fig, ax = plt.subplots(figsize=(Y*0.5, X*0.5))
ax.imshow(image)
# Create a Rectangle patch
rect = mpl.patches.Rectangle(
	(x_min, y_min),
	width,
	height,
	linewidth=2,
	edgecolor='#FAFF00', 
	facecolor='none',
	alpha=0.8,
)
# Add the patch to the Axes
ax.add_patch(rect)
ax.axis('off')
plt.tight_layout()
plt.show()

We can repeat this process for a number of objects that we'd like CLIP to detect. let's put everything we've done so far in a few helper functions, then create a new function called `detect` to handle the detection of multiple objects and visualization of the bounding boxes.

In [None]:
colors = [
	'#FF00FF',
	'#00FF00',
	'#FAFF00',
	'#8CF1FF',
	'#FF0000',
	'#0000FF',
	'#000000',
	'#FFFFFF',
	'#808080',
	'#800000',
	'#808000',
	'#008000',
	'#008080',
	'#000080',
	'#800080',
	'#FFA500',
	'#FFC0CB',
	'#FFD700',
	'#FF69B4',
	'#FF4500',
	'#FF1493',
	'#FF00FF',
	'#FF0000',
	'#FF00FF',
]

def get_patches(img, patch_size=256):
	# add extra dimension for later calculations
	img_patches = img.data.unfold(0,3,3)
	# break the image into patches (in height dimension)
	img_patches = img_patches.unfold(1, patch_size, patch_size)
	# break the image into patches (in width dimension)
	img_patches = img_patches.unfold(2, patch_size, patch_size)
	return img_patches

def get_scores(img_patches, prompt, window=6, stride=1):
	# initialize scores and runs arrays
	scores = torch.zeros(img_patches.shape[1], img_patches.shape[2])
	runs = torch.ones(img_patches.shape[1], img_patches.shape[2])

	# iterate through patches
	for Y in range(0, img_patches.shape[1]-window+1, stride):
		for X in range(0, img_patches.shape[2]-window+1, stride):
			# initialize array to store big patches
			big_patch = torch.zeros(patch*window, patch*window, 3)
			# get a single big patch
			patch_batch = img_patches[0, Y:Y+window, X:X+window]
			# iteratively build all big patches
			for y in range(window):
				for x in range(window):
					big_patch[y*patch:(y+1)*patch, x*patch:(x+1)*patch, :] = patch_batch[y, x].permute(1, 2, 0)
			inputs = processor(
				images=big_patch, # image trasmitted to the model
				return_tensors="pt", # return pytorch tensor
				text=prompt, # text trasmitted to the model
				padding=True
			).to(device) # move to device if possible

			score = model(**inputs).logits_per_image.item()
			# sum up similarity scores
			scores[Y:Y+window, X:X+window] += score
			# calculate the number of runs 
			runs[Y:Y+window, X:X+window] += 1
	# calculate average scores
	scores /= runs
	# clip scores
	for _ in range(3):
		scores = np.clip(scores-scores.mean(), 0, np.inf)
	# normalize scores
	scores = (scores - scores.min()) / (scores.max() - scores.min())
	print(type(scores), scores.shape, scores.min(), scores.max())
	return scores

def get_box(scores, patch_size=256, threshold=0.5):
	detection = scores > threshold
	# find box corners
	y_min, y_max = np.nonzero(detection)[:,0].min().item(), np.nonzero(detection)[:,0].max().item()+1
	x_min, x_max = np.nonzero(detection)[:,1].min().item(), np.nonzero(detection)[:,1].max().item()+1
	# convert from patch co-ords to pixel co-ords
	y_min *= patch_size
	y_max *= patch_size
	x_min *= patch_size
	x_max *= patch_size
	# calculate box height and width
	height = y_max - y_min
	width = x_max - x_min
	return x_min, y_min, width, height

def detect(prompts, img, patch_size=256, window=6, stride=1, threshold=0.5):
	# build image patches for detection
	img_patches = get_patches(img, patch_size)
	# convert image to format for displaying with matplotlib
	image = np.moveaxis(img.data.numpy(), 0, -1)
	# initialize plot to display image + bounding boxes
	fig, ax = plt.subplots(figsize=(10, 10))
	ax.imshow(image)
	# process image through object detection steps
	for i, prompt in enumerate(tqdm(prompts)):
		scores = get_scores(img_patches, prompt, window, stride)
		# Check if there's any detection above the threshold
		detection = scores.numpy() > threshold
		if np.any(detection):
			x, y, width, height = get_box(scores, patch_size, threshold)
			# Only add the rectangle if the detection meets the threshold
			rect = mpl.patches.Rectangle(
				(x, y), 
				width, 
				height, 
				linewidth=1.5,
				alpha=0.8,
				edgecolor=colors[i], 
				facecolor='none',
			)
			ax.add_patch(rect)
		else:
			print(f"No {prompt} detected in the image.")
	ax.axis('off')
	plt.show()

In [None]:
labels = ["dog", "ball", "cat", "butterfly", "car", "apple", "banana", "orange", "bird", "flower"]
detect(labels, img, window=4, stride=1, threshold=0.7)

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpl_patches
from tqdm import tqdm

colors = [
		'#0000FF', '#000000', '#FFFFFF', '#808080', '#800000',
		'#808000', '#008000', '#008080', '#000080', '#800080',
		'#FF00FF', '#00FF00', '#FAFF00', '#8CF1FF', '#FF0000',
		'#FFA500', '#FFC0CB', '#FFD700', '#FF69B4', '#FF4500',
		'#FF1493', '#FF00FF', '#FF0000', '#FF00FF',
]

def get_patches(img, patch_size=256):
		img_patches = img.data.unfold(0, 3, 3)
		img_patches = img_patches.unfold(1, patch_size, patch_size)
		img_patches = img_patches.unfold(2, patch_size, patch_size)
		return img_patches

def get_scores(img_patches, prompt, patch_size, window=6, stride=1):
		scores = torch.zeros(img_patches.shape[1], img_patches.shape[2])
		runs = torch.ones(img_patches.shape[1], img_patches.shape[2])

		for Y in range(0, img_patches.shape[1] - window + 1, stride):
				for X in range(0, img_patches.shape[2] - window + 1, stride):
						big_patch = torch.zeros(patch_size * window, patch_size * window, 3)
						patch_batch = img_patches[0, Y:Y + window, X:X + window]
						for y in range(window):
								for x in range(window):
										big_patch[y * patch_size:(y + 1) * patch_size, x * patch_size:(x + 1) * patch_size, :] = patch_batch[y, x].permute(1, 2, 0)
						inputs = processor(
								images=big_patch,
								return_tensors="pt",
								text=prompt,
								padding=True
						).to(device)

						score = model(**inputs).logits_per_image.item()
						scores[Y:Y + window, X:X + window] += score
						runs[Y:Y + window, X:X + window] += 1

		scores /= runs
		for _ in range(3):
				scores = np.clip(scores - scores.mean(), 0, np.inf)
		scores = (scores - scores.min()) / (scores.max() - scores.min())
		return scores

def get_box(scores, patch_size=256, threshold=0.5):
		# Convert scores to a NumPy array
		scores_np = scores.numpy()
		detection = scores_np > threshold
		if np.any(detection):
				# Unpack the tuple returned by np.nonzero
				y_indices, x_indices = np.nonzero(detection)
				y_min, y_max = y_indices.min().item(), y_indices.max().item() + 1
				x_min, x_max = x_indices.min().item(), x_indices.max().item() + 1
				y_min *= patch_size
				y_max *= patch_size
				x_min *= patch_size
				x_max *= patch_size
				height = y_max - y_min
				width = x_max - x_min
				return x_min, y_min, width, height
		return None

def detect(prompts, img, patch_size=256, window=6, stride=1, threshold=0.5):
		img_patches = get_patches(img, patch_size)
		image = np.moveaxis(img.data.numpy(), 0, -1)
		fig, ax = plt.subplots(figsize=(10, 10))
		ax.imshow(image)

		for i, prompt in enumerate(tqdm(prompts)):
				scores = get_scores(img_patches, prompt, patch_size, window, stride)
				box = get_box(scores, patch_size, threshold)
				if box:
						x, y, width, height = box
						rect = mpl_patches.Rectangle(
								(x, y),
								width,
								height,
								linewidth=1.5,
								alpha=0.8,
								edgecolor=colors[i],
								facecolor='none',
						)
						ax.add_patch(rect)
				else:
						print(f"No {prompt} detected in the image.")

		ax.axis('off')
		plt.savefig("output.png")
		# plt.show()

# Example usage
labels = ["dog", "ball", "cat", "butterfly", "car", "apple", "banana", "orange", "bird", "flower"]
detect(labels, img, window=4, stride=1, threshold=0.7)

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpl_patches
from tqdm import tqdm

colors = [
	'#FF00FF', '#00FF00', '#FAFF00', '#8CF1FF', '#FF0000',
	'#0000FF', '#000000', '#FFFFFF', '#808080', '#800000',
	'#808000', '#008000', '#008080', '#000080', '#800080',
	'#FFA500', '#FFC0CB', '#FFD700', '#FF69B4', '#FF4500',
]

def get_patches(img, patch_size=256):
	img_patches = img.data.unfold(0, 3, 3)
	img_patches = img_patches.unfold(1, patch_size, patch_size)
	img_patches = img_patches.unfold(2, patch_size, patch_size)
	return img_patches

def get_scores(img_patches, prompt, patch_size, window=6, stride=1):
	scores = torch.zeros(img_patches.shape[1], img_patches.shape[2])
	runs = torch.ones(img_patches.shape[1], img_patches.shape[2])
	
	for Y in range(0, img_patches.shape[1] - window + 1, stride):
		for X in range(0, img_patches.shape[2] - window + 1, stride):
			big_patch = torch.zeros(patch_size * window, patch_size * window, 3)
			patch_batch = img_patches[0, Y:Y + window, X:X + window]
			
			for y in range(window):
				for x in range(window):
					big_patch[y * patch_size:(y + 1) * patch_size, 
							x * patch_size:(x + 1) * patch_size, :] = patch_batch[y, x].permute(1, 2, 0)
			
			inputs = processor(
				images=big_patch,
				return_tensors="pt",
				text=prompt,
				padding=True
			).to(device)
			
			score = model(**inputs).logits_per_image.item()
			scores[Y:Y + window, X:X + window] += score
			runs[Y:Y + window, X:X + window] += 1
	
	scores /= runs
	return scores

def normalize_scores(scores):
	"""Normalize scores with improved contrast and filtering."""
	scores_np = scores.numpy()
	
	# Apply Gaussian smoothing to reduce noise
	from scipy.ndimage import gaussian_filter
	scores_np = gaussian_filter(scores_np, sigma=1.0)
	
	# Enhanced contrast normalization
	scores_np = np.clip(scores_np - scores_np.mean(), 0, np.inf)
	scores_np = (scores_np - scores_np.min()) / (scores_np.max() - scores_np.min() + 1e-8)
	
	return torch.from_numpy(scores_np)

def get_box(scores, patch_size=256, threshold=0.5, min_area=4):
	"""Get bounding box with minimum area requirement and confidence score."""
	scores_np = scores.numpy()
	detection = scores_np > threshold
	
	if np.sum(detection) < min_area:  # Minimum area threshold
		return None, 0.0
	
	if np.any(detection):
		y_indices, x_indices = np.nonzero(detection)
		y_min, y_max = y_indices.min().item(), y_indices.max().item() + 1
		x_min, x_max = x_indices.min().item(), x_indices.max().item() + 1
		
		# Calculate confidence as mean score in the detected region
		confidence = float(scores_np[y_min:y_max, x_min:x_max].mean())
		
		# Convert to image coordinates
		y_min *= patch_size
		y_max *= patch_size
		x_min *= patch_size
		x_max *= patch_size
		
		height = y_max - y_min
		width = x_max - x_min
		
		return (x_min, y_min, width, height), confidence
	
	return None, 0.0

def detect(prompts, img, patch_size=256, window=6, stride=1, threshold=0.5, conf_threshold=0.3):
	"""Improved detection with confidence filtering and legend."""
	img_patches = get_patches(img, patch_size)
	image = np.moveaxis(img.data.numpy(), 0, -1)
	
	fig, ax = plt.subplots(figsize=(12, 8))
	ax.imshow(image)
	
	# Store detected objects for legend
	detected_objects = []
	
	for i, prompt in enumerate(tqdm(prompts)):
		scores = get_scores(img_patches, prompt, patch_size, window, stride)
		scores = normalize_scores(scores)
		box, confidence = get_box(scores, patch_size, threshold)
		
		if box and confidence > conf_threshold:
			x, y, width, height = box
			rect = mpl_patches.Rectangle(
				(x, y),
				width,
				height,
				linewidth=2,
				alpha=0.8,
				edgecolor=colors[i % len(colors)],
				facecolor='none',
				label=f"{prompt} ({confidence:.2f})"
			)
			ax.add_patch(rect)
			detected_objects.append((prompt, confidence))
		else:
			print(f"No {prompt} detected with sufficient confidence.")
	
	# Add legend if objects were detected
	if detected_objects:
		ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
	
	ax.axis('off')
	plt.tight_layout()
	return detected_objects

# Example usage
labels = ["dog", "ball", "cat", "butterfly", "car", "apple", "banana", "orange", "bird", "flower"]
detected = detect(labels, img, 
				 window=4, 
				 stride=1, 
				 threshold=0.8,
				 conf_threshold=0.5)

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpl_patches
from tqdm import tqdm

colors = [
		'#FF00FF', '#00FF00', '#FAFF00', '#8CF1FF', '#FF0000',
		'#0000FF', '#000000', '#FFFFFF', '#808080', '#800000',
		'#808000', '#008000', '#008080', '#000080', '#800080',
		'#FFA500', '#FFC0CB', '#FFD700', '#FF69B4', '#FF4500',
]

def get_patches(img, patch_size=256):
		img_patches = img.data.unfold(0, 3, 3)
		img_patches = img_patches.unfold(1, patch_size, patch_size)
		img_patches = img_patches.unfold(2, patch_size, patch_size)
		return img_patches

def get_initial_score(img, prompt):
		"""Get initial whole-image relevance score for the prompt."""
		inputs = processor(
				images=img,
				return_tensors="pt",
				text=prompt,
				padding=True
		).to(device)
		
		return model(**inputs).logits_per_image.item()

def get_scores(img_patches, prompt, patch_size, window=6, stride=1):
		scores = torch.zeros(img_patches.shape[1], img_patches.shape[2])
		runs = torch.ones(img_patches.shape[1], img_patches.shape[2])
		
		for Y in range(0, img_patches.shape[1] - window + 1, stride):
				for X in range(0, img_patches.shape[2] - window + 1, stride):
						big_patch = torch.zeros(patch_size * window, patch_size * window, 3)
						patch_batch = img_patches[0, Y:Y + window, X:X + window]
						
						for y in range(window):
								for x in range(window):
										big_patch[y * patch_size:(y + 1) * patch_size, 
														x * patch_size:(x + 1) * patch_size, :] = patch_batch[y, x].permute(1, 2, 0)
						
						inputs = processor(
								images=big_patch,
								return_tensors="pt",
								text=prompt,
								padding=True
						).to(device)
						
						score = model(**inputs).logits_per_image.item()
						scores[Y:Y + window, X:X + window] += score
						runs[Y:Y + window, X:X + window] += 1
		
		scores /= runs
		return scores

def normalize_scores(scores):
		"""Normalize scores with improved contrast and filtering."""
		scores_np = scores.numpy()
		
		# Apply Gaussian smoothing to reduce noise
		from scipy.ndimage import gaussian_filter
		scores_np = gaussian_filter(scores_np, sigma=1.0)
		
		# Enhanced contrast normalization
		scores_np = np.clip(scores_np - scores_np.mean(), 0, np.inf)
		scores_np = (scores_np - scores_np.min()) / (scores_np.max() - scores_np.min() + 1e-8)
		
		return torch.from_numpy(scores_np)

def get_box(scores, patch_size=256, threshold=0.5, min_area=4):
		"""Get bounding box with minimum area requirement and confidence score."""
		scores_np = scores.numpy()
		detection = scores_np > threshold
		
		if np.sum(detection) < min_area:  # Minimum area threshold
				return None, 0.0
		
		if np.any(detection):
				y_indices, x_indices = np.nonzero(detection)
				y_min, y_max = y_indices.min().item(), y_indices.max().item() + 1
				x_min, x_max = x_indices.min().item(), x_indices.max().item() + 1
				
				# Calculate confidence as mean score in the detected region
				confidence = float(scores_np[y_min:y_max, x_min:x_max].mean())
				
				# Convert to image coordinates
				y_min *= patch_size
				y_max *= patch_size
				x_min *= patch_size
				x_max *= patch_size
				
				height = y_max - y_min
				width = x_max - x_min
				
				return (x_min, y_min, width, height), confidence
		
		return None, 0.0

def detect(prompts, img, patch_size=256, window=6, stride=1, threshold=0.5, relevance_threshold=0.2):
		"""Improved detection with automatic relevance filtering."""
		img_patches = get_patches(img, patch_size)
		image = np.moveaxis(img.data.numpy(), 0, -1)
		
		# First pass: check whole-image relevance for each prompt
		relevant_prompts = []
		relevance_scores = []
		
		print("Checking image-level relevance for each label...")
		for prompt in tqdm(prompts):
				relevance = get_initial_score(img, prompt)
				if relevance > relevance_threshold:
						relevant_prompts.append(prompt)
						relevance_scores.append(relevance)
		
		if not relevant_prompts:
				print("No relevant objects detected in the image.")
				return []
		
		# Sort prompts by relevance score
		sorted_prompts = [x for _, x in sorted(zip(relevance_scores, relevant_prompts), reverse=True)]
		
		fig, ax = plt.subplots(figsize=(12, 8))
		ax.imshow(image)
		
		detected_objects = []
		
		print("\nDetecting and localizing relevant objects...")
		for i, prompt in enumerate(tqdm(sorted_prompts)):
				scores = get_scores(img_patches, prompt, patch_size, window, stride)
				scores = normalize_scores(scores)
				box, confidence = get_box(scores, patch_size, threshold)
				
				if box:
						x, y, width, height = box
						rect = mpl_patches.Rectangle(
								(x, y),
								width,
								height,
								linewidth=2,
								alpha=0.8,
								edgecolor=colors[i % len(colors)],
								facecolor='none',
								label=f"{prompt} ({confidence:.2f})"
						)
						ax.add_patch(rect)
						detected_objects.append((prompt, confidence))
		
		if detected_objects:
				ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
		
		ax.axis('off')
		plt.tight_layout()
		plt.savefig("output_detection.png")
		
		# Print summary
		print("\nDetection Summary:")
		print(f"Total labels checked: {len(prompts)}")
		print(f"Relevant labels found: {len(relevant_prompts)}")
		print(f"Objects localized: {len(detected_objects)}")
		
		return detected_objects

# Example usage with all labels
labels = ["dog", "ball", "cat", "butterfly", "car", "apple", "banana", "orange", "bird", "flower"]
detected = detect(
	labels, 
	img, 
	window=4, 
	stride=1, 
	threshold=0.86, # Confidence threshold for bounding boxes
	relevance_threshold=0.95, # Relevance threshold for initial detection
)

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpl_patches
from tqdm import tqdm
from scipy.ndimage import gaussian_filter

colors = [
    '#FF00FF', '#00FF00', '#FAFF00', '#8CF1FF', '#FF0000',
    '#0000FF', '#000000', '#FFFFFF', '#808080', '#800000',
]

def get_patches(img, patch_size=256):
    img_patches = img.data.unfold(0, 3, 3)
    img_patches = img_patches.unfold(1, patch_size, patch_size)
    img_patches = img_patches.unfold(2, patch_size, patch_size)
    return img_patches

def get_initial_score(img, prompt):
    """Get initial whole-image relevance score with multiple viewpoints."""
    # Check the whole image
    inputs_full = processor(
        images=img,
        return_tensors="pt",
        text=prompt,
        padding=True
    ).to(device)
    full_score = model(**inputs_full).logits_per_image.item()
    
    # Check negative prompt to establish baseline
    negative_prompt = f"an image without a {prompt}"
    inputs_neg = processor(
        images=img,
        return_tensors="pt",
        text=negative_prompt,
        padding=True
    ).to(device)
    neg_score = model(**inputs_neg).logits_per_image.item()
    
    # Calculate relative confidence
    relative_confidence = full_score / (full_score + neg_score + 1e-6)
    
    return relative_confidence

def get_scores(img_patches, prompt, patch_size, window=6, stride=1):
    scores = torch.zeros(img_patches.shape[1], img_patches.shape[2])
    runs = torch.ones(img_patches.shape[1], img_patches.shape[2])
    
    # Add negative checking for each patch
    negative_prompt = f"an image without a {prompt}"
    
    for Y in range(0, img_patches.shape[1] - window + 1, stride):
        for X in range(0, img_patches.shape[2] - window + 1, stride):
            big_patch = torch.zeros(patch_size * window, patch_size * window, 3)
            patch_batch = img_patches[0, Y:Y + window, X:X + window]
            
            for y in range(window):
                for x in range(window):
                    big_patch[y * patch_size:(y + 1) * patch_size, 
                            x * patch_size:(x + 1) * patch_size, :] = patch_batch[y, x].permute(1, 2, 0)
            
            # Check positive prompt
            inputs_pos = processor(
                images=big_patch,
                return_tensors="pt",
                text=prompt,
                padding=True
            ).to(device)
            pos_score = model(**inputs_pos).logits_per_image.item()
            
            # Check negative prompt
            inputs_neg = processor(
                images=big_patch,
                return_tensors="pt",
                text=negative_prompt,
                padding=True
            ).to(device)
            neg_score = model(**inputs_neg).logits_per_image.item()
            
            # Calculate relative confidence
            relative_score = pos_score / (pos_score + neg_score + 1e-6)
            
            scores[Y:Y + window, X:X + window] += relative_score
            runs[Y:Y + window, X:X + window] += 1
    
    scores /= runs
    return scores

def verify_detection(img, box, prompt, threshold):
    """Verify detection by checking the specific region."""
    if box is None:
        return False, 0.0
    
    x, y, width, height = box
    region = img[:, y:y+height, x:x+width]
    
    # Check positive prompt
    inputs_pos = processor(
        images=region,
        return_tensors="pt",
        text=prompt,
        padding=True
    ).to(device)
    pos_score = model(**inputs_pos).logits_per_image.item()
    
    # Check negative prompt
    negative_prompt = f"an image without a {prompt}"
    inputs_neg = processor(
        images=region,
        return_tensors="pt",
        text=negative_prompt,
        padding=True
    ).to(device)
    neg_score = model(**inputs_neg).logits_per_image.item()
    
    # Calculate final confidence
    final_confidence = pos_score / (pos_score + neg_score + 1e-6)
    
    return final_confidence > threshold, final_confidence

def detect(prompts, img, patch_size=256, window=6, stride=1, threshold=0.86, relevance_threshold=0.95):
    """Improved detection with robust filtering and verification."""
    img_patches = get_patches(img, patch_size)
    image = np.moveaxis(img.data.numpy(), 0, -1)
    
    # First pass: check whole-image relevance
    relevant_prompts = []
    relevance_scores = []
    
    print("Checking image-level relevance for each label...")
    for prompt in tqdm(prompts):
        relevance = get_initial_score(img, prompt)
        if relevance > relevance_threshold:
            relevant_prompts.append(prompt)
            relevance_scores.append(relevance)
    
    if not relevant_prompts:
        print("No relevant objects detected in the image.")
        return []
    
    # Sort prompts by relevance
    sorted_pairs = sorted(zip(relevance_scores, relevant_prompts), reverse=True)
    
    fig, ax = plt.subplots(figsize=(12, 8))
    ax.imshow(image)
    
    detected_objects = []
    
    print("\nDetecting and verifying objects...")
    for score, prompt in tqdm(sorted_pairs):
        scores = get_scores(img_patches, prompt, patch_size, window, stride)
        scores = gaussian_filter(scores.numpy(), sigma=1.0)
        scores = torch.from_numpy(scores)
        
        # Get initial box
        box, confidence = get_box(scores, patch_size, threshold)
        
        # Verify detection
        if box is not None:
            is_valid, final_confidence = verify_detection(img, box, prompt, threshold)
            
            if is_valid:
                x, y, width, height = box
                rect = mpl_patches.Rectangle(
                    (x, y),
                    width,
                    height,
                    linewidth=2,
                    alpha=0.8,
                    edgecolor=colors[len(detected_objects) % len(colors)],
                    facecolor='none',
                    label=f"{prompt} ({final_confidence:.2f})"
                )
                ax.add_patch(rect)
                detected_objects.append((prompt, final_confidence))
    
    if detected_objects:
        ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    
    ax.axis('off')
    plt.tight_layout()
    
    # Print summary
    print("\nDetection Summary:")
    print(f"Total labels checked: {len(prompts)}")
    print(f"Objects detected and verified: {len(detected_objects)}")
    for obj, conf in detected_objects:
        print(f"- {obj}: {conf:.2f}")
    
    return detected_objects

In [None]:
# Example usage with all labels
labels = ["dog", "ball", "cat", "butterfly", "car", "apple", "banana", "orange", "bird", "flower"]
detected = detect(
	labels,
	img, 
	window=4, 
	stride=1, 
	threshold=0.46, # Confidence threshold for bounding boxes
	relevance_threshold=0.45, # Relevance threshold for initial detection
)

---