In [1]:
import torch
import clip
from PIL import Image
import pandas as pd

## Clip Testing

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

image = preprocess(Image.open("../data/images/CLIP.png")).unsqueeze(0).to(device)
text = clip.tokenize(["a diagram", "a dog", "a cat"]).to(device)

with torch.no_grad():
    image_features = model.encode_image(image)
    text_features = model.encode_text(text)
    
    logits_per_image, logits_per_text = model(image, text)
    probs = logits_per_image.softmax(dim=-1).cpu().numpy()

print("Label probs:", probs)  # prints: [[0.9927937  0.00421068 0.00299572]]

Label probs: [[0.9927   0.004185 0.002968]]


## Crowd Counting - Testing I 

In [48]:
df = pd.read_csv("../data/images/CrowdCountingKaggle/labels.csv")
df.id = df.id.map(lambda x: str(x) + ".jpg")
df = df.rename(columns={"count": "y_true"})

# create a sample with random id's to interact with CLIP model
df_sample = df.sample(20, random_state=42).sort_values(by="y_true")
display(df_sample)
image_id_list = df_sample.id.to_list()
y_true = list(map(lambda x: x+' people', df_sample["y_true"].astype(str).to_list()))
print(y_true)

Unnamed: 0,id,y_true
584,585.jpg,20
56,57.jpg,23
1289,1290.jpg,25
65,66.jpg,27
1292,1293.jpg,28
1118,1119.jpg,28
374,375.jpg,30
1860,1861.jpg,30
938,939.jpg,31
1273,1274.jpg,33


['20 people', '23 people', '25 people', '27 people', '28 people', '28 people', '30 people', '30 people', '31 people', '33 people', '33 people', '34 people', '34 people', '34 people', '36 people', '40 people', '42 people', '43 people', '44 people', '45 people']


In [49]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

images_path = "../data/images/CrowdCountingKaggle/frames/" 
output = []

for index in range(len(image_id_list)):
    image = preprocess(Image.open(images_path + image_id_list[index])).unsqueeze(0).to(device)
    text = clip.tokenize(y_true).to(device)

    with torch.no_grad():
        image_features = model.encode_image(image)
        text_features = model.encode_text(text)
        
        logits_per_image, logits_per_text = model(image, text)
        probs = logits_per_image.softmax(dim=-1).cpu().numpy()

    # print("Label probs:", probs)

    # max_value and first_position_of_max
    max_value = max(probs.tolist()[0])
    max_index_position = probs.tolist()[0].index(max_value)

    # store the results in the output list
    output.append([image_id_list[index], 
                   max_value, 
                   max_index_position, 
                   y_true[max_index_position]])

output

[['585.jpg', 0.06732177734375, 6, '30 people'],
 ['57.jpg', 0.07696533203125, 0, '20 people'],
 ['1290.jpg', 0.0899658203125, 0, '20 people'],
 ['66.jpg', 0.06298828125, 0, '20 people'],
 ['1293.jpg', 0.084228515625, 0, '20 people'],
 ['1119.jpg', 0.0731201171875, 0, '20 people'],
 ['375.jpg', 0.08941650390625, 0, '20 people'],
 ['1861.jpg', 0.08880615234375, 0, '20 people'],
 ['939.jpg', 0.0711669921875, 2, '25 people'],
 ['1274.jpg', 0.09088134765625, 0, '20 people'],
 ['276.jpg', 0.06988525390625, 0, '20 people'],
 ['354.jpg', 0.0933837890625, 0, '20 people'],
 ['129.jpg', 0.0716552734375, 0, '20 people'],
 ['747.jpg', 0.082763671875, 0, '20 people'],
 ['1324.jpg', 0.08447265625, 0, '20 people'],
 ['1647.jpg', 0.07763671875, 0, '20 people'],
 ['906.jpg', 0.07305908203125, 0, '20 people'],
 ['1853.jpg', 0.07757568359375, 6, '30 people'],
 ['1732.jpg', 0.09222412109375, 0, '20 people'],
 ['1334.jpg', 0.08465576171875, 0, '20 people']]

In [None]:
df_results = pd.DataFrame(output, columns=["id", "probability", "max_index_position", "y_pred"])
df_sample.merge(df_results, on="id")

# remove duplicate values from y_true?

Unnamed: 0,id,y_true,probability,max_index_position,y_pred
0,585.jpg,20,0.067322,6,30 people
1,57.jpg,23,0.076965,0,20 people
2,1290.jpg,25,0.089966,0,20 people
3,66.jpg,27,0.062988,0,20 people
4,1293.jpg,28,0.084229,0,20 people
5,1119.jpg,28,0.07312,0,20 people
6,375.jpg,30,0.089417,0,20 people
7,1861.jpg,30,0.088806,0,20 people
8,939.jpg,31,0.071167,2,25 people
9,1274.jpg,33,0.090881,0,20 people
