In [1]:
import torch
import clip
from PIL import Image
import os
import matplotlib.pyplot as plt  # Import the matplotlib library for image visualization
import json

# Check if CUDA (GPU) is available, otherwise use CPU
device = "cuda" if torch.cuda.is_available() else "cpu"

In [2]:
# Load the CLIP model and preprocessing pipeline
model, preprocess = clip.load("ViT-B/32", device=device)

In [3]:
# List of clothing items for comparison
objects = ["airplane", "boat", "car", "bike",]

In [4]:
json_path = 'rsicd/annotations.json'


In [21]:
# path to test data
f = open(json_path)
data = json.load(f)
input_data = data['images']
f.close()

In [16]:
# Index of the input data you want to analyze
index_ = 0

# Assuming 'input_data' is a list of JSON-like objects with image information
image_json = input_data[index_]

# Construct the full path to the image file using the given 'image_path'
image_path = os.path.join("rsicd/images/", image_json['filename'])

# Get the class label of the image
# image_class = image_json['class_label']

# Preprocess the image and move it to the appropriate device (CPU or GPU)
image = preprocess(Image.open(image_path)).unsqueeze(0).to(device)

# Tokenize and move the clothing item names to the appropriate device
text = torch.cat([clip.tokenize(f"a photo of a {c}") for c in objects]).to(device)

In [17]:
# Perform inference
with torch.no_grad():
    # Encode image and text
    image_features = model.encode_image(image)
    text_features = model.encode_text(text)
    
    # Calculate similarity scores between image and text
    logits_per_image, logits_per_text = model(image, text)
    probs = logits_per_image.softmax(dim=-1).cpu().numpy()

In [18]:
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)

In [19]:
# Calculate similarity scores
similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
values, indices = similarity[0].topk(4)

In [20]:
# Print the top predictions
print("\nTop predictions:\n")
for value, index in zip(values, indices):
    print(f"{objects[index]:>16s}: {100 * value.item():.2f}%")


Top predictions:

        airplane: 99.61%
             car: 0.32%
            bike: 0.04%
            boat: 0.02%
