# Pipeline Fine-Tune VLM (Moondream2)

Маликов Денис     
Конвейер обучения VLM     

In [None]:
!pip install einops bitsandbytes transformers==4.41.2

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from IPython.display import display
import requests
from PIL import Image
import pandas as pd
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
from bitsandbytes.optim import Adam8bit
import math
from einops import rearrange
from tqdm import tqdm

## Обработка датасета

Я выбрал этот датасет, потому что он содержит фотографии различных графов и вопросы с ответами к каждой фотографии.

In [None]:
def processed_data(df) :
    '''
    Функция обрабатывает датасет и возвращает обучающую и тестовую выборки.
    '''

    image_conversations = {}
    for index, row in df.iterrows():
        image = row['image']
        conversations = row['conversations']
        if image not in image_conversations:
            image_conversations[image] = []
        image_conversations[image].append(conversations)

    all_messages = []
    for image_name in image_conversations:

        image_path = '/content/drive/MyDrive/vlm/dataset/' + image_name
        image = Image.open(image_path).convert('RGB')


        for msg_gpt in image_conversations[image_name]:
            gpt = msg_gpt[1]['value']
            msg = msg_gpt[0]['value']

            if "\n<image>" in msg:
                msg = msg.replace("\n<image>", '')
            else:
                msg = msg.replace("<image>\n", '')

            if "\n" in gpt:
              
              def process_gpt_problem(dialog, img):

                qas = [dialog[i].replace('Question: ', '') for i in range(0, len(dialog), 2)]
                ans = [dialog[i].replace('Answer: ', '') for i in range(1, len(dialog), 2)]

                for i in range(len(qas)):
                  sample = {
                    "image": img,
                    "qa": [
                        {
                            "question": qas[i],
                            "answer": ans[i],
                        }
                    ]
                  }
                  all_messages.append(sample)

              process_gpt_problem([x for x in gpt.split('\n')[1:] if x != ''], image)
              temp_gpt = gpt.split('\n')[0]

            else:
              temp_gpt = gpt

            sample = {
                "image": image,
                "qa": [
                    {
                        "question": msg,
                        "answer": temp_gpt,
                    }
                ]
            }
            all_messages.append(sample)

    data_train, data_test = train_test_split(all_messages, test_size=0.1, random_state=1337, shuffle=True)
    return (data_train, data_test)

box_path = '/content/drive/MyDrive/vlm/dataset/box_chart_100examples_simplified_qa.json'
candlestick_path = '/content/drive/MyDrive/vlm/dataset/candlestick_chart_100examples_simplified_qa.json'
funnel_path = '/content/drive/MyDrive/vlm/dataset/funnel_chart_100examples_simplified_qa.json'
gantt_path = '/content/drive/MyDrive/vlm/dataset/gantt_chart_100examples_simplified_qa.json'
heatmap_path = '/content/drive/MyDrive/vlm/dataset/heatmap_chart_100examples_simplified_qa.json'
polar_path = '/content/drive/MyDrive/vlm/dataset/polar_chart_100examples_simplified_qa.json'
scatter_path = '/content/drive/MyDrive/vlm/dataset/scatter_chart_100examples_simplified_qa.json'

all_path = [box_path, candlestick_path, funnel_path, gantt_path, heatmap_path, polar_path, scatter_path]

dfs = []
for path in all_path:
    df = pd.read_json(path)
    dfs.append(df)

df_all = pd.concat(dfs, ignore_index=True)
data_train, data_test = processed_data(df_all)

## Загрузка модели

Загружаем модель на CUDA с весами в float16

In [None]:
DEVICE = "cuda"
DTYPE = torch.float16 
MD_REVISION = "2024-05-20"

tokenizer = AutoTokenizer.from_pretrained("vikhyatk/moondream2", revision=MD_REVISION)
moondream = AutoModelForCausalLM.from_pretrained(
    "vikhyatk/moondream2", revision=MD_REVISION, trust_remote_code=True,
    torch_dtype=DTYPE, device_map={"": DEVICE}
)

## Обучение модели

### 1. Гиперпараметры для модели

In [None]:
EPOCHS = 15
BATCH_SIZE = 8
GRAD_ACCUM_STEPS = 2
LR = 1e-5


### 2. Импорт библиотек и определение констант

In [None]:
from torch.utils.data import DataLoader
from bitsandbytes.optim import Adam8bit
import math
from einops import rearrange
from tqdm import tqdm

ANSWER_EOS = "<|endoftext|>"

IMG_TOKENS = 729

### 3. Определение функции collate_fn (Data_Collator)

Функция collate_fn используется для преобразования батча данных в формат, подходящий для модели.    
Она принимает батч данных, преобразует изображения с помощью moondream.vision_encoder, а затем создает токены и метки для каждого вопроса-ответа в батче.

In [None]:

def collate_fn(batch):
    torch.cuda.empty_cache()
    images = [sample['image'] for sample in batch]
    images = [moondream.vision_encoder.preprocess(image) for image in images]

    labels_acc = []
    tokens_acc = []

    for sample in batch:
        toks = [tokenizer.bos_token_id]
        labs = [-100] * (IMG_TOKENS + 1)

        for qa in sample['qa']:
            q_t = tokenizer(
                f"\n\nQuestion: {qa['question']}\n\nAnswer:",
                add_special_tokens=False
            ).input_ids
            toks.extend(q_t)
            labs.extend([-100] * len(q_t))

            a_t = tokenizer(
                f" {qa['answer']}{ANSWER_EOS}",
                add_special_tokens=False
            ).input_ids
            toks.extend(a_t)
            labs.extend(a_t)

        tokens_acc.append(toks)
        labels_acc.append(labs)

    max_len = -1
    for labels in labels_acc:
        max_len = max(max_len, len(labels))

    attn_mask_acc = []

    for i in range(len(batch)):
        len_i = len(labels_acc[i])
        pad_i = max_len - len_i

        labels_acc[i].extend([-100] * pad_i)
        tokens_acc[i].extend([tokenizer.eos_token_id] * pad_i)
        attn_mask_acc.append([1] * len_i + [0] * pad_i)
    torch.cuda.empty_cache()
    return (
        images,
        torch.stack([torch.tensor(t, dtype=torch.long) for t in tokens_acc]),
        torch.stack([torch.tensor(l, dtype=torch.long) for l in labels_acc]),
        torch.stack([torch.tensor(a, dtype=torch.bool) for a in attn_mask_acc]),
    )

### 4. Определение функции compute_loss

Функция compute_loss вычисляет потерю для батча данных.      
Она принимает батч данных, преобразует токены и метки в тензоры, а затем использует модель moondream.text_model для вычисления потери.

In [None]:
def compute_loss(batch):
    images, tokens, labels, attn_mask = batch

    tokens = tokens.to(DEVICE)
    labels = labels.to(DEVICE)
    attn_mask = attn_mask.to(DEVICE)

    with torch.no_grad():
        img_embs = moondream.vision_encoder(images)

    tok_embs = moondream.text_model.get_input_embeddings()(tokens)
    inputs_embeds = torch.cat((tok_embs[:, 0:1, :], img_embs, tok_embs[:, 1:, :]), dim=1)

    outputs = moondream.text_model(
        inputs_embeds=inputs_embeds,
        labels=labels,
        attention_mask=attn_mask,
    )

    return outputs.loss

### 5. Определение скорости на шаге обучения

Функция lr_schedule определяет расписание изменения скорости в зависимости от шага обучения.

In [None]:
def lr_schedule(step, max_steps):
    x = step / max_steps
    if x < 0.1:
        return 0.1 * LR + 0.9 * LR * x / 0.1
    else:
        return 0.1 * LR + 0.9 * LR * (1 + math.cos(math.pi * (x - 0.1))) / 2

### 6. Создание DataLoader

Создается DataLoader для тренировочного набора данных с помощью функции collate_fn.

In [None]:
dataloaders = {
    "train": DataLoader(
        data_train,
        batch_size=BATCH_SIZE,
        shuffle=True,
        collate_fn=collate_fn,
    )
}

### 7. Настройка модели для обучения

In [None]:
moondream.text_model.train()
moondream.text_model.transformer.gradient_checkpointing_enable()

total_steps = EPOCHS * len(dataloaders["train"]) // GRAD_ACCUM_STEPS
optimizer = Adam8bit(
    [
        {"params": moondream.text_model.parameters()},
    ],
    lr=LR * 0.1,
    betas=(0.9, 0.95),
    eps=1e-6
)

### 8. Обучени модели

Модель тренируется в цикле по эпохам.      
В каждой эпохе происходит итерация по батчам данных, вычисление потери, обратное распространение ошибки и обновление параметров модели.

In [None]:
i = 0
for epoch in range(EPOCHS):
    for batch in tqdm(dataloaders["train"], desc=f"Epoch {epoch + 1}/{EPOCHS}"):
        i += 1

        loss = compute_loss(batch)
        loss.backward()

        if i % GRAD_ACCUM_STEPS == 0:
            optimizer.step()
            optimizer.zero_grad()

            lr = lr_schedule(i / GRAD_ACCUM_STEPS, total_steps)
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr


moondream.save_pretrained("/content/drive/MyDrive/moondream-ft_all_15_epoch")

# Оценка моделей

В наличии 118 тестовых вопросов и овтетов к ним.

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

DEVICE = "cuda"
DTYPE = torch.float32 if DEVICE == "cpu" else torch.float16 # CPU doesn't support float16
MD_REVISION = "2024-05-20"

tokenizer = AutoTokenizer.from_pretrained("vikhyatk/moondream2", revision=MD_REVISION)
moondream_default = AutoModelForCausalLM.from_pretrained(
    "vikhyatk/moondream2", revision=MD_REVISION, trust_remote_code=True,
    torch_dtype=DTYPE, device_map={"": DEVICE}
)

moondream_finetune = AutoModelForCausalLM.from_pretrained(
    "/content/drive/MyDrive/moondream-ft_new_all_15_epoch", revision=MD_REVISION, trust_remote_code=True,
    torch_dtype=DTYPE, device_map={"": DEVICE}
)

Делаем предсказания

In [None]:
moondream_default.eval()
moondream_finetune.eval()

right = []
pred_default = []
pred_finetune = []
for i, sample in enumerate(data_test):
    md_answer_default = moondream_default.answer_question(
        moondream_default.encode_image(sample['image']),
        sample['qa'][0]['question'],
        tokenizer=tokenizer,
        num_beams=4,
        no_repeat_ngram_size=5,
        early_stopping=True
    )
    md_answer_finetune = moondream_finetune.answer_question(
        moondream_finetune.encode_image(sample['image']),
        sample['qa'][0]['question'],
        tokenizer=tokenizer,
        num_beams=4,
        no_repeat_ngram_size=5,
        early_stopping=True
    )


    right.append(sample['qa'][0]['answer'])
    pred_default.append(md_answer_default)
    pred_finetune.append(md_answer_finetune)

Считаем метрики: Accuracy, F1, BLUE, ROGUE

In [None]:
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score

from sklearn.metrics import accuracy_score, f1_score
from nltk.translate.bleu_score import sentence_bleu
from rouge import Rouge

## Дефолтная модель

In [None]:
# Точность (Accuracy)
accuracy = accuracy_score(right, pred_default)
print(f'Точность: {accuracy}')

# F1-мера
f1 = f1_score(right, pred_default, average='micro')
print(f'F1-мера: {f1}')

import numpy as np
blue_score = []
for i in range(len(right)):
  y_true_tokens = right[i]
  y_pred_tokens =  pred_default[i]
  bleu = sentence_bleu(y_true_tokens, y_pred_tokens)
  blue_score.append(bleu)

print(f'BLEU: {np.mean(blue_score)}')

# ROUGE
rouge = Rouge()
scores = rouge.get_scores(pred_default, right, avg=True)
print(f'ROUGE: {scores}')

Точность: 0.0    
F1-мера: 0.0      
BLEU: 7.3532446158733e-232      
ROUGE: {'rouge-1': {'r': 0.12017594916754581, 'p': 0.03504597827353245, 'f': 0.047154155157596966}, 'rouge-2': {'r': 0.023289315726290515, 'p': 0.009176020940726823, 'f': 0.011400834649175393}, 'rouge-l': {'r': 0.11731116689099882, 'p': 0.03236279973896073, 'f': 0.04460934638441449}}

## Модель с finetune

In [None]:
# Точность (Accuracy)
accuracy = accuracy_score(right, pred_finetune)
print(f'Точность: {accuracy}')

# F1-мера
f1 = f1_score(right, pred_finetune, average='micro')
print(f'F1-мера: {f1}')

import numpy as np
blue_score = []
for i in range(len(right)):
  y_true_tokens = right[i]
  y_pred_tokens =  pred_finetune[i]
  bleu = sentence_bleu(y_true_tokens, y_pred_tokens)
  blue_score.append(bleu)

print(f'BLEU: {np.mean(blue_score)}')

# ROUGE
rouge = Rouge()
scores = rouge.get_scores(pred_finetune, right, avg=True)
print(f'ROUGE: {scores}')

Точность: 0.3949579831932773      
F1-мера: 0.39495798319327724      
BLEU: 1.2346845291816845e-231     
ROUGE: {'rouge-1': {'r': 0.44676325075484735, 'p': 0.4451853468660191, 'f': 0.4433325794435508}, 'rouge-2': {'r': 0.05921368547418967, 'p': 0.059803921568627454, 'f': 0.059323729035511356}, 'rouge-l': {'r': 0.4459993088144349, 'p': 0.4437847866419295, 'f': 0.44234394869713456}}