In [20]:
import os
from PIL import Image
import numpy as np
import clip
from loguru import logger
from torch.utils.data import Dataset, DataLoader, ConcatDataset
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.nn as nn
import torch
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


In [31]:
class YourDataset(Dataset):
    def __init__(self,img_root,meta_root,is_train,preprocess):
        # 1.根目录(根据自己的情况更改)
        self.img_root = img_root
        self.meta_root = meta_root
        # 2.训练图片和测试图片地址(根据自己的情况更改)
        self.train_set_file = os.path.join(meta_root,'train.txt')
        self.test_set_file = os.path.join(meta_root,'test.txt')
        # 3.训练 or 测试(根据自己的情况更改)
        self.is_train = is_train
        # 4.处理图像
        self.img_process = preprocess
        # 5.获得数据(根据自己的情况更改)
        self.samples = []
        self.sam_labels = []
        # 5.1 训练还是测试数据集
        self.read_file = ""
        if is_train:
            self.read_file = self.train_set_file
        else:
            self.read_file = self.test_set_file
		# 5.2 获得所有的样本(根据自己的情况更改)
        idx = 0
        with open(self.read_file,'r') as f:

            for line in f:
                
                img_path = os.path.join(self.img_root,line.strip() + '.jpeg')
                
                label = line.strip().split('/')[0]
                label = label.replace("_"," ")
                label = "photo if " + label
                self.samples.append(img_path)
                self.sam_labels.append(label)
                print("idx: ", idx, img_path, " ", label)
        # 转换为token
        self.tokens = clip.tokenize(self.sam_labels)

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path = self.samples[idx]
        token = self.tokens[idx]
        # 加载图像
        image = Image.open(img_path).convert('RGB')
        # 对图像进行转换
        image = self.img_process(image)
        return image,token

In [None]:
from torch.utils.data import DataLoader 
from datasets import load_dataset 
from torchvision import transforms
from PIL import Image 
import torch
from PIL import Image 

class Flickr30kDataset(torch.utils.data.Dataset):
    def __init__(self):
        self.dataset = load_dataset("nlphuji/flockr30k", cache_dir="~/.huggingface_data")
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
        ])
        self.cap_per_image = 2

def __len__(self):
    return self.dataset.num_rows["test"] * self.cap_per_image

def __getitem__(self, idx):
    original_idx = idx // self.cap_per_image
    image = self.dataset["test"][original_idx]["image"].convert("RGB")
    image = self.transform(image)

    caption = self.dataset["test"][original_idx][idx % self.cap_per_image]
    return {"image": image, "caption": caption}



In [60]:
net, preprocess = clip.load("RN50",device=device,jit=False)



In [46]:
mydataset = YourDataset("./data/images", "./data", is_train=True, preprocess=preprocess)
print(mydataset)
clip_dataloader = DataLoader(mydataset, batch_size=4, shuffle=True, num_workers=4)
print(len(mydataset))
data_size = len(mydataset)


idx:  0 ./data/images/cat.jpeg   photo if cat
idx:  0 ./data/images/cock.jpeg   photo if cock
idx:  0 ./data/images/dog.jpeg   photo if dog
idx:  0 ./data/images/bus.jpeg   photo if bus
<__main__.YourDataset object at 0x73975094d910>
4


In [63]:
image_path="./data/images/bus.jpeg"
labels = ["photo if cat", "photo if cock", "photo if dog", "photo if bus"]

In [61]:
def evaluate(net, preprocess, image_path, labels, use_my=False):

    image = Image.open(image_path)
    print(type(image))
    image = preprocess(image)
    print(type(image), image.size(), image.device)
    image = image.unsqueeze(0)
    image.size()
    image = image.to(device)
    print(image.device)


    text = clip.tokenize(labels).to(device)
    if use_my == True:
        state_dict = torch.load("my_model_epoch_49.pth")
        net.load_state_dict(state_dict)
        net.eval()


    with torch.no_grad():
        image_features = net.encode_image(image)
        text_features = net.encode_text(text)
        logits_per_image, logits_per_text = net(image, text)
        probs = logits_per_image.softmax(dim=-1).cpu()
        print(probs)
        

In [64]:
evaluate(net, preprocess, image_path, labels, use_my=False)


<class 'PIL.JpegImagePlugin.JpegImageFile'>
<class 'torch.Tensor'> torch.Size([3, 224, 224]) cpu
cuda:0
tensor([[1.3232e-04, 1.6856e-04, 1.1498e-04, 9.9951e-01]], dtype=torch.float16)


In [47]:

optimizer = optim.Adam(net.parameters(), lr=1e-6,betas=(0.9,0.98),eps=1e-6,weight_decay=0.001)
scheduler = lr_scheduler.StepLR(
        optimizer, step_size=10, gamma=0.1)

# 创建损失函数
loss_img = nn.CrossEntropyLoss()
loss_txt = nn.CrossEntropyLoss()

In [48]:
phase = "train"
model_name = "my_model"
ckt_gap = 4
epoches = 50

In [49]:

for epoch in range(epoches):
    scheduler.step()
    total_loss = 0
    batch_num = 0
    # 使用混合精度，占用显存更小
    with torch.cuda.amp.autocast(enabled=True):
        for images,label_tokens in clip_dataloader:
            # 将图片和标签token转移到device设备
            images = images.to(device)
            label_tokens = label_tokens.to(device)
            batch_num += 1
            # 优化器梯度清零
            optimizer.zero_grad()
            with torch.set_grad_enabled(phase == "train"):
                logits_per_image, logits_per_text = net(images, label_tokens)
                ground_truth = torch.arange(len(images),dtype=torch.long,device=device)
                cur_loss = (loss_img(logits_per_image,ground_truth) + loss_txt(logits_per_text,ground_truth))/2
                total_loss += cur_loss
                if phase == "train":
                    cur_loss.backward()
                    if device == "cpu":
                        optimizer.step()
                    else:
                        optimizer.step()
                        clip.model.convert_weights(net) 
            if batch_num % 4 == 0:
                logger.info('{} epoch:{} loss:{}'.format(phase,epoch,cur_loss))
        epoch_loss = total_loss / data_size
        torch.save(net.state_dict(),f"{model_name}_epoch_{epoch}.pth")
        logger.info(f"weights_{epoch} saved")
        if epoch % ckt_gap == 0:
            checkpoint_path = f"{model_name}_ckt.pth"
            checkpoint = {
                'it': epoch,
                'network': net.state_dict(),
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict()}
            torch.save(checkpoint, checkpoint_path)
            logger.info(f"checkpoint_{epoch} saved")
        logger.info('{} Loss: {:.4f}'.format(
            phase, epoch_loss))


[32m2024-04-28 02:58:04.828[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m30[0m - [1mweights_0 saved[0m
[32m2024-04-28 02:58:06.087[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m39[0m - [1mcheckpoint_0 saved[0m
[32m2024-04-28 02:58:06.088[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m40[0m - [1mtrain Loss: 0.0033[0m
[32m2024-04-28 02:58:08.325[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m30[0m - [1mweights_1 saved[0m
[32m2024-04-28 02:58:08.326[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m40[0m - [1mtrain Loss: 0.5756[0m
[32m2024-04-28 02:58:10.471[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m30[0m - [1mweights_2 saved[0m
[32m2024-04-28 02:58:10.472[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m40[0m - [1mtrain Loss: 0.5282[0m
[32m2024-04-28 02:58:12.431[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[

### **test


In [None]:
evaluate(net, preprocess, image_path, labels, use_my=True)

<class 'PIL.JpegImagePlugin.JpegImageFile'>
<class 'torch.Tensor'> torch.Size([3, 224, 224]) cpu
cuda:0


In [52]:

with torch.no_grad():
    image_features = net.encode_image(image)
    text_features = net.encode_text(text)
    logits_per_image, logits_per_text = net(image, text)
    probs = logits_per_image.softmax(dim=-1).cpu()
    print(probs)

tensor([[2.9206e-06, 2.7299e-05, 2.3842e-07, 1.0000e+00]], dtype=torch.float16)
