In [9]:
%%writefile requirements.txt
trl
sagemaker
datasets
pillow
torch==2.4.1
peft
h5py

Overwriting requirements.txt


In [10]:
!pip install -r requirements.txt

Collecting torch==2.0.1 (from -r requirements.txt (line 5))
  Downloading torch-2.0.1-cp310-cp310-manylinux1_x86_64.whl.metadata (24 kB)
Collecting nvidia-cuda-nvrtc-cu11==11.7.99 (from torch==2.0.1->-r requirements.txt (line 5))
  Downloading nvidia_cuda_nvrtc_cu11-11.7.99-2-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu11==11.7.99 (from torch==2.0.1->-r requirements.txt (line 5))
  Downloading nvidia_cuda_runtime_cu11-11.7.99-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cuda-cupti-cu11==11.7.101 (from torch==2.0.1->-r requirements.txt (line 5))
  Downloading nvidia_cuda_cupti_cu11-11.7.101-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu11==8.5.0.96 (from torch==2.0.1->-r requirements.txt (line 5))
  Downloading nvidia_cudnn_cu11-8.5.0.96-2-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu11==11.10.3.66 (from torch==2.0.1->-r requirements.txt (line 5))
  Downloading nvidia_

In [10]:
!df -h
!mkdir /home/ec2-user/SageMaker/data

Filesystem      Size  Used Avail Use% Mounted on
devtmpfs         16G     0   16G   0% /dev
tmpfs            16G     0   16G   0% /dev/shm
tmpfs            16G  708K   16G   1% /run
tmpfs            16G     0   16G   0% /sys/fs/cgroup
/dev/nvme0n1p1  135G   88G   48G  65% /
tmpfs           3.1G     0  3.1G   0% /run/user/0
/dev/nvme2n1    296G  224K  281G   1% /home/ec2-user/SageMaker
tmpfs           3.1G     0  3.1G   0% /run/user/1002
tmpfs           3.1G     0  3.1G   0% /run/user/1001
tmpfs           3.1G     0  3.1G   0% /run/user/1000


In [11]:
import boto3
import sagemaker

sess = sagemaker.Session()
sagemaker_bucket = sess.default_bucket()
iam = boto3.client('iam')
role = iam.get_role(RoleName='AmazonSageMaker-ExecutionRole-20241127T090079')['Role']['Arn']

s3_prefix = 'training-data'
ec2_data_prefix = '/home/ec2-user/SageMaker/data'
base_job_name = "blip-finetune-facad"

checkpoint_s3_uri = f's3://{sagemaker_bucket}/{base_job_name}/checkpoints'
train_image_s3_path = f'{s3_prefix}/train/TRAIN_IMAGES.hdf5'
train_caption_s3_path = f'{s3_prefix}/train/TRAIN_CAPTIONS.txt'
eval_image_s3_path = f'{s3_prefix}/eval/VAL_IMAGES.hdf5'
eval_caption_s3_path = f'{s3_prefix}/eval/VAL_CAPTIONS.txt'

In [12]:
import os

# download files from S3 to local SSD
s3 = boto3.client('s3')
finetuning_data_files = [train_image_s3_path, train_caption_s3_path, eval_image_s3_path, eval_caption_s3_path]

for filepath in finetuning_data_files:
    filename = filepath.split('/')[-1]
    data_file = os.path.join(ec2_data_prefix, filename)
    s3.download_file(sagemaker_bucket, filepath, data_file)

In [13]:
!ls /home/ec2-user/SageMaker/data

TRAIN_CAPTIONS.txt  TRAIN_IMAGES.hdf5  VAL_CAPTIONS.txt  VAL_IMAGES.hdf5


In [14]:
%%writefile hdf5.py
import h5py
import datasets


class HDF5Config(datasets.BuilderConfig):

    def __init__(self, key='', **kwargs):
        """BuilderConfig for HDF5 file.
        """
        # Version history:
        # 0.0.1: Initial version.
        super(HDF5Config, self).__init__(version=datasets.Version("0.0.1"), **kwargs)
        self.key = key


class HDF5(datasets.GeneratorBasedBuilder):

    BUILDER_CONFIGS = [
        HDF5Config(
            name="keyed_config",
            description="HDF5 Dataset Generator iterates values of provided key",
            key=''
        )
    ]

    def _info(self):
        return datasets.DatasetInfo(description=self.config.description)

    def _split_generators(self, dl_manager):
        if not self.config.data_files:
            raise ValueError(f"At least one data file must be specified, but got data_files={self.config.data_files}")
        dl_manager.download_config.extract_on_the_fly = True
        data_files = dl_manager.download_and_extract(self.config.data_files)
        splits = []
        for split_name, files in data_files.items():
            splits.append(datasets.SplitGenerator(name=split_name, gen_kwargs={"files": files}))
        return splits

    def _generate_examples(self, files):
        key = self.config.key
        for file in files:
            with h5py.File(file, "r", swmr=True) as data:
                if not key:
                    raise ValueError(f"A key must be specified, but got key={key}")
                else:
                    for idx, value in enumerate(data[key]):
                        yield idx, { key: value }

Writing hdf5.py


In [28]:
%%writefile finetune_blip.py
import logging
import sys
from dataclasses import dataclass
from typing import Optional

import torch
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm import tqdm

import trl
import transformers
from transformers import (
    AutoProcessor,
    BlipForConditionalGeneration,
    HfArgumentParser,
    TrainingArguments,
    BitsAndBytesConfig,
)
from transformers.trainer_utils import get_last_checkpoint
from peft import LoraConfig, get_peft_model
from trl import SFTTrainer
from torch.utils.data import DataLoader


logger = logging.getLogger(__name__)
device = "cuda" if torch.cuda.is_available() else "cpu"


# custome trl.trainer.ConstantLengthDataset
class FashionImageCaptioningDataset(torch.utils.data.IterableDataset):
    def __init__(self, dataset, size, processor):
        self.dataset = dataset
        self.size = size
        self.processor = processor
        
    def __len__(self):
        return self.size
    
    def __iter__(self):
        n = len(self.dataset['image'])
        m = len(self.dataset['text'])
        if n != m:
            raise Exception(f'Expects same image and text datasets, but received {n} images and {m} texts.')

        for i in range(n):
            image_iterator = iter(self.dataset['image'][i])
            text_iterator = iter(self.dataset['text'][i])
            while True:
                try:
                    image = next(image_iterator)['images']
                    text = next(text_iterator)['text']
                    example = self.processor(images=torch.tensor(image, dtype=torch.int), padding="max_length", return_tensors="pt")
                    example = {k: v.squeeze() for k, v in example.items()}
                    example['labels'] = text
                    yield example
                except Exception as err:
                    logger.warning(f"Error generating example: {err}")
                    break

@dataclass
class SFTTrainingArguments:
    model_name_or_path: str
    train_data_files: str
    train_data_size: int = 0
    eval_data_files: str = None
    eval_data_size: int = 0
    freeze_vision_model: bool = False
    freeze_text_model: bool = False
    load_in_8bit: bool = False
    load_in_4bit: bool = False
    use_flash_attention_2: bool = False
    use_peft: bool = True
    peft_target_model: Optional[str] = "blip-image-captioning-facad"
    peft_target_modules: Optional[list[str]] = None
    peft_lora_r: int = 16
    peft_lora_alpha: int = 32
    peft_lora_dropout: float = 0.05

    def __post_init__(self):
        if self.load_in_8bit and self.load_in_4bit:
            raise ValueError("load_in_8bit and load_in_4bit are mutually exclusive")
        if self.peft_target_model and self.peft_target_modules is None:
            if self.peft_target_model == "blip-image-captioning-facad":
                self.peft_target_modules = [
                    "self.query",
                    "self.key",
                    "self.value",
                    "output.dense",
                    "self_attn.qkv",
                    "self_attn.projection",
                    "mlp.fc1",
                    "mlp.fc2",
                ]
            else:
                logger.warning(
                    f"peft_target_model '{self.peft_target_model}' is not supported, "
                    f"so peft_target_modules is set to None."
                )

    def from_pretrained_kwargs(self, training_args):
        kwargs = {}
        if self.load_in_8bit:
            kwargs = {"load_in_8bit": True}
        elif self.load_in_4bit:
            kwargs = {
                "quantization_config": BitsAndBytesConfig(
                    load_in_4bit=True,
                    bnb_4bit_use_double_quant=True,
                    bnb_4bit_quant_type="nf4",
                    bnb_4bit_compute_dtype=torch.bfloat16,
                )
            }
        elif training_args.bf16:
            kwargs = {"torch_dtype": torch.bfloat16}
        else:
            kwargs = {"torch_dtype": torch.float16}
        if self.use_flash_attention_2:
            kwargs["attn_implementation"] = "flash_attention_2"
        return kwargs

def load_datasets(data_files):
    datasets = {'image': [], 'text': []}
    for data_file in data_files:
        dataset = None
        if data_file.endswith('.hdf5'):
            dataset = load_dataset("hdf5.py", name="keyed_config", key="images", data_files=data_file, trust_remote_code=True, streaming=True)
            datasets['image'].append(dataset['train'])
        else:
            dataset = load_dataset("text", data_files=data_file, streaming=True)
            datasets['text'].append(dataset['train'])
    return datasets

def main():
    parser = HfArgumentParser((TrainingArguments, SFTTrainingArguments))
    training_args, sft_training_args = parser.parse_args_into_dataclasses()

    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        handlers=[logging.StreamHandler(sys.stdout)],
    )

    log_level = training_args.get_process_log_level()
    logger.setLevel(log_level)
    logger.info(f"Training parameters {training_args}\nSupervised Fine-Tuning parameters {sft_training_args}")

    processor = AutoProcessor.from_pretrained(sft_training_args.model_name_or_path)
    kwarg = sft_training_args.from_pretrained_kwargs(training_args)
    model = BlipForConditionalGeneration.from_pretrained(sft_training_args.model_name_or_path, **kwarg).to(device)

    peft_config = None
    if sft_training_args.use_peft:
        peft_config = LoraConfig(
            r=sft_training_args.peft_lora_r,
            target_modules=sft_training_args.peft_target_modules,
            lora_alpha=sft_training_args.peft_lora_alpha,
            lora_dropout=sft_training_args.peft_lora_dropout,
            bias="none"
        )
        model = get_peft_model(model, peft_config)
        if training_args.gradient_checkpointing:
            for param in model.parameters():
                param.requires_grad = False
            model.gradient_checkpointing_enable()
            model.enable_input_require_grads()

    def _freeze_params(module):
        for param in module.parameters():
            param.requires_grad = False

    if sft_training_args.freeze_vision_model:
        _freeze_params(model.vision_model)

    if sft_training_args.freeze_text_model:
        _freeze_params(model.text_model)

    train_dataset = None
    if sft_training_args.train_data_files:
        data_files = sft_training_args.train_data_files.split(',')
        train_dataset = load_datasets(data_files)
        train_dataset = FashionImageCaptioningDataset(train_dataset, sft_training_args.train_data_size, processor)

    eval_dataset = None
    if sft_training_args.eval_data_files:
        data_files = sft_training_args.eval_data_files.split(',')
        eval_dataset = load_datasets(data_files)
        eval_dataset = FashionImageCaptioningDataset(eval_dataset, sft_training_args.eval_data_size, processor)

    train_dataloader = DataLoader(train_dataset, batch_size=training_args.per_device_train_batch_size, shuffle=False, pin_memory=True)
    eval_dataloader = DataLoader(eval_dataset, batch_size=training_args.per_device_eval_batch_size, shuffle=False, pin_memory=True)

    optimizer = torch.optim.AdamW(model.parameters(), lr=training_args.learning_rate, weight_decay=training_args.weight_decay)
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9, last_epoch=-1, verbose=False)
    num_epochs = int(training_args.num_train_epochs)
    patience = 10
    min_eval_loss = float("inf")
    early_stopping_hook = 0

    for epoch in range(num_epochs):
        epoch_loss = 0
        model.train()
        for _, batch in zip(tqdm(range(len(train_dataloader)), desc='Training batch: ...'), train_dataloader):
            pixel_values = batch.pop('pixel_values').to(device)
            texts = batch.pop('labels')
            text_inputs = processor.tokenizer(texts, padding=True, return_tensors="pt").to(device)
            input_ids = text_inputs.input_ids
            attention_mask = text_inputs.attention_mask

            outputs = model(input_ids=input_ids,
                            pixel_values=pixel_values,
                            attention_mask=attention_mask,
                            labels=input_ids)

            loss = outputs.loss
            epoch_loss += loss.item()
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

        model.eval()
        eval_loss = 0
        for _, batch in zip(tqdm(range(len(eval_dataloader)), desc='Validating batch: ...'), eval_dataloader):
            pixel_values = batch.pop('pixel_values').to(device)
            texts = batch.pop('labels')
            text_inputs = processor.tokenizer(texts, padding=True, return_tensors="pt").to(device)
            input_ids = text_inputs.input_ids
            attention_mask = text_inputs.attention_mask

            outputs = model(input_ids=input_ids,
                            pixel_values=pixel_values,
                            attention_mask=attention_mask,
                            labels=input_ids)

            loss = outputs.loss
            eval_loss += loss.item()

        print("Epoch: {} - Training loss: {} - Eval Loss: {} - LR: {}".format(epoch+1, epoch_loss/len(train_dataloader), eval_loss/len(eval_dataloader), optimizer.param_groups[0]["lr"]))
        scheduler.step()
        if eval_loss < min_eval_loss:
            model.save_pretrained(training_args.output_dir, from_pt=True)
            min_eval_loss = eval_loss
            early_stopping_hook = 0
        else:
            early_stopping_hook += 1
            if early_stopping_hook > patience:
                break


if __name__ == "__main__":
    print(f'using transformers: {transformers.__version__}, trl: {trl.__version__}')
    main()

Overwriting finetune_blip.py


In [17]:
!mkdir /home/ec2-user/SageMaker/model

mkdir: cannot create directory ‘/home/ec2-user/SageMaker/model’: File exists


In [None]:
!python finetune_blip.py --eval_data_files /home/ec2-user/SageMaker/data/VAL_IMAGES.hdf5,/home/ec2-user/SageMaker/data/VAL_CAPTIONS.txt --eval_data_size 19915 --freeze_vision_model True --learning_rate 1e-05 --model_name_or_path Salesforce/blip-image-captioning-base --num_train_epochs 5 --output_dir /home/ec2-user/SageMaker/model --per_device_eval_batch_size 32 --per_device_train_batch_size 32 --train_data_files /home/ec2-user/SageMaker/data/VAL_IMAGES.hdf5,/home/ec2-user/SageMaker/data/VAL_CAPTIONS.txt --train_data_size 800000 --weight_decay 0.05 --do_train True --do_eval True --gradient_checkpointing True

using transformers: 4.46.3, trl: 0.12.2
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]
Training batch: ...:   1%|▏               | 272/25000 [04:49<7:18:13,  1.06s/it]