In [15]:
import os
import clip
import torch
import PIL
import numpy as np

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("/hy-tmp/clip_model/ViT-B-32.pt", device)
model.cuda().eval()

CLIP(
  (visual): VisionTransformer(
    (conv1): Conv2d(3, 768, kernel_size=(32, 32), stride=(32, 32), bias=False)
    (ln_pre): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (transformer): Transformer(
      (resblocks): Sequential(
        (0): ResidualAttentionBlock(
          (attn): MultiheadAttention(
            (out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
          )
          (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): Sequential(
            (c_fc): Linear(in_features=768, out_features=3072, bias=True)
            (gelu): QuickGELU()
            (c_proj): Linear(in_features=3072, out_features=768, bias=True)
          )
          (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        )
        (1): ResidualAttentionBlock(
          (attn): MultiheadAttention(
            (out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
          )
          (ln_1): LayerNorm((7

In [20]:
from torch.utils.data import DataLoader

class ImageDataset(torch.utils.data.Dataset):
    
    def __init__(self, img_dir):
        self.imgs = []
        self.labels = []
        label_dict = {'Aircraft_Carrier': 0, 
                      'Amphibious_Assault_Ship': 1, 
                      'Fast_Combat_Support_Ships': 2,
                      'Guided_Missile_Cruiser': 3,
                      'Guided_Missile_Destroyer': 4}
        for label in os.listdir(img_dir):
            dir_path = os.path.join(img_dir, label)
            for img in os.listdir(dir_path):
                if img.endswith((".jpg", ".png")):
                    img_path = os.path.join(dir_path, img)
                    img_label = label
                    self.imgs.append(img_path)
                    self.labels.append(img_label)
        
    def __len__(self):
        return len(self.imgs)
    
    def __getitem__(self, index):
        img = self.imgs[index]
        label = self.labels[index]
        return img, label

In [26]:
test_ships = ImageDataset("/hy-tmp/5_types_ships_small/test")
test_loader = torch.utils.data.DataLoader(test_ships, batch_size=1, shuffle=False, drop_last=False, num_workers=32)

In [47]:
def get_key(dic, value):
    swapped_dict = {value: key for key, value in label_dict.items()}
    return swapped_dict[value.item()]

In [62]:
top_1_correct = 0
top_2_correct = 0
top_3_correct = 0
total = 0
with torch.no_grad():
    for data in test_loader:
        inputs, labels = data
        
        # images = [preprocess(PIL.Image.open(image)) for image in inputs]
        # images = torch.tensor(np.stack(images))
        image = inputs[0]
        label_dict = {'Aircraft_Carrier': 0, 
                      'Amphibious_Assault_Ship': 1, 
                      'Fast_Combat_Support_Ships': 2,
                      'Guided_Missile_Cruiser': 3,
                      'Guided_Missile_Destroyer': 4}
        image_input = preprocess(PIL.Image.open(image)).unsqueeze(0).to(device)
        text_inputs = torch.cat([clip.tokenize(f"a photo of a {label}") for label in label_dict.keys()]).to(device)

        # image_input = torch.tensor(images).half().to(device)
        # text_input = torch.cat([clip.tokenize(f"a photo of a {label}") for label in labels]).to(device)
        
        # print(image_input.size())
        # print(text_inputs.size())
        image_features = model.encode_image(image_input)
        text_features = model.encode_text(text_inputs)
        
        image_features /= image_features.norm(dim=-1, keepdim=True)
        text_features /= text_features.norm(dim=-1, keepdim=True)
        similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
        values, indices = similarity[0].topk(5)
        # print(values, indices)
        
        # Print the result
        # print("\nTop predictions:\n")
        total += 1
        for value, index in zip(values, indices):
            # print(f"True label: {labels[0]}")
            # print(f"{get_key(label_dict, index):>16s}: {100 * value.item():.2f}%")
            pass
        for i, top in enumerate(list(indices)[:3]):
            if label_dict[labels[0]] == top.item():
                if i == 0:
                    top_1_correct += 1 
                    top_2_correct += 1 
                    top_3_correct += 1
                elif i == 1:
                    top_2_correct += 1
                    top_3_correct += 1
                elif i == 2:
                    top_3_correct += 1
                        
    print('Top-1 accuracy of the network on the %d test images: %.2f %%' % (len(test_loader.dataset), 100 * top_1_correct / total))
    print('Top-2 accuracy of the network on the %d test images: %.2f %%' % (len(test_loader.dataset), 100 * top_2_correct / total))
    print('Top-3 accuracy of the network on the %d test images: %.2f %%' % (len(test_loader.dataset), 100 * top_3_correct / total))

Top-1 accuracy of the network on the 449 test images: 42.54 %
Top-2 accuracy of the network on the 449 test images: 71.71 %
Top-3 accuracy of the network on the 449 test images: 86.41 %
