# ファインチューニング

In [None]:
import os
from dataclasses import dataclass
from functools import partial

import albumentations
import numpy as np
import torch
import wandb
from datasets import load_from_disk
from PIL import Image, ImageDraw
from torchmetrics.detection.mean_ap import MeanAveragePrecision
from transformers import (
    AutoImageProcessor,
    AutoModelForObjectDetection,
    Trainer,
    TrainingArguments,
)
from transformers.image_transforms import center_to_corners_format

## 設定

In [None]:
# wandb settings
wandb_project_name = "BCCD"
artifact_name = ""

# local settings
use_local_model = False
dataset_path = "BCCD_dataset"
output_dir = "BCCD_output"

# modeling settings
num_train_epochs = 1
per_device_train_batch_size = 4

## データ & モデルロード

In [None]:
dataset = load_from_disk(dataset_path)
id2label = {0: "RBC", 1: "WBC", 2: "Platelets"}
label2id = {v: k for k, v in id2label.items()}

In [None]:
os.environ["WANDB_PROJECT"] = wandb_project_name
os.environ["WANDB_LOG_MODEL"] = "end"
run = wandb.init()

In [None]:
if artifact_name != "":
    artifact = run.use_artifact(artifact_name, type="model")
    checkpoint = artifact.download()
elif use_local_model:
    checkpoint = output_dir + "/model"
else:
    checkpoint = "facebook/detr-resnet-50"

image_processor = AutoImageProcessor.from_pretrained(checkpoint)
model = AutoModelForObjectDetection.from_pretrained(
    checkpoint,
    id2label=id2label,
    label2id=label2id,
    ignore_mismatched_sizes=True,
)

## 前処理準備

In [None]:
def format_image_annotations_as_coco(image_id, categories, areas, bboxes):
    annotations = []
    for category, area, bbox in zip(categories, areas, bboxes):
        formatted_annotation = {
            "image_id": image_id,
            "category_id": category,
            "iscrowd": 0,
            "area": area,
            "bbox": list(bbox),
        }
        annotations.append(formatted_annotation)

    return {
        "image_id": image_id,
        "annotations": annotations,
    }


train_augment_and_transform = albumentations.Compose(
    [
        albumentations.Perspective(p=0.1),
        albumentations.HorizontalFlip(p=0.5),
        albumentations.RandomBrightnessContrast(p=0.5),
        albumentations.HueSaturationValue(p=0.1),
    ],
    bbox_params=albumentations.BboxParams(
        format="coco", label_fields=["category"], clip=True, min_area=25
    ),
)


validation_transform = albumentations.Compose(
    [albumentations.NoOp()],
    bbox_params=albumentations.BboxParams(
        format="coco", label_fields=["category"], clip=True
    ),
)


def augment_and_transform_batch(
    examples, transform, image_processor, return_pixel_mask=False
):
    images = []
    annotations = []
    for image_id, image, objects in zip(
        examples["image_id"], examples["image"], examples["objects"]
    ):
        image = np.array(image.convert("RGB"))

        output = transform(
            image=image, bboxes=objects["bbox"], category=objects["category"]
        )
        images.append(output["image"])

        formatted_annotations = format_image_annotations_as_coco(
            image_id, output["category"], objects["area"], output["bboxes"]
        )
        annotations.append(formatted_annotations)

    result = image_processor(
        images=images, annotations=annotations, return_tensors="pt"
    )

    if not return_pixel_mask:
        result.pop("pixel_mask", None)

    return result


train_transform_batch = partial(
    augment_and_transform_batch,
    transform=train_augment_and_transform,
    image_processor=image_processor,
)


validation_transform_batch = partial(
    augment_and_transform_batch,
    transform=validation_transform,
    image_processor=image_processor,
)

## 前処理実行

### tips
trainer.evaluate であったり trainer.train の 評価時に， batch の中のデータ数が 1 つしかないと評価予測の返り値の形式が変わってしまうため，カスタム評価関数がうまく動かない場合があるため注意

→ 今回は TrainingArguments の引数 per_device_eval_batch_size のデフォルト値 8 で割った余りが 1 にならないように dataset 準備時に test_size=100 として対策している

In [None]:
# # サンプリング
# dataset = dataset.shuffle(seed=42).select(range(int(0.1 * len(dataset))))

In [None]:
dataset = dataset.train_test_split(test_size=100, seed=0)
dataset["train"] = dataset["train"].with_transform(train_transform_batch)
dataset["test"] = dataset["test"].with_transform(validation_transform_batch)

In [None]:
dataset

## ファインチューニング準備

### Preparing function to compute mAP

https://huggingface.co/docs/transformers/main/en/tasks/object_detection

In [None]:
def convert_bbox_yolo_to_pascal(boxes, image_size):
    boxes = center_to_corners_format(boxes)
    height, width = image_size
    boxes = boxes * torch.tensor([[width, height, width, height]])
    return boxes


@dataclass
class ModelOutput:
    logits: torch.Tensor
    pred_boxes: torch.Tensor


@torch.no_grad()
def compute_metrics(evaluation_results, image_processor, threshold=0.0, id2label=None):
    predictions, targets = evaluation_results.predictions, evaluation_results.label_ids
    image_sizes = []
    post_processed_targets = []
    post_processed_predictions = []

    for batch in targets:
        batch_image_sizes = torch.tensor(np.array([x["orig_size"] for x in batch]))
        image_sizes.append(batch_image_sizes)
        for image_target in batch:
            boxes = torch.tensor(image_target["boxes"])
            boxes = convert_bbox_yolo_to_pascal(boxes, image_target["orig_size"])
            labels = torch.tensor(image_target["class_labels"])
            post_processed_targets.append({"boxes": boxes, "labels": labels})

    for batch, target_sizes in zip(predictions, image_sizes):
        batch_logits, batch_boxes = batch[1], batch[2]
        output = ModelOutput(
            logits=torch.tensor(batch_logits), pred_boxes=torch.tensor(batch_boxes)
        )
        post_processed_output = image_processor.post_process_object_detection(
            output, threshold=threshold, target_sizes=target_sizes
        )
        post_processed_predictions.extend(post_processed_output)

    metric = MeanAveragePrecision(box_format="xyxy", class_metrics=True)
    metric.update(post_processed_predictions, post_processed_targets)
    metrics = metric.compute()

    classes = metrics.pop("classes")
    map_per_class = metrics.pop("map_per_class")
    mar_100_per_class = metrics.pop("mar_100_per_class")
    for class_id, class_map, class_mar in zip(
        classes, map_per_class, mar_100_per_class
    ):
        class_name = (
            id2label[class_id.item()] if id2label is not None else class_id.item()
        )
        metrics[f"map_{class_name}"] = class_map
        metrics[f"mar_100_{class_name}"] = class_mar

    metrics = {k: round(v.item(), 4) for k, v in metrics.items()}

    return metrics


eval_compute_metrics_fn = partial(
    compute_metrics, image_processor=image_processor, id2label=id2label, threshold=0.0
)

In [None]:
def collate_fn(batch):
    pixel_values = [item["pixel_values"] for item in batch]
    encoding = image_processor.pad(pixel_values, return_tensors="pt")
    labels = [item["labels"] for item in batch]
    batch = {}
    batch["pixel_values"] = encoding["pixel_values"]
    batch["pixel_mask"] = encoding["pixel_mask"]
    batch["labels"] = labels
    return batch

In [None]:
training_args = TrainingArguments(
    output_dir=output_dir + "/log",
    num_train_epochs=num_train_epochs,
    per_device_train_batch_size=per_device_train_batch_size,
    fp16=True,
    learning_rate=5e-5,
    lr_scheduler_type="cosine",
    weight_decay=1e-4,
    max_grad_norm=0.01,
    metric_for_best_model="eval_map",
    greater_is_better=True,
    load_best_model_at_end=True,
    eval_strategy="epoch",
    save_strategy="epoch",
    save_total_limit=2,
    remove_unused_columns=False,
    eval_do_concat_batches=False,
    report_to="wandb",
    logging_steps=1,
)

## ファインチューニング実行

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
    tokenizer=image_processor,
    compute_metrics=eval_compute_metrics_fn,
)

trainer.train()

## モデルの保存

In [None]:
model.save_pretrained(output_dir + "/model")
image_processor.save_pretrained(output_dir + "/model")

## W&B Table 作成

In [None]:
image_dir = "BCCD/JPEGImages/"
test_batch_size = 8

batch_sizes = [
    test_batch_size for x in range(int(dataset["test"].num_rows / test_batch_size))
]
left_batch_size = np.mod(dataset["test"].num_rows, test_batch_size)
if left_batch_size != 0:
    batch_sizes.append(left_batch_size)

image_filenames = os.listdir(image_dir)

In [None]:
# create table
table = wandb.Table(columns=["image_id", "ground truth", "prediction"])

for i, batch_size in enumerate(batch_sizes):
    # image_ids
    image_ids = [str(x["labels"]["image_id"].item()).zfill(5) for x in dataset["test"]][
        test_batch_size * i : test_batch_size * i + batch_size
    ]

    # prediction
    image_paths = [
        image_dir + x
        for x in image_filenames
        if x.split("_")[-1].split(".")[0] in image_ids
    ]
    images = [Image.open(x) for x in image_paths]
    image_sizes = [[x.size[1], x.size[0]] for x in images]

    with torch.no_grad():
        inputs = image_processor(images=images, return_tensors="pt")
        outputs = model(**inputs.to("cuda"))
        target_sizes = torch.tensor(image_sizes)
        results = image_processor.post_process_object_detection(
            outputs, threshold=0.5, target_sizes=target_sizes
        )

    for image, result in zip(images, results):
        draw = ImageDraw.Draw(image)
        for label, box in zip(result["labels"], result["boxes"]):
            box = [round(i, 2) for i in box.tolist()]
            x, y, x2, y2 = tuple(box)
            if label == 0:
                draw.rectangle((x, y, x2, y2), outline="red", width=1)
                draw.text((x, y), id2label[label.item()], fill="red")
            elif label == 1:
                draw.rectangle((x, y, x2, y2), outline="blue", width=1)
                draw.text((x, y), id2label[label.item()], fill="blue")
            elif label == 2:
                draw.rectangle((x, y, x2, y2), outline="green", width=1)
                draw.text((x, y), id2label[label.item()], fill="green")

    # ground truth
    dataset_org = load_from_disk(dataset_path)
    images_objects_org = [
        x for x in dataset_org if x["image_id"] in [int(x) for x in image_ids]
    ]
    images_org = [x["image"].copy() for x in images_objects_org]
    objects_org = [x["objects"].copy() for x in images_objects_org]

    for image, object_org in zip(images_org, objects_org):
        draw = ImageDraw.Draw(image)
        for label, box in zip(object_org["category"], object_org["bbox"]):
            x, y, w, h = tuple(box)
            x2 = x + w
            y2 = y + h
            if label == 0:
                draw.rectangle((x, y, x2, y2), outline="red", width=1)
                draw.text((x, y), id2label[label], fill="red")
            elif label == 1:
                draw.rectangle((x, y, x2, y2), outline="blue", width=1)
                draw.text((x, y), id2label[label], fill="blue")
            elif label == 2:
                draw.rectangle((x, y, x2, y2), outline="green", width=1)
                draw.text((x, y), id2label[label], fill="green")

    # add datas
    for image_id, groundtruth, prediction in zip(image_ids, images_org, images):
        table.add_data(image_id, wandb.Image(groundtruth), wandb.Image(prediction))

# log table
wandb.log({"Predictions": table})

In [None]:
wandb.finish()