<a href="https://colab.research.google.com/github/donbcolab/AIE3/blob/main/paligemma_cnmc_finetune_v6.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Paligemma Fine Tuning using CNMC dataset

### Setting Up

In [1]:
base_model_name = "google/paligemma-3b-pt-224"
adapter_version = "paligemma-cnmc-ft"
adapter_model_name = f"dwb2023/{adapter_version}"

In [2]:
!pip install -q -U git+https://github.com/huggingface/transformers.git datasets accelerate bitsandbytes peft hf_transfer

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone


In [3]:
import os
from google.colab import userdata

HF_TOKEN = userdata.get('HF_TOKEN')
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"

## Load Dataset

In [4]:
from datasets import load_dataset, DatasetDict

# Load CNMC Dataset
ds = load_dataset("dwb2023/cnmc-leukemia-2019", split="train")

In [5]:
# Filter records to only include those from fold 0
ds_fold_0 = ds.filter(lambda example: example['fold'] == 0)

# Define the percentage you want to retrieve (e.g., 10%)
percentage = 0.10

# Use train_test_split to get the subset
cnmc_ds = ds_fold_0.train_test_split(test_size=percentage)["test"]

# Columns to remove
cols_remove = ["subject_id", "image_number", "fold", "original_image_name", "relative_file_path"]
cnmc_ds = cnmc_ds.remove_columns(cols_remove)

In [6]:
# create train test split with test_size=0.2
train_ds = cnmc_ds.train_test_split(test_size=0.2)

# create test val split
test_val_ds = train_ds["test"].train_test_split(test_size=0.5)

cnmc_ds_dict = DatasetDict({
    "train" : train_ds["train"],
    "test" : test_val_ds["test"],
    "validation" : test_val_ds["train"]
})

cnmc_ds_dict

DatasetDict({
    train: Dataset({
        features: ['cell_count', 'image', 'label', 'class_label'],
        num_rows: 282
    })
    test: Dataset({
        features: ['cell_count', 'image', 'label', 'class_label'],
        num_rows: 36
    })
    validation: Dataset({
        features: ['cell_count', 'image', 'label', 'class_label'],
        num_rows: 35
    })
})

In [56]:
from PIL import Image
import io
from datasets import Dataset

def bytes_to_pil(image_data):
    try:
        # Explicitly specify that we're opening a BMP file
        return Image.open(io.BytesIO(image_data), formats=["BMP"])
    except Exception as e:
        print(f"Error converting BMP image: {e}")
        return None

def convert_images_ds(ds):
    def process_image(example):
        image_data = example['image']
        if isinstance(image_data, dict) and 'bytes' in image_data:
            pil_image = bytes_to_pil(image_data['bytes'])
            if pil_image is not None:
                example['image'] = pil_image
            else:
                print("Failed to convert BMP to PIL Image")
        else:
            print(f"Unexpected image data format: {type(image_data)}")
        return example

    return ds.map(process_image, batched=False, load_from_cache_file=False)

# Use the original dataset
train_ds = cnmc_ds_dict['train'].shuffle(seed=42).select(range(10))
test_ds = cnmc_ds_dict['test'].shuffle(seed=42).select(range(10))
validation_ds = cnmc_ds_dict['validation'].shuffle(seed=42).select(range(10))

train_ds = convert_images_ds(train_ds)
test_ds = convert_images_ds(test_ds)
validation_ds = convert_images_ds(validation_ds)

# Debugging: Check all images in the train dataset
for i, example in enumerate(train_ds):
    image = example['image']
    if isinstance(image, Image.Image):
        print(f"Image {i}: PIL Image, mode: {image.mode}, size: {image.size}")
    elif isinstance(image, dict) and 'bytes' in image:
        print(f"Image {i}: Bytes data, length: {len(image['bytes'])}, first 8 bytes: {image['bytes'][:8]}")
    else:
        print(f"Image {i}: Unexpected type: {type(image)}")

# Print other features
print("\nOther features in the first example:")
for key, value in train_ds[0].items():
    if key != 'image':
        print(f"{key}: {value}")

Map:   0%|          | 0/10 [00:00<?, ? examples/s]

Map:   0%|          | 0/10 [00:00<?, ? examples/s]

Map:   0%|          | 0/10 [00:00<?, ? examples/s]

Image 0: Bytes data, length: 608454, first 8 bytes: b'BM\xc6H\t\x00\x00\x00'
Image 1: Bytes data, length: 608454, first 8 bytes: b'BM\xc6H\t\x00\x00\x00'
Image 2: Bytes data, length: 608454, first 8 bytes: b'BM\xc6H\t\x00\x00\x00'
Image 3: Bytes data, length: 608454, first 8 bytes: b'BM\xc6H\t\x00\x00\x00'
Image 4: Bytes data, length: 608454, first 8 bytes: b'BM\xc6H\t\x00\x00\x00'
Image 5: Bytes data, length: 608454, first 8 bytes: b'BM\xc6H\t\x00\x00\x00'
Image 6: Bytes data, length: 608454, first 8 bytes: b'BM\xc6H\t\x00\x00\x00'
Image 7: Bytes data, length: 608454, first 8 bytes: b'BM\xc6H\t\x00\x00\x00'
Image 8: Bytes data, length: 608454, first 8 bytes: b'BM\xc6H\t\x00\x00\x00'
Image 9: Bytes data, length: 608454, first 8 bytes: b'BM\xc6H\t\x00\x00\x00'

Other features in the first example:
cell_count: 3
label: healthy
class_label: hem


In [62]:
print(f"cnmc_ds_dict['train'][0].keys(): {cnmc_ds_dict['train'][0].keys()}")
print(f"cnmc_ds_dict['train'][0]['image'].keys(): {cnmc_ds_dict['train'][0]['image'].keys()}")
print(f"cnmc_ds_dict['train'][0]['image']['bytes'][:8]: {cnmc_ds_dict['train'][0]['image']['bytes'][:8]}")

cnmc_ds_dict['train'][0].keys(): dict_keys(['cell_count', 'image', 'label', 'class_label'])
cnmc_ds_dict['train'][0]['image'].keys(): dict_keys(['bytes', 'path'])
cnmc_ds_dict['train'][0]['image']['bytes'][:8]: b'BM\xc6H\t\x00\x00\x00'


In [61]:
print(f"train_ds[0]['image']['bytes'][:8]: {train_ds[0]['image']['bytes'][:8]}")

train_ds[0]['image']['bytes'][:8]: b'BM\xc6H\t\x00\x00\x00'


In [81]:
from PIL import Image, UnidentifiedImageError
import io
from datasets import Dataset, DatasetDict

# Function to convert bytes to PIL Image with error handling
def bytes_to_pil(image_data):
    try:
        image_stream = io.BytesIO(image_data)
        image = Image.open(image_stream)
        image.verify()  # Verify if the image is not corrupted
        image_stream.seek(0)  # Reset the stream pointer to the beginning
        image = Image.open(image_stream)  # Reopen image after verification
        return image
    except (UnidentifiedImageError, IOError) as e:
        print(f"Error: {e}")
        return None

# Convert images in a Hugging Face Dataset
def convert_images_ds(ds):
    def process_image(example):
        image_data = example['image']
        if isinstance(image_data, dict) and 'bytes' in image_data:
            pil_image = bytes_to_pil(image_data['bytes'])
            if pil_image is not None:
                example['image'] = pil_image
                print(f"Converted image at index: {example['image']}")  # Debugging line
        return example

    # Use map with batched=False and load_from_cache_file=False to force re-processing
    ds = ds.map(process_image, batched=False, load_from_cache_file=False)
    return ds

In [82]:
# Assuming cnmc_ds_dict is already loaded with 'train', 'test', and 'validation' datasets
# and contains Hugging Face Datasets

# Random selection of 10 records from each dataset
train_ds = Dataset.from_dict(cnmc_ds_dict['train'].shuffle(seed=42)[:10])
test_ds = Dataset.from_dict(cnmc_ds_dict['test'].shuffle(seed=42)[:10])
validation_ds = Dataset.from_dict(cnmc_ds_dict['validation'].shuffle(seed=42)[:10])

# Convert images for each Dataset and save the new datasets
train_ds_converted = convert_images_ds(train_ds)
test_ds_converted = convert_images_ds(test_ds)
validation_ds_converted = convert_images_ds(validation_ds)

Map:   0%|          | 0/10 [00:00<?, ? examples/s]

Converted image at index: <PIL.BmpImagePlugin.BmpImageFile image mode=RGB size=450x450 at 0x790599118A90>
Converted image at index: <PIL.BmpImagePlugin.BmpImageFile image mode=RGB size=450x450 at 0x790599119840>
Converted image at index: <PIL.BmpImagePlugin.BmpImageFile image mode=RGB size=450x450 at 0x79059911B2E0>
Converted image at index: <PIL.BmpImagePlugin.BmpImageFile image mode=RGB size=450x450 at 0x79059911B550>
Converted image at index: <PIL.BmpImagePlugin.BmpImageFile image mode=RGB size=450x450 at 0x790599119870>
Converted image at index: <PIL.BmpImagePlugin.BmpImageFile image mode=RGB size=450x450 at 0x7905991197B0>
Converted image at index: <PIL.BmpImagePlugin.BmpImageFile image mode=RGB size=450x450 at 0x79059911BD60>
Converted image at index: <PIL.BmpImagePlugin.BmpImageFile image mode=RGB size=450x450 at 0x790599118700>
Converted image at index: <PIL.BmpImagePlugin.BmpImageFile image mode=RGB size=450x450 at 0x79059911BA60>
Converted image at index: <PIL.BmpImagePlugin.

Map:   0%|          | 0/10 [00:00<?, ? examples/s]

Converted image at index: <PIL.BmpImagePlugin.BmpImageFile image mode=RGB size=450x450 at 0x790599119810>
Converted image at index: <PIL.BmpImagePlugin.BmpImageFile image mode=RGB size=450x450 at 0x7907582D7490>
Converted image at index: <PIL.BmpImagePlugin.BmpImageFile image mode=RGB size=450x450 at 0x7907582D5B40>
Converted image at index: <PIL.BmpImagePlugin.BmpImageFile image mode=RGB size=450x450 at 0x7907582D4C70>
Converted image at index: <PIL.BmpImagePlugin.BmpImageFile image mode=RGB size=450x450 at 0x7907582D59F0>
Converted image at index: <PIL.BmpImagePlugin.BmpImageFile image mode=RGB size=450x450 at 0x7907582D4760>
Converted image at index: <PIL.BmpImagePlugin.BmpImageFile image mode=RGB size=450x450 at 0x7907582D5DB0>
Converted image at index: <PIL.BmpImagePlugin.BmpImageFile image mode=RGB size=450x450 at 0x7907582D49D0>
Converted image at index: <PIL.BmpImagePlugin.BmpImageFile image mode=RGB size=450x450 at 0x790599119FC0>
Converted image at index: <PIL.BmpImagePlugin.

Map:   0%|          | 0/10 [00:00<?, ? examples/s]

Converted image at index: <PIL.BmpImagePlugin.BmpImageFile image mode=RGB size=450x450 at 0x790599119B10>
Converted image at index: <PIL.BmpImagePlugin.BmpImageFile image mode=RGB size=450x450 at 0x7905992EDFC0>
Converted image at index: <PIL.BmpImagePlugin.BmpImageFile image mode=RGB size=450x450 at 0x7905992EE5C0>
Converted image at index: <PIL.BmpImagePlugin.BmpImageFile image mode=RGB size=450x450 at 0x7905992EE380>
Converted image at index: <PIL.BmpImagePlugin.BmpImageFile image mode=RGB size=450x450 at 0x7905992EE2F0>
Converted image at index: <PIL.BmpImagePlugin.BmpImageFile image mode=RGB size=450x450 at 0x7905992EEE00>
Converted image at index: <PIL.BmpImagePlugin.BmpImageFile image mode=RGB size=450x450 at 0x7905992EDC00>
Converted image at index: <PIL.BmpImagePlugin.BmpImageFile image mode=RGB size=450x450 at 0x7905992ECB20>
Converted image at index: <PIL.BmpImagePlugin.BmpImageFile image mode=RGB size=450x450 at 0x79059927E380>
Converted image at index: <PIL.BmpImagePlugin.

In [83]:
# Debugging: Check the first element in the new dataset to see if the image was converted
print(f"Train dataset first image type: {type(train_ds_converted[0]['image'])}")

# Verify the conversion by checking if an example image is of PIL.Image.Image type
if isinstance(train_ds_converted[0]['image'], Image.Image):
    print("The conversion to PIL.Image.Image was successful.")
else:
    print("The conversion to PIL.Image.Image failed.")

Train dataset first image type: <class 'dict'>
The conversion to PIL.Image.Image failed.


In [77]:
print(f"train_ds_converted[0].keys():\n{train_ds_converted[0].keys()}")

print(f"\ntrain_ds_converted[0]['image'].keys():\n{train_ds_converted[0]['image'].keys()}")

print(f"\ntrain_ds_converted[0]['image']['bytes'][:8]:\n{train_ds_converted[0]['image']['bytes'][:8]}")


train_ds_converted[0].keys():
dict_keys(['cell_count', 'image', 'label', 'class_label'])

train_ds_converted[0]['image'].keys():
dict_keys(['bytes', 'path'])

train_ds_converted[0]['image']['bytes'][:8]:
b'BM\xc6H\t\x00\x00\x00'


In [84]:
from PIL import Image, UnidentifiedImageError
import io
from datasets import Dataset, DatasetDict

# Function to convert bytes to PIL Image with error handling
def bytes_to_pil(image_data):
    try:
        image_stream = io.BytesIO(image_data)
        image = Image.open(image_stream)
        image.verify()  # Verify if the image is not corrupted
        image_stream.seek(0)  # Reset the stream pointer to the beginning
        image = Image.open(image_stream)  # Reopen image after verification
        return image
    except (UnidentifiedImageError, IOError) as e:
        print(f"Error: {e}")
        return None

# Convert images in a Hugging Face Dataset
def convert_images_ds(ds):
    def process_image(example):
        image_data = example['image']
        if isinstance(image_data, dict) and 'bytes' in image_data:
            pil_image = bytes_to_pil(image_data['bytes'])
            if pil_image is not None:
                example['image'] = pil_image
                print(f"Converted image at index: {example['image']}")  # Debugging line
        return example

    # Use map with batched=False and load_from_cache_file=False to force re-processing
    ds = ds.map(process_image, batched=False, load_from_cache_file=False)
    return ds

# Assuming cnmc_ds_dict is already loaded with 'train', 'test', and 'validation' datasets
# and contains Hugging Face Datasets

# Random selection of 10 records from each dataset
train_ds = Dataset.from_dict(cnmc_ds_dict['train'].shuffle(seed=42)[:10])
test_ds = Dataset.from_dict(cnmc_ds_dict['test'].shuffle(seed=42)[:10])
validation_ds = Dataset.from_dict(cnmc_ds_dict['validation'].shuffle(seed=42)[:10])

# Convert images for each Dataset and save the new datasets
train_ds_converted = convert_images_ds(train_ds)
test_ds_converted = convert_images_ds(test_ds)
validation_ds_converted = convert_images_ds(validation_ds)

# Debugging: Check the first element in the new dataset to see if the image was converted
print(f"Train dataset first image type: {type(train_ds_converted[0]['image'])}")

# Verify the conversion by checking if an example image is of PIL.Image.Image type
if isinstance(train_ds_converted[0]['image'], Image.Image):
    print("The conversion to PIL.Image.Image was successful.")
else:
    print("The conversion to PIL.Image.Image failed.")


Map:   0%|          | 0/10 [00:00<?, ? examples/s]

Converted image at index: <PIL.BmpImagePlugin.BmpImageFile image mode=RGB size=450x450 at 0x790599159720>
Converted image at index: <PIL.BmpImagePlugin.BmpImageFile image mode=RGB size=450x450 at 0x7905992ECBB0>
Converted image at index: <PIL.BmpImagePlugin.BmpImageFile image mode=RGB size=450x450 at 0x79059927FCA0>
Converted image at index: <PIL.BmpImagePlugin.BmpImageFile image mode=RGB size=450x450 at 0x7906623E3C40>
Converted image at index: <PIL.BmpImagePlugin.BmpImageFile image mode=RGB size=450x450 at 0x79066245C460>
Converted image at index: <PIL.BmpImagePlugin.BmpImageFile image mode=RGB size=450x450 at 0x7905991F5E10>
Converted image at index: <PIL.BmpImagePlugin.BmpImageFile image mode=RGB size=450x450 at 0x7907816ED900>
Converted image at index: <PIL.BmpImagePlugin.BmpImageFile image mode=RGB size=450x450 at 0x7907582D4820>
Converted image at index: <PIL.BmpImagePlugin.BmpImageFile image mode=RGB size=450x450 at 0x7907582D4790>
Converted image at index: <PIL.BmpImagePlugin.

Map:   0%|          | 0/10 [00:00<?, ? examples/s]

Converted image at index: <PIL.BmpImagePlugin.BmpImageFile image mode=RGB size=450x450 at 0x79059935D3F0>
Converted image at index: <PIL.BmpImagePlugin.BmpImageFile image mode=RGB size=450x450 at 0x79059935F220>
Converted image at index: <PIL.BmpImagePlugin.BmpImageFile image mode=RGB size=450x450 at 0x79059935E410>
Converted image at index: <PIL.BmpImagePlugin.BmpImageFile image mode=RGB size=450x450 at 0x79059935E920>
Converted image at index: <PIL.BmpImagePlugin.BmpImageFile image mode=RGB size=450x450 at 0x79059935E140>
Converted image at index: <PIL.BmpImagePlugin.BmpImageFile image mode=RGB size=450x450 at 0x79059935C580>
Converted image at index: <PIL.BmpImagePlugin.BmpImageFile image mode=RGB size=450x450 at 0x7907582D7520>
Converted image at index: <PIL.BmpImagePlugin.BmpImageFile image mode=RGB size=450x450 at 0x79059935E950>
Converted image at index: <PIL.BmpImagePlugin.BmpImageFile image mode=RGB size=450x450 at 0x79059935E290>
Converted image at index: <PIL.BmpImagePlugin.

Map:   0%|          | 0/10 [00:00<?, ? examples/s]

Converted image at index: <PIL.BmpImagePlugin.BmpImageFile image mode=RGB size=450x450 at 0x79059935CDF0>
Converted image at index: <PIL.BmpImagePlugin.BmpImageFile image mode=RGB size=450x450 at 0x790599118E50>
Converted image at index: <PIL.BmpImagePlugin.BmpImageFile image mode=RGB size=450x450 at 0x790599119E10>
Converted image at index: <PIL.BmpImagePlugin.BmpImageFile image mode=RGB size=450x450 at 0x79059911AB30>
Converted image at index: <PIL.BmpImagePlugin.BmpImageFile image mode=RGB size=450x450 at 0x79059911B880>
Converted image at index: <PIL.BmpImagePlugin.BmpImageFile image mode=RGB size=450x450 at 0x7905991182B0>
Converted image at index: <PIL.BmpImagePlugin.BmpImageFile image mode=RGB size=450x450 at 0x79059911AEC0>
Converted image at index: <PIL.BmpImagePlugin.BmpImageFile image mode=RGB size=450x450 at 0x7905991187C0>
Converted image at index: <PIL.BmpImagePlugin.BmpImageFile image mode=RGB size=450x450 at 0x790599118940>
Converted image at index: <PIL.BmpImagePlugin.

In [85]:
"""
Objective: Convert BMP image bytes in the CNMC dataset to PIL Image objects
while preserving the rest of the dataset structure. This prepares the dataset
for use with vision-language models like PaliGemma.

Steps:
1. Load a sample of the original dataset
2. Convert the 'image' field from bytes to PIL Image objects
3. Preserve all other fields in the dataset
4. Provide debugging information about the conversion process
"""

from PIL import Image
import io
from datasets import Dataset

def bytes_to_pil(image_bytes):
    """
    Convert image bytes to a PIL Image object.

    Args:
    image_bytes (bytes): Raw bytes of a BMP image.

    Returns:
    PIL.Image.Image or None: A PIL Image object if conversion is successful, None otherwise.
    """
    try:
        image = Image.open(io.BytesIO(image_bytes))
        image.load()  # Ensure the image data is loaded
        return image
    except Exception as e:
        print(f"Error converting BMP image: {e}")
        return None

def convert_image_field(example):
    """
    Convert the 'image' field in a dataset example from bytes to PIL Image.

    Args:
    example (dict): A single example from the dataset.

    Returns:
    dict: The same example with the 'image' field converted to PIL Image if successful.
    """
    image_data = example['image']
    if isinstance(image_data, dict) and 'bytes' in image_data:
        pil_image = bytes_to_pil(image_data['bytes'])
        if pil_image is not None:
            example['image'] = pil_image
        else:
            print("Failed to convert BMP to PIL Image")
    else:
        print(f"Unexpected image data format: {type(image_data)}")
    return example

def convert_dataset_images(ds):
    """
    Convert all images in a dataset from bytes to PIL Image objects.

    Args:
    ds (datasets.Dataset): The input dataset with image bytes.

    Returns:
    datasets.Dataset: A new dataset with converted PIL Image objects.
    """
    return ds.map(convert_image_field, batched=False, load_from_cache_file=False)

def sample_and_convert_datasets(cnmc_ds_dict, sample_size=10):
    """
    Sample and convert images for all splits in the CNMC dataset.

    Args:
    cnmc_ds_dict (dict): Dictionary containing train, test, and validation datasets.
    sample_size (int): Number of examples to sample from each split.

    Returns:
    dict: Dictionary containing converted datasets for each split.
    """
    converted_ds_dict = {}
    for split, dataset in cnmc_ds_dict.items():
        sampled_ds = dataset.shuffle(seed=42).select(range(sample_size))
        converted_ds = convert_dataset_images(sampled_ds)
        converted_ds_dict[split] = converted_ds
    return converted_ds_dict

def print_dataset_info(dataset, split_name):
    """
    Print information about a dataset split, including image and other feature details.

    Args:
    dataset (datasets.Dataset): The dataset to inspect.
    split_name (str): The name of the dataset split (e.g., 'train', 'test').
    """
    print(f"\n{split_name} Dataset Information:")
    for i, example in enumerate(dataset):
        image = example['image']
        if isinstance(image, Image.Image):
            print(f"Image {i}: PIL Image, mode: {image.mode}, size: {image.size}")
        elif isinstance(image, dict) and 'bytes' in image:
            print(f"Image {i}: Bytes data, length: {len(image['bytes'])}, first 8 bytes: {image['bytes'][:8]}")
        else:
            print(f"Image {i}: Unexpected type: {type(image)}")

    print("\nOther features in the first example:")
    for key, value in dataset[0].items():
        if key != 'image':
            print(f"{key}: {value}")

# Main execution
if __name__ == "__main__":
    # Assume cnmc_ds_dict is already loaded
    converted_datasets = sample_and_convert_datasets(cnmc_ds_dict)

    for split_name, dataset in converted_datasets.items():
        print_dataset_info(dataset, split_name)

Map:   0%|          | 0/10 [00:00<?, ? examples/s]

Map:   0%|          | 0/10 [00:00<?, ? examples/s]

Map:   0%|          | 0/10 [00:00<?, ? examples/s]


train Dataset Information:
Image 0: Bytes data, length: 608454, first 8 bytes: b'BM\xc6H\t\x00\x00\x00'
Image 1: Bytes data, length: 608454, first 8 bytes: b'BM\xc6H\t\x00\x00\x00'
Image 2: Bytes data, length: 608454, first 8 bytes: b'BM\xc6H\t\x00\x00\x00'
Image 3: Bytes data, length: 608454, first 8 bytes: b'BM\xc6H\t\x00\x00\x00'
Image 4: Bytes data, length: 608454, first 8 bytes: b'BM\xc6H\t\x00\x00\x00'
Image 5: Bytes data, length: 608454, first 8 bytes: b'BM\xc6H\t\x00\x00\x00'
Image 6: Bytes data, length: 608454, first 8 bytes: b'BM\xc6H\t\x00\x00\x00'
Image 7: Bytes data, length: 608454, first 8 bytes: b'BM\xc6H\t\x00\x00\x00'
Image 8: Bytes data, length: 608454, first 8 bytes: b'BM\xc6H\t\x00\x00\x00'
Image 9: Bytes data, length: 608454, first 8 bytes: b'BM\xc6H\t\x00\x00\x00'

Other features in the first example:
cell_count: 3
label: healthy
class_label: hem

test Dataset Information:
Image 0: Bytes data, length: 608454, first 8 bytes: b'BM\xc6H\t\x00\x00\x00'
Image 1: Bytes

In [88]:
from PIL import Image, UnidentifiedImageError
import io
from datasets import Dataset

# Function to convert bytes to PIL Image with error handling
def bytes_to_pil(image_data):
    try:
        image_stream = io.BytesIO(image_data)
        image = Image.open(image_stream)
        image.verify()
        image_stream.seek(0)
        image = Image.open(image_stream)
        return image
    except (UnidentifiedImageError, IOError) as e:
        print(f"Error: {e}")
        return None

# Convert images in a Hugging Face Dataset using with_indices
def convert_images_ds(ds):
    def process_image(example, idx):
        image_data = example['image']
        if isinstance(image_data, dict) and 'bytes' in image_data:
            pil_image = bytes_to_pil(image_data['bytes'])
            if pil_image is not None:
                ds[idx]['image'] = pil_image
        return example

    ds = ds.map(process_image, with_indices=True, batched=False, load_from_cache_file=False)
    return ds

# Assuming cnmc_ds_dict is already loaded with 'train', 'test', and 'validation' datasets
train_ds = Dataset.from_dict(cnmc_ds_dict['train'].shuffle(seed=42)[:10])
test_ds = Dataset.from_dict(cnmc_ds_dict['test'].shuffle(seed=42)[:10])
validation_ds = Dataset.from_dict(cnmc_ds_dict['validation'].shuffle(seed=42)[:10])

# Convert images for each Dataset and save the new datasets
train_ds_converted = convert_images_ds(train_ds)
test_ds_converted = convert_images_ds(test_ds)
validation_ds_converted = convert_images_ds(validation_ds)


Map:   0%|          | 0/10 [00:00<?, ? examples/s]

Map:   0%|          | 0/10 [00:00<?, ? examples/s]

Map:   0%|          | 0/10 [00:00<?, ? examples/s]

In [89]:
# Debugging: Check the first element in the new dataset to see if the image was converted
print(f"Train dataset first image type: {type(train_ds_converted[0]['image'])}")

# Verify the conversion by checking if an example image is of PIL.Image.Image type
if isinstance(train_ds_converted[0]['image'], Image.Image):
    print("The conversion to PIL.Image.Image was successful.")
else:
    print("The conversion to PIL.Image.Image failed.")

Train dataset first image type: <class 'dict'>
The conversion to PIL.Image.Image failed.


In [90]:
from PIL import Image, UnidentifiedImageError
import io
from datasets import Dataset

# Function to convert bytes to PIL Image with error handling
def bytes_to_pil(image_data):
    try:
        image_stream = io.BytesIO(image_data)
        image = Image.open(image_stream)
        image.verify()
        image_stream.seek(0)
        image = Image.open(image_stream)
        return image
    except (UnidentifiedImageError, IOError) as e:
        print(f"Error: {e}")
        return None

# Convert images in a Hugging Face Dataset
def convert_images_ds(ds):
    def process_image(example):
        image_data = example['image']
        if isinstance(image_data, dict) and 'bytes' in image_data:
            pil_image = bytes_to_pil(image_data['bytes'])
            if pil_image is not None:
                example['image'] = pil_image
        return example

    ds = ds.map(process_image, batched=False, load_from_cache_file=False)
    return ds

# Assuming cnmc_ds_dict is already loaded with 'train', 'test', and 'validation' datasets
train_ds = Dataset.from_dict(cnmc_ds_dict['train'].shuffle(seed=42)[:10])
test_ds = Dataset.from_dict(cnmc_ds_dict['test'].shuffle(seed=42)[:10])
validation_ds = Dataset.from_dict(cnmc_ds_dict['validation'].shuffle(seed=42)[:10])

# Convert images for each Dataset and save the new datasets
train_ds_converted = convert_images_ds(train_ds)
test_ds_converted = convert_images_ds(test_ds)
validation_ds_converted = convert_images_ds(validation_ds)

# Debugging: Check the first element in the new dataset to see if the image was converted
print(f"Train dataset first image type: {type(train_ds_converted[0]['image'])}")

# Verify the conversion by checking if an example image is of PIL.Image.Image type
if isinstance(train_ds_converted[0]['image'], Image.Image):
    print("The conversion to PIL.Image.Image was successful.")
else:
    print("The conversion to PIL.Image.Image failed.")


Map:   0%|          | 0/10 [00:00<?, ? examples/s]

Map:   0%|          | 0/10 [00:00<?, ? examples/s]

Map:   0%|          | 0/10 [00:00<?, ? examples/s]

Train dataset first image type: <class 'dict'>
The conversion to PIL.Image.Image failed.


## Collate Data

In [None]:
from transformers import AutoProcessor

processor = AutoProcessor.from_pretrained(base_model_name)

In [None]:
max_seq_length = 128
output_dir = adapter_version

In [None]:
from PIL import Image
import io

def collate_fn(examples):
    texts = ["Are these cells healthy or cancerous?" for _ in range(len(examples))]
    labels = [example['label'] for example in examples]
    images = [Image.open(io.BytesIO(example['image']['bytes'])).convert("RGB") for example in examples]
    tokens = processor(text=texts, images=images, return_tensors="pt", padding="longest")
    tokens["labels"] = processor.tokenizer(labels, padding="longest", return_tensors="pt").input_ids
    tokens = tokens.to(torch.bfloat16).to("cuda")
    return tokens

## Load and Quatize the base Model (bitsandbytes)

In [None]:
import torch

from transformers import PaliGemmaForConditionalGeneration, BitsAndBytesConfig
from peft import get_peft_model, LoraConfig

bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_type=torch.bfloat16
)

lora_config = LoraConfig(
    r=8,
    target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
    task_type="CAUSAL_LM",
)
model = PaliGemmaForConditionalGeneration.from_pretrained(base_model_name, quantization_config=bnb_config, device_map={"":0})
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
#trainable params: 11,298,816 || all params: 2,934,634,224 || trainable%: 0.38501616002417344


## Train the Adapter Model (trl)

In [None]:
model

In [None]:
model.config

In [None]:
from transformers import TrainingArguments

args = TrainingArguments(
    num_train_epochs=1,  # Reduced to 1 for quicker demonstration
    remove_unused_columns=False,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=2,  # Reduced to speed up training
    warmup_steps=2,
    learning_rate=2e-5,
    weight_decay=1e-6,
    adam_beta2=0.999,
    logging_steps=50,  # More frequent logging
    optim="adamw_hf",
    save_strategy="epoch",  # Set to epoch to match evaluation strategy
    push_to_hub=False,  # Disable pushing to hub for this demo
    output_dir="paligemma_vqav2",
    bf16=True,
    report_to=["tensorboard"],
    dataloader_pin_memory=False,
    load_best_model_at_end=True,  # Required for EarlyStoppingCallback
    evaluation_strategy="epoch",  # Set to epoch for periodic evaluation
)


In [None]:
from transformers import Trainer, EarlyStoppingCallback

# Define EarlyStoppingCallback
early_stopping = EarlyStoppingCallback(
    early_stopping_patience=1,
    early_stopping_threshold=0.01,
)

# Define Trainer with EarlyStoppingCallback
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=small_train_ds,
    eval_dataset=small_validation_ds,
    data_collator=collate_fn,
    callbacks=[early_stopping]
)


In [None]:
# Start fine-tuning
trainer.train()