# 完整项目设置

In [4]:
from PIL import Image
import requests
from transformers import AutoProcessor, Blip2ForConditionalGeneration, BitsAndBytesConfig
import torch
import os
import re
from pycocoevalcap.bleu.bleu import Bleu
from pycocoevalcap.rouge.rouge import Rouge
from pycocoevalcap.cider.cider import Cider
from pycocoevalcap.meteor.meteor import Meteor
from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer
import matplotlib.pyplot as plt
from sentence_transformers import SentenceTransformer, util
import numpy as np
from tqdm import tqdm
import io
from peft import LoraConfig, get_peft_model
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import pandas as pd
import datasets
from sklearn.model_selection import train_test_split


  warn(f"Failed to load image Python extension: {e}")


## 导入模型

In [5]:
device = "cuda" if torch.cuda.is_available() else "cpu"

processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")

model = Blip2ForConditionalGeneration.from_pretrained("ybelkada/blip2-opt-2.7b-fp16-sharded", device_map="auto", load_in_8bit=True)

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.


Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]

## 处理数据集

In [6]:
split_size = 0.2 # 用于减小训练数据集大小
dataset_name = 'sd' # 从这两个数据集中选择一个 sd/coco

In [7]:
if dataset_name == 'sd':
    train_df = pd.read_parquet('sd_train_dataset.parquet')
    test_df = pd.read_parquet('sd_best_test_over0_7.parquet')
elif dataset_name == 'coco':
    train_df = pd.read_parquet('coco_train_dataset.parquet')
    test_df = pd.read_parquet('coco_best_test_over0_8.parquet')

#修改图片路径
train_df['image_path'] = '../' + train_df['image_path']
test_df['image_path'] = '../' + test_df['image_path']

train_split, val_split = train_test_split(train_df, test_size=split_size, random_state=42)
print(f"Train dataset size: {len(train_split)}")
print(f"Validation dataset size: {len(val_split)}")

Train dataset size: 964
Validation dataset size: 242


In [8]:
# Convert parquet file to datasets format
train_dataset = datasets.Dataset.from_pandas(train_split)   # 之后通过划分训练和验证集减少训练数据集结果
test_dataset = datasets.Dataset.from_pandas(test_df)        # 用于最后的验证
val_dataset = datasets.Dataset.from_pandas(val_split)       # 无用，仅为了减少训练数据集大小

In [9]:
# 查看一个例子
train_dataset[0]

{'index': 2210,
 'Prompt_Index': 35438,
 'Prompt': 'Indians eating bananas and laughing , artwork by Craig Mullins,Movie poster, detailed, trending on artstation',
 'image_path': '../StableDiff_Dataset/generated_pics_sd_ext_prompts/35438/1.png',
 '__index_level_0__': 1136}

## 训练配置

In [10]:
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os
import pandas as pd

class ImageCaptioningDataset(Dataset):
    def __init__(self, dataset, processor):
        self.dataset = dataset
        self.processor = processor

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        image = Image.open(item['image_path'])
        encoding = self.processor(images=image, padding="max_length", return_tensors="pt")
        # remove batch dimension
        encoding = {k: v.squeeze() for k, v in encoding.items()}
        if dataset_name == 'sd':
            encoding["text"] = item["Prompt"]
        elif dataset_name == 'coco':
            encoding["text"] = item["caption"]
        return encoding

In [11]:
def collate_fn(batch):
    # pad the input_ids and attention_mask
    processed_batch = {}
    for key in batch[0].keys():
        if key != "text":
            processed_batch[key] = torch.stack([example[key] for example in batch])
        else:
            text_inputs = processor.tokenizer(
                [example["text"] for example in batch], padding=True, return_tensors="pt"
            )
            processed_batch["input_ids"] = text_inputs["input_ids"]
            processed_batch["attention_mask"] = text_inputs["attention_mask"]
    return processed_batch

In [12]:
# Define the LoraConfig
config = LoraConfig(
    r=64,
    lora_alpha=64,
    lora_dropout=0.05,
    bias="none",
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"]
)

model = get_peft_model(model, config)
model.print_trainable_parameters()

trainable params: 31,457,280 || all params: 3,776,137,216 || trainable%: 0.8331


In [13]:
train_ft_dataset = ImageCaptioningDataset(train_dataset, processor)
train_dataloader = DataLoader(train_ft_dataset, batch_size=16, shuffle=True,collate_fn=collate_fn)

subset_size = int(0.1 * len(test_dataset))
test_ft_dataset = ImageCaptioningDataset(test_dataset.take(subset_size), processor)
test_dataloader = DataLoader(test_ft_dataset, batch_size=16,shuffle=True,collate_fn=collate_fn)

## 开始训练并保存最佳模型

In [14]:
def generate_captions_batch(batch, model, processor):
    generated_ids = model.generate(batch['pixel_values'], max_length=77)
    captions = processor.batch_decode(generated_ids, skip_special_tokens=True)
    return captions

In [15]:
from sentence_transformers import SentenceTransformer, util
import numpy as np
import datetime

judge_model = SentenceTransformer('all-MiniLM-L6-v2')


def cal_cossim(reference:list,generated:list,judge_model)->dict:

    scores = []

    for seq in range(len(reference)):
        embedding1 = judge_model.encode(reference[seq], convert_to_tensor=True)
        embedding2 = judge_model.encode(generated[seq], convert_to_tensor=True)

        cosine_similarity = util.pytorch_cos_sim(embedding1, embedding2)
        scores.append(cosine_similarity.item())
    
    mean_score = np.mean(scores)

    return {"mean_score":mean_score,"scores":scores}


In [16]:
def train_and_evaluate(model,processor, train_dataloader, test_dataloader, optimizer, epochs=10):
    best_val_cos =0
    train_loss = []
    val_cos = []
    now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")


    for epoch in range(epochs):
        model.train()
        total_loss = 0

        print(f"Epoch {epoch+1}")

        for batch in tqdm(train_dataloader,total=len(train_dataloader),desc="Train Progress"):
            input_ids = batch.pop("input_ids").to(device)
            pixel_values = batch.pop("pixel_values").to(device, torch.float16)
            attention_mask = batch.pop("attention_mask").to(device)

            optimizer.zero_grad()
            outputs = model(input_ids=input_ids,
                            pixel_values=pixel_values,
                            labels=input_ids,
                            attention_mask=attention_mask)
        
            loss = outputs.loss

            #print("Loss:", loss.item())

            loss.backward()
            optimizer.step()

            total_loss += loss.item()
        
        avg_train_loss = total_loss / len(train_dataloader)
        train_loss.append(avg_train_loss)
        print(f"Epoch {epoch+1} Train Loss: {avg_train_loss}")

        
        # 评估阶段
        model.eval()

        with torch.no_grad():
            total_genval_captions = []
            total_ref_captions = []
            for batch in tqdm(test_dataloader,total=len(test_dataloader),desc="Evaluation Progress"):
                
                captions = generate_captions_batch(batch, model, processor)
                refs = processor.batch_decode(batch['input_ids'], skip_special_tokens=True)
                total_genval_captions.extend(captions)
                total_ref_captions.extend(refs)

            cosscore_val = cal_cossim(reference=total_ref_captions,generated=total_genval_captions,judge_model=judge_model)

        val_cos.append(cosscore_val['mean_score'])
        print(f"Validation Cosine-similarity: {cosscore_val['mean_score']}")

        # 检查是否是最好的模型
        if cosscore_val['mean_score'] >= best_val_cos:
            best_val_cos = cosscore_val['mean_score']
            print(f"Saving new best model with val_cos: {best_val_cos}")
            model_path = f"blip2_ft_{dataset_name}/{now}/blip2_{len(train_dataset)}_{split_size}_{epochs}/best"
            model.save_pretrained(model_path)
            #processor.save_pretrained(model_path)
            similarity_file = open(model_path+'/cosine_similarity.txt', 'w')
            similarity_file.write('Epoch, Cosine Similarity\n')  # 写入标题行
            similarity_file.write(f'{epoch+1}, {best_val_cos}\n')

    return {"bst_model_path":model_path,"train_avg_loss/e":train_loss,"val_avg_cos/e":val_cos}



In [17]:
# 设置参数
num_epoches = 10

In [18]:
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)
# 记录loss曲线
bast_model_output = train_and_evaluate(model=model,processor=processor,train_dataloader=train_dataloader,test_dataloader=test_dataloader,optimizer=optimizer,epochs=num_epoches) 

Epoch 1


Train Progress:   0%|          | 0/61 [00:00<?, ?it/s]

Expanding inputs for image tokens in BLIP-2 should be done in processing. Please follow instruction here (https://gist.github.com/zucchini-nlp/e9f20b054fa322f84ac9311d9ab67042) to update your BLIP-2 model. Using processors without these attributes in the config is deprecated and will throw an error in v4.50.
Train Progress: 100%|██████████| 61/61 [00:50<00:00,  1.22it/s]


Epoch 1 Train Loss: 4.212202228483607


Evaluation Progress: 100%|██████████| 2/2 [00:21<00:00, 10.81s/it]


Validation Cosine-similarity: 0.506862654350698
Saving new best model with val_cos: 0.506862654350698
Epoch 2


Train Progress: 100%|██████████| 61/61 [00:49<00:00,  1.23it/s]


Epoch 2 Train Loss: 2.1067334784836067


Evaluation Progress: 100%|██████████| 2/2 [00:21<00:00, 10.70s/it]


Validation Cosine-similarity: 0.5517264381051064
Saving new best model with val_cos: 0.5517264381051064
Epoch 3


Train Progress: 100%|██████████| 61/61 [00:49<00:00,  1.22it/s]


Epoch 3 Train Loss: 1.9184650358606556


Evaluation Progress: 100%|██████████| 2/2 [00:21<00:00, 10.73s/it]


Validation Cosine-similarity: 0.5803271159529686
Saving new best model with val_cos: 0.5803271159529686
Epoch 4


Train Progress: 100%|██████████| 61/61 [00:49<00:00,  1.22it/s]


Epoch 4 Train Loss: 1.8015817110655739


Evaluation Progress: 100%|██████████| 2/2 [00:21<00:00, 10.73s/it]


Validation Cosine-similarity: 0.5961369648575783
Saving new best model with val_cos: 0.5961369648575783
Epoch 5


Train Progress: 100%|██████████| 61/61 [00:49<00:00,  1.23it/s]


Epoch 5 Train Loss: 1.800076844262295


Evaluation Progress: 100%|██████████| 2/2 [00:21<00:00, 10.95s/it]


Validation Cosine-similarity: 0.6128123151138425
Saving new best model with val_cos: 0.6128123151138425
Epoch 6


Train Progress: 100%|██████████| 61/61 [00:49<00:00,  1.23it/s]


Epoch 6 Train Loss: 1.7309650358606556


Evaluation Progress: 100%|██████████| 2/2 [00:21<00:00, 10.55s/it]


Validation Cosine-similarity: 0.6424035057425499
Saving new best model with val_cos: 0.6424035057425499
Epoch 7


Train Progress: 100%|██████████| 61/61 [00:49<00:00,  1.23it/s]


Epoch 7 Train Loss: 1.6236712346311475


Evaluation Progress: 100%|██████████| 2/2 [00:21<00:00, 10.78s/it]


Validation Cosine-similarity: 0.6331000244244933
Epoch 8


Train Progress: 100%|██████████| 61/61 [00:50<00:00,  1.22it/s]


Epoch 8 Train Loss: 1.5626761014344261


Evaluation Progress: 100%|██████████| 2/2 [00:21<00:00, 10.59s/it]


Validation Cosine-similarity: 0.6507753515616059
Saving new best model with val_cos: 0.6507753515616059
Epoch 9


Train Progress: 100%|██████████| 61/61 [00:49<00:00,  1.22it/s]


Epoch 9 Train Loss: 1.5502369364754098


Evaluation Progress: 100%|██████████| 2/2 [00:21<00:00, 10.67s/it]


Validation Cosine-similarity: 0.6783425789326429
Saving new best model with val_cos: 0.6783425789326429
Epoch 10


Train Progress: 100%|██████████| 61/61 [00:50<00:00,  1.22it/s]


Epoch 10 Train Loss: 1.4930199795081966


Evaluation Progress: 100%|██████████| 2/2 [00:21<00:00, 10.89s/it]


Validation Cosine-similarity: 0.683305075392127
Saving new best model with val_cos: 0.683305075392127


In [19]:
# # 加载最佳模型
# model = Blip2ForConditionalGeneration.from_pretrained(bast_model_output['bst_model_path'], device_map="auto", load_in_8bit=True)
# # 加载处理器
# processor = AutoProcessor.from_pretrained(bast_model_output['bst_model_path'])

model.save_pretrained(f'blip2_ft_{dataset_name}/final_model/blip2_{len(train_dataset)}_{split_size}_{num_epoches}',safe_serialization=False)
processor.save_pretrained(f'blip2_ft_{dataset_name}/final_model/blip2_{len(train_dataset)}_{split_size}_{num_epoches}')

['blip2_ft_sd/final_model/blip2_964_0.2_10/processor_config.json']

# 以新训练好模型生成初步结果

In [20]:
import requests
from PIL import Image
from tqdm import tqdm
import io

def generate_cap(df,u_model,u_processor)->list:
    u_model.eval()
    generation = []
    for index,row in tqdm(df.iterrows(),total = len(df) ):

        # 使用 PIL 打开图像
        image = Image.open(row['image_path']).convert('RGB')
        inputs = u_processor(images=image, return_tensors="pt").to(device)

        pixel_values = inputs.pixel_values

        generated_ids = u_model.generate(pixel_values=pixel_values)
        generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()

        # generated_ids = model.generate(**inputs)
        # generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
        generation.append(generated_text)
    
    return generation

In [21]:
test_generated_caption = generate_cap(test_df,model,processor)
if dataset_name == 'sd':
    test_reference_prompt = test_df['Prompt'].tolist()
elif dataset_name == 'coco':
    test_reference_prompt = test_df['caption'].tolist()

100%|██████████| 326/326 [13:05<00:00,  2.41s/it]


In [22]:
# 保存在测试集上的生成结果
test_compare_df = pd.DataFrame({'image_path':test_df['image_path'],'generated_caption':test_generated_caption,'reference_prompt':test_reference_prompt})
test_compare_df.to_parquet(f'blip2_ft_{dataset_name}/result/blip2_{split_size}_{num_epoches}_test.parquet')

In [23]:
test_compare_df

Unnamed: 0,image_path,generated_caption,reference_prompt
0,../StableDiff_Dataset/generated_pics_sd_ext_pr...,a beautiful painting of a seal in a snowy cave...,highly detailed painting of cute furry white b...
1,../StableDiff_Dataset/generated_pics_sd_ext_pr...,"a portrait of yoda smoking a joint, by artgerm...",a realistic and atmospheric watercolour fantas...
2,../StableDiff_Dataset/generated_pics_sd_ext_pr...,"a spongebob character with a hat, by artgerm, ...","spongebob trianglepants....., trending on arts..."
3,../StableDiff_Dataset/generated_pics_sd_ext_pr...,a red chinese temple with a glowing green beam...,a beautiful red asian temple with green detail...
4,../StableDiff_Dataset/generated_pics_sd_ext_pr...,a beautiful and detailed portrait of a fat hai...,a realistic and atmospheric watercolour fantas...
...,...,...,...
321,../StableDiff_Dataset/generated_pics_sd_ext_pr...,"a street in a city, concept art, concept art a...","pixel art of an old european city, summer seas..."
322,../StableDiff_Dataset/generated_pics_sd_ext_pr...,"a lion and a cub in the forest, digital art, t...",beautiful aesthetic digital illustration of a ...
323,../StableDiff_Dataset/generated_pics_sd_ext_pr...,"a bonsai tree in a pot, concept art, trending ...",bonsai spruce.. tree but minimalistic concept ...
324,../StableDiff_Dataset/generated_pics_sd_ext_pr...,"a beautiful forest with mushrooms, by artgerm,...","A mushroom house in the foreground, other mush..."
