In [43]:
import os
import pandas as pd
import torch
from torch.utils.data import Dataset
from PIL import Image
import pydicom
from transformers import (
    RTDetrForObjectDetection,
    RTDetrImageProcessor,
    Trainer,
    TrainingArguments,
)
import warnings
from pipelines.detr.config import *
from global_config import *

In [49]:
warnings.filterwarnings("ignore")
torch.backends.cudnn.benchmark = True

In [50]:
class RSNADataset(Dataset):

    def __init__(
            self,
            csv_file: str,
            images_dir: str,
            image_processor,
            class_names,
    ) -> None:
        self.df = pd.read_csv(csv_file)
        self.images_dir = images_dir
        self.image_processor = image_processor
        self.class_names = class_names
        self.label2id = {name: i for i, name in enumerate(class_names)}
        self.patient_ids = self.df["patientId"].unique().tolist()

    def __len__(self):
        return len(self.patient_ids)

    def __getitem__(self, idx):

        patient_id = self.patient_ids[idx]
        dicom_path = os.path.join(self.images_dir, f"{patient_id}.dcm")
        dcm = pydicom.dcmread(dicom_path)
        image = Image.fromarray(dcm.pixel_array).convert("RGB")

        records = self.df[
            self.df["patientId"] == patient_id
        ]
        annotations = []
        for i, row in records.iterrows():
            if row["Target"] == 1:
                coco_bbox = [row["x"], row["y"], row["width"], row["height"]]
                annotations.append({
                    "image_id": idx,
                    "bbox": coco_bbox,
                    "category_id": self.label2id["Lung Opacity"],
                    "area": float(row["width"] * row["height"]),
                    "iscrowd": 0,
                    "id": int(i)
                })
        target = {"image_id": idx, "annotations": annotations}
        encoding = self.image_processor(images=image, annotations=target, return_tensors="pt")
        pixel_values = encoding["pixel_values"].squeeze()
        labels = encoding["labels"][0] if "labels" in encoding else target
        return {"pixel_values": pixel_values, "labels": labels}

In [51]:
def collate_fn(batch):
    pixel_values = torch.stack([item["pixel_values"] for item in batch])
    labels = [item["labels"] for item in batch]
    return {"pixel_values": pixel_values, "labels": labels}

In [54]:
def setup_fine_tuning(
        csv_file,
        images_dir,
        output_dir,
        num_epochs,
        batch_size,
        learning_rate,
):
    image_processor = RTDetrImageProcessor.from_pretrained(MODEL_NAME)
    model = RTDetrForObjectDetection.from_pretrained(
        MODEL_NAME,
        num_labels=len(class_names),
        ignore_mismatched_sizes=True,
        id2label={i: name for i, name in enumerate(class_names)},
        label2id={name: i for i, name in enumerate(class_names)},
    )
    dataset = RSNADataset(csv_file, images_dir, image_processor, class_names)
    train_size = int(0.8 * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = torch.utils.data.random_split(
        dataset,
        [train_size, val_size]
    )
    training_args = TrainingArguments(
        output_dir=output_dir,
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        num_train_epochs=num_epochs,
        learning_rate=learning_rate,
        eval_strategy="epoch",
        save_strategy="epoch",
        logging_dir=os.path.join(output_dir, "logs"),
        logging_steps=LOG_STEP,
        load_best_model_at_end=True,
        remove_unused_columns=False,
        push_to_hub=False,
        disable_tqdm=False,
        report_to="tensorboard",
    )
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        tokenizer=image_processor,
        data_collator=collate_fn,
        patience=PATIENCE,
    )
    return trainer, model

In [None]:
!tensorboard --logdir /kaggle/working/rtdetr_rsna_output/logs --port 6006

In [None]:
trainer, model = setup_fine_tuning(
    csv_file=CSV_FILE,
    images_dir=IMAGES_DIR,
    output_dir=OUTOUT_DIR,
    num_epochs=NUM_EPOCHS,
    batch_size=BATCH_SIZE,
    learning_rate=LR
)

trainer.train()
model.save_pretrained(OUTOUT_DIR)