## PaliGemma Fine-tuning
Pretrained Paligemma를 파인튜닝하여 딥페이크 기술로 생성된 이미지를 분류하는 모델을 생성

### 환경 설정

In [None]:
!pip install torch
!pip install transformers
!pip install peft
!pip install trl
!pip install -U bitsandbytes
!pip install datasets
!pip install accelerate

Collecting trl
  Downloading trl-0.12.2-py3-none-any.whl.metadata (11 kB)
Collecting datasets>=2.21.0 (from trl)
  Downloading datasets-3.1.0-py3-none-any.whl.metadata (20 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets>=2.21.0->trl)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets>=2.21.0->trl)
  Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets>=2.21.0->trl)
  Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.9.0,>=2023.1.0 (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets>=2.21.0->trl)
  Downloading fsspec-2024.9.0-py3-none-any.whl.metadata (11 kB)
Downloading trl-0.12.2-py3-none-any.whl (365 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m365.7/365.7 kB[0m [31m8.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading datasets-3.1.0-py3-none-any.whl (480 kB)
[2K   [90m━━━━━━━━━━━━━━━━

### Huggingface 로그인

- paligemma에 대한 read 권한 확보
- 로그인 한 계정에 대하여 파인튜닝 모델 업로드 권한 확보

In [None]:
from huggingface_hub import notebook_login
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

### 학습 준비
학습 데이터셋 로드

In [None]:
from datasets import load_dataset

ds = load_dataset("JamieWithofs/Deepfake-and-real-images")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md:   0%|          | 0.00/624 [00:00<?, ?B/s]

train-00000-of-00003.parquet:   0%|          | 0.00/426M [00:00<?, ?B/s]

train-00001-of-00003.parquet:   0%|          | 0.00/436M [00:00<?, ?B/s]

train-00002-of-00003.parquet:   0%|          | 0.00/424M [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/116M [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/391M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/140002 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/10905 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/39428 [00:00<?, ? examples/s]

데이터셋 구조 파악

In [None]:
ds['train']

Dataset({
    features: ['image', 'label'],
    num_rows: 140002
})

In [None]:
train_ds = ds['test']

In [None]:
train_ds

Dataset({
    features: ['image', 'label'],
    num_rows: 10905
})

In [None]:
question_make = ['Is this image made by AI?' for i in range(len(train_ds['label']))]
train_ds = train_ds.add_column("question", question_make)

In [None]:
train_ds

Dataset({
    features: ['image', 'label', 'question'],
    num_rows: 10905
})

PaliGemmaProcessor는 PaliGemma 모델과 함께 사용하는 프로세서로, 모델의 입력 데이터를 적절히 전처리하고 모델의 출력을 후처리하는 역할을 수행

In [None]:
from transformers import PaliGemmaProcessor
model_id = "google/paligemma-3b-pt-224"
processor = PaliGemmaProcessor.from_pretrained(model_id)

The cache for model files in Transformers v4.22.0 has been updated. Migrating your old cache. This is a one-time only operation. You can interrupt this and resume the migration later on by calling `transformers.utils.move_cache()`.


0it [00:00, ?it/s]

preprocessor_config.json:   0%|          | 0.00/699 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/40.0k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.26M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/24.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/607 [00:00<?, ?B/s]

In [None]:
import torch
device = "cuda"

image_token = processor.tokenizer.convert_tokens_to_ids("<image>")
bos_token = processor.tokenizer.bos_token  # BOS 토큰 가져오기

def collate_fn(examples):
    texts = [f"<image> {bos_token} answer {example['question']}" for example in examples]
    labels= [str(example['label']) for example in examples]
    images = [example["image"].convert("RGB") for example in examples]
    tokens = processor(text=texts,
                       images=images,
                       suffix=labels,
                    return_tensors="pt",
                       padding="longest")

    tokens = tokens.to(torch.bfloat16).to(device)
    return tokens


In [None]:
from transformers import PaliGemmaForConditionalGeneration
import torch

model = PaliGemmaForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch.bfloat16).to(device)

for param in model.vision_tower.parameters():
    param.requires_grad = False

for param in model.multi_modal_projector.parameters():
    param.requires_grad = False


config.json:   0%|          | 0.00/1.03k [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/62.6k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/3 [00:00<?, ?it/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/4.95G [00:00<?, ?B/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/1.74G [00:00<?, ?B/s]

`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/137 [00:00<?, ?B/s]

QLora 적용

In [None]:
from transformers import BitsAndBytesConfig
from peft import get_peft_model, LoraConfig

bnb_config = BitsAndBytesConfig(        # 4-bit quantization
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16
    )

lora_config = LoraConfig(
    r=8,
    target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
    task_type="CAUSAL_LM",
)
model = PaliGemmaForConditionalGeneration.from_pretrained(model_id, quantization_config=bnb_config, device_map="auto")
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
#trainable params: 11,298,816 || all params: 2,934,634,224 || trainable%: 0.38501616002417344


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

trainable params: 11,298,816 || all params: 2,934,765,296 || trainable%: 0.3850


TrainingArguments 클래스를 사용하여 모델 학습에 대한 다양한 하이퍼파라미터를 설정

In [None]:
from transformers import TrainingArguments
args=TrainingArguments(
            num_train_epochs=10,
            remove_unused_columns=False,
            per_device_train_batch_size=4,
            gradient_accumulation_steps=4,
            warmup_steps=2,
            learning_rate=2e-5,
            weight_decay=1e-6,
            adam_beta2=0.999,
            logging_steps=100,
            optim="adamw_hf",
            save_strategy="steps",
            push_to_hub=True,
            save_steps=1000,
            save_total_limit=1,
            output_dir="paligemma_deepfake_2024",
            bf16=True,
            dataloader_pin_memory=False
        )


학습 진행

In [None]:
from transformers import Trainer

trainer = Trainer(
        model=model,
        train_dataset=train_ds ,
        data_collator=collate_fn,
        args=args
        )


In [None]:
trainer.train()

Step,Training Loss
100,1.984
200,0.6179
300,0.5712
400,0.5345
500,0.4503
600,0.4183
700,0.3393
800,0.3083
900,0.2702
1000,0.2895


TrainOutput(global_step=6810, training_loss=0.17923424933625387, metrics={'train_runtime': 12131.4565, 'train_samples_per_second': 8.989, 'train_steps_per_second': 0.561, 'total_flos': 4.218461322778829e+17, 'train_loss': 0.17923424933625387, 'epoch': 9.995599559955995})