In [1]:
import os
from ClipAdapter import DirDataset
from ClipAdapter import ClipAdapter
import clip
import torch
from tqdm import tqdm
from PIL import Image
from torch.utils.data import DataLoader
from sklearn import metrics
from torchvision import transforms
from utils.utils import inin_random
model_name="ViT-L/14"
device = "cuda:0" if torch.cuda.is_available() else "cpu"

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
#zero shot full dataset
clip_model, preprocess = clip.load(model_name, device=device)
root='dataset/rebar_tying'
class_dir_map={
    "a photo of a worker squatting or bending to tie steel bars":"12",
    "a photo of a worker doing non-rebar work or taking a break":"3",
}

dataset_val=DirDataset(root=root,class_dir_map=class_dir_map,transform=preprocess)
dataloader_val=DataLoader(dataset=dataset_val,batch_size=16,num_workers=4,shuffle=True)
all_targets = []
all_predictions = []
with torch.no_grad():
    text = clip.tokenize(list(class_dir_map.keys()))
    text=text.to(device)
    text_features = clip_model.encode_text(text)
    text_features = text_features / text_features.norm(dim=-1, keepdim=True)

    for i,(imgs,label,_) in enumerate(tqdm(dataloader_val)):
        imgs= imgs.to(device)
        image_features = clip_model.encode_image(imgs)
        image_features_norm = image_features / image_features.norm(dim=-1, keepdim=True)
        logits_per_image = 100. * image_features_norm @ text_features.t()

        probs = logits_per_image.softmax(dim=-1).cpu()
        pred_label=(probs.argmax(dim=1))

        all_targets.extend(label.cpu().numpy())
        all_predictions.extend(pred_label.cpu().numpy())

accuracy = metrics.accuracy_score(all_targets, all_predictions)
precision = metrics.precision_score(all_targets, all_predictions, average=None)
recall = metrics.recall_score(all_targets, all_predictions,average=None)
f1 = metrics.f1_score(all_targets, all_predictions,average=None)
print("\n**** Zero-shot CLIP's val accuracy: {:.2f}. ****\n".format(accuracy*100))
print(precision)
print(recall)
print(f1)

100%|██████████| 70/70 [00:08<00:00,  8.61it/s]


**** Zero-shot CLIP's val accuracy: 48.44. ****

[0.5030525  0.43333333]
[0.70790378 0.24208566]
[0.58815132 0.31063321]





In [None]:
from torchvision.models import resnet50
import torch
import clip
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm import tqdm

import torch.optim as optim
from torchvision.models import resnet50
model=resnet50(num_classes=13)

In [None]:
def train_model(model, train_loader, loss_fn, optimizer, epoch):
    model.train()
    total_loss = 0.
    total_corrects = 0.
    total = 0.
    for idx, (inputs, labels) in enumerate(tqdm(train_loader)):
        inputs = inputs.to(device)
        labels = labels.to(device)
        outputs = model(inputs)
        loss = loss_fn(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        preds = outputs.argmax(dim=1)
        total_corrects += torch.sum(preds.eq(labels))
        total_loss += loss.item() * inputs.size(0)
        total += labels.size(0)
    total_loss = total_loss / total
    acc = 100 * total_corrects / total
    print("轮次:%4d|训练集损失:%.5f|训练集准确率:%6.2f%%" % (epoch + 1, total_loss, acc))
    return total_loss, acc
 
 
def test_model(model, test_loader, loss_fn, optimizer, epoch):
    model.eval()
    total_loss = 0.
    total_corrects = 0.
    total = 0.
    with torch.no_grad():
        for idx, (inputs, labels) in enumerate(tqdm(test_loader)):
            inputs = inputs.to(device)
            labels = labels.to(device)
            outputs = model(inputs)
            loss = loss_fn(outputs, labels)
            preds = outputs.argmax(dim=1)
            total += labels.size(0)
            total_loss += loss.item() * inputs.size(0)
            total_corrects += torch.sum(preds.eq(labels))
 
        loss = total_loss / total
        accuracy = 100 * total_corrects / total
        print("轮次:%4d|训练集损失:%.5f|训练集准确率:%6.2f%%" % (epoch + 1, loss, accuracy))
        return loss, accuracy
 
 
loss_fn = nn.CrossEntropyLoss().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.0001)
epoches=10
for epoch in range(0, epoches):
    loss1, acc1 = train_model(model, dataloader_train, loss_fn, optimizer, epoch)
    loss2, acc2 = test_model(model, dataloader_val, loss_fn, optimizer, epoch)
model.train()

In [None]:
clip_model, preprocess = clip.load(model_name, device=device)
class_dir_map={
    "a photo of a worker tying steel bars":"12",
    "a photo of a worker doing non-rebar work":"3",
}
if True:
    seed = 1
    inin_random(seed)
    train_tranform = transforms.Compose([
            transforms.RandomResizedCrop(size=224, scale=(0.8, 1), interpolation=transforms.InterpolationMode.BICUBIC),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
        ])
    # root='dataset/rebar_tying'
    root=r'/CV/gaobiaoli/dataset/rebar_tying'
    seed=1000
    dataset_shot=DirDataset(root=root,class_dir_map=class_dir_map,transform=train_tranform,few_shot=16,random_seed=seed)
    dataloader_shot=DataLoader(dataset=dataset_shot,batch_size=16,num_workers=4,shuffle=False)
    clip_adapter=ClipAdapter(model=clip_model,dataloader=dataloader_shot,classnames=dataset_shot.classnames,augment_epoch=10,alpha=10,beta=1,device=device)
    dataset_test=DirDataset(root=root,class_dir_map=class_dir_map,transform=train_tranform,few_shot=16,random_seed=seed,reverse=True)
    dataloader_test=DataLoader(dataset=dataset_test,batch_size=16,num_workers=4,shuffle=True)
    clip_adapter.pre_load_features(dataloader_test)
    clip_adapter.search_hp(beta_search=False)
    all_predictions, all_targets,(accuracy0,precision,recall,f1)=clip_adapter.eval(adapt=False)
    print("\n**** Zero-shot CLIP's val accuracy: {:.2f}. ****\n".format(accuracy0*100))
    print(precision)
    print(recall)
    print(f1)
    all_predictions, all_targets,(accuracy1,precision,recall,f1)=clip_adapter.eval(adapt=True)
    print("\n**** Few-shot CLIP's val accuracy: {:.2f}. ****\n".format(accuracy1*100))
    print(precision)
    print(recall)
    print(f1)

    

100%|██████████| 10/10 [00:03<00:00,  2.65it/s]
100%|██████████| 68/68 [00:06<00:00, 11.08it/s]


New best HP, beta: 1.00, alpha: 5.09; accuracy: 82.89

**** Zero-shot CLIP's val accuracy: 53.54. ****

[0.68944099 0.50863931]
[0.19611307 0.90403071]
[0.30536451 0.65100207]

**** Few-shot CLIP's val accuracy: 82.89. ****

[0.86964981 0.79232112]
[0.78975265 0.87140115]
[0.82777778 0.82998172]
