In [1]:
# ============================================
# Fine-tune BLIP-base on anime-captions (Kaggle, Streaming)
# Handles NoneType + dynamic padding
# ============================================


# !pip install transformers datasets --quiet
import torch
from torch.nn.utils.rnn import pad_sequence
from transformers import BlipProcessor, BlipForConditionalGeneration, Trainer, TrainingArguments
from datasets import load_dataset,IterableDataset
from itertools import islice
from PIL import Image

# --------------------------------------------
# Check GPU
# --------------------------------------------
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

# --------------------------------------------
# Load BLIP-base
# --------------------------------------------
model_name = "Salesforce/blip-image-captioning-base"
processor = BlipProcessor.from_pretrained(model_name)
model = BlipForConditionalGeneration.from_pretrained(model_name).to(device)

# --------------------------------------------
# Training arguments
# --------------------------------------------
training_args = TrainingArguments(
    output_dir="./blip-anime",
    per_device_train_batch_size=4,   # T4/P100 safe
    fp16=True,                       # mixed precision
    save_strategy="steps",
    save_steps=10000,
    save_total_limit=2, # save every ~2000 steps
    logging_steps=500,
    max_steps=40000,                 # ~6–7h on Kaggle
    report_to="none",
    remove_unused_columns=False
)

# --------------------------------------------
# Preprocessing function
# --------------------------------------------
def preprocess(example):
    try:
        image = example["image"]
        if isinstance(image, str):
            image = Image.open(image).convert("RGB")
        else:
            image = image.convert("RGB")

        inputs = processor(images=image, text=example["text"], return_tensors="pt")

        return {
            "input_ids": inputs["input_ids"][0],
            "attention_mask": inputs["attention_mask"][0],
            "pixel_values": inputs["pixel_values"][0],
        }
    except Exception:
        return {}  # return empty dict, will be dropped later

# --------------------------------------------
# Collator (pads captions dynamically)
# --------------------------------------------
def collate_fn(batch):
    batch = [x for x in batch if x]  # drop empties

    input_ids = [x["input_ids"] for x in batch]
    attention_masks = [x["attention_mask"] for x in batch]
    pixel_values = torch.stack([x["pixel_values"] for x in batch])

    input_ids = pad_sequence(input_ids, batch_first=True, padding_value=processor.tokenizer.pad_token_id)
    attention_masks = pad_sequence(attention_masks, batch_first=True, padding_value=0)

    return {
        "input_ids": input_ids,
        "attention_mask": attention_masks,
        "pixel_values": pixel_values,
        "labels": input_ids.clone(),  # teacher forcing
    }

# --------------------------------------------
# Load dataset in streaming mode
# --------------------------------------------
dataset = load_dataset("none-yet/anime-captions", split="train", streaming=True)
total = 337000  # replace with your dataset size
half = total // 2
half1 = IterableDataset.from_generator(
    lambda: islice(load_dataset("none-yet/anime-captions", split="train", streaming=True), 0, half)
)

half2 = IterableDataset.from_generator(
    lambda: islice(load_dataset("none-yet/anime-captions", split="train", streaming=True), half, total)
)
# Map preprocessing lazily (streaming-safe)
half1 = half1.map(preprocess)

# --------------------------------------------
# Setup Trainer
# --------------------------------------------
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=half1,
    data_collator=collate_fn,
)

# --------------------------------------------
# Train
# --------------------------------------------
trainer.train()

# --------------------------------------------
# Save final model
# --------------------------------------------
model.save_pretrained("./blip-anime-half1-final")
processor.save_pretrained("./blip-anime-half1-final")

print("\n✅ Fine-tuning complete! Final model saved at ./blip-anime-half1-final")


2025-09-17 06:15:07.988373: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1758089708.343792      19 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1758089708.445334      19 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


Using device: cuda


preprocessor_config.json:   0%|          | 0.00/287 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/506 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

config.json: 0.00B [00:00, ?B/s]

pytorch_model.bin:   0%|          | 0.00/990M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/990M [00:00<?, ?B/s]

README.md:   0%|          | 0.00/391 [00:00<?, ?B/s]

Resolving data files:   0%|          | 0/57 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/57 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/57 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/57 [00:00<?, ?it/s]



Step,Training Loss
500,1.0654
1000,0.8179
1500,0.7817
2000,0.7559
2500,0.7405
3000,0.7307
3500,0.7161
4000,0.7009
4500,0.6929
5000,0.6889




Resolving data files:   0%|          | 0/57 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/57 [00:00<?, ?it/s]




✅ Fine-tuning complete! Final model saved at ./blip-anime-half1-final
