In [2]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
from ClipAdapter import DirDataset
from ClipAdapter import ClipAdapter
import clip
import torch
from tqdm import tqdm
from PIL import Image
from torch.utils.data import DataLoader
from sklearn import metrics
from torchvision import transforms
# os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
device = "cuda:0" if torch.cuda.is_available() else "cpu"
clip_model, preprocess = clip.load("ViT-B/32", device=device)

In [13]:
#zero shot full dataset
root=r'/CV/gaobiaoli/dataset/rebar_tying'
class_dir_map={
    "a photo of a worker squatting or bending to tie steel bars":"12",
    "a photo of a worker doing non-rebar work or taking a break":"3",
}

dataset_val=DirDataset(root=root,class_dir_map=class_dir_map,transform=preprocess)
dataloader_val=DataLoader(dataset=dataset_val,batch_size=16,num_workers=4,shuffle=True)
all_targets = []
all_predictions = []
with torch.no_grad():
    text = clip.tokenize(list(class_dir_map.keys()))
    text=text.to(device)
    text_features = clip_model.encode_text(text)
    text_features = text_features / text_features.norm(dim=-1, keepdim=True)

    for i,(imgs,label,_) in enumerate(tqdm(dataloader_val)):
        imgs= imgs.to(device)
        image_features = clip_model.encode_image(imgs)
        image_features_norm = image_features / image_features.norm(dim=-1, keepdim=True)
        logits_per_image = 100. * image_features_norm @ text_features.t()

        probs = logits_per_image.softmax(dim=-1).cpu()
        pred_label=(probs.argmax(dim=1))

        all_targets.extend(label.cpu().numpy())
        all_predictions.extend(pred_label.cpu().numpy())

accuracy = metrics.accuracy_score(all_targets, all_predictions)
precision = metrics.precision_score(all_targets, all_predictions, average=None)
recall = metrics.recall_score(all_targets, all_predictions,average=None)
f1 = metrics.f1_score(all_targets, all_predictions,average=None)
print("\n**** Zero-shot CLIP's val accuracy: {:.2f}. ****\n".format(accuracy*100))
print(precision)
print(recall)
print(f1)

100%|██████████| 70/70 [00:01<00:00, 44.62it/s]


**** Zero-shot CLIP's val accuracy: 45.04. ****

[0.48135593 0.33333333]
[0.73195876 0.1452514 ]
[0.5807771  0.20233463]





In [17]:
len(dataset_val.imgs_list)
len(os.listdir("/CV/gaobiaoli/dataset/rebar_tying/3"))

537

In [18]:
# clip_model, preprocess = clip.load("ViT-L/14", device=device)
clip_model, preprocess = clip.load("ViT-B/16", device=device)
# clip_model, preprocess = clip.load("ViT-B/32", device=device)
class_dir_map={
    "a photo of a worker tying steel bars":"12",
    "a photo of a worker doing non-rebar work":"3",
}
if True:
    seed = 1
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    train_tranform = transforms.Compose([
            transforms.RandomResizedCrop(size=224, scale=(0.8, 1), interpolation=transforms.InterpolationMode.BICUBIC),
            transforms.RandomHorizontalFlip(p=0.5),
            # transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
        ])
    root=r'/CV/gaobiaoli/dataset/rebar_tying'
    seed=1000
    dataset_shot=DirDataset(root=root,class_dir_map=class_dir_map,transform=train_tranform,few_shot=16,random_seed=seed)
    dataloader_shot=DataLoader(dataset=dataset_shot,batch_size=16,num_workers=4,shuffle=False)
    clip_adapter=ClipAdapter(model=clip_model,dataloader=dataloader_shot,classnames=dataset_shot.classnames,augment_epoch=10,alpha=10,beta=1,device=device)
    dataset_test=DirDataset(root=root,class_dir_map=class_dir_map,transform=train_tranform,few_shot=16,random_seed=seed,reverse=True)
    dataloader_test=DataLoader(dataset=dataset_test,batch_size=16,num_workers=4,shuffle=True)
    clip_adapter.pre_load_features(dataloader_test)
    clip_adapter.search_hp(beta_search=False)
    all_predictions, all_targets,(accuracy0,precision,recall,f1)=clip_adapter.eval(adapt=False)
    print("\n**** Zero-shot CLIP's val accuracy: {:.2f}. ****\n".format(accuracy0*100))
    print(precision)
    print(recall)
    print(f1)
    all_predictions, all_targets,(accuracy1,precision,recall,f1)=clip_adapter.eval(adapt=True)
    print("\n**** Few-shot CLIP's val accuracy: {:.2f}. ****\n".format(accuracy1*100))
    print(precision)
    print(recall)
    print(f1)

100%|██████████| 10/10 [00:03<00:00,  2.80it/s]
100%|██████████| 68/68 [00:08<00:00,  7.84it/s]

New best HP, beta: 1.00, alpha: 7.58; accuracy: 80.86

**** Zero-shot CLIP's val accuracy: 37.44. ****

[0.321875   0.39634941]
[0.1819788  0.58349328]
[0.23250564 0.47204969]

**** Few-shot CLIP's val accuracy: 80.86. ****

[0.78870968 0.83511777]
[0.8639576  0.74856046]
[0.82462057 0.78947368]



