In [None]:
import os
import torch
import torchvision
import pytorch_lightning as pl
import json
from torch.utils.data import DataLoader
from torchvision.models.detection import retinanet_resnet50_fpn
from torchvision.transforms import functional as F
from PIL import Image

In [None]:
import wandb
import random

# start a new wandb run to track this script
wandb.init(
    # set the wandb project where this run will be logged
    project="SEO_project_01",

    # # track hyperparameters and run metadata
    # config={
    # "learning_rate": 0.02,
    # "epochs": 10,
    # }
)

# # simulate training
# epochs = 10
# offset = random.random() / 5
# for epoch in range(2, epochs):
#     acc = 1 - 2 ** -epoch - random.random() / epoch - offset
#     loss = 2 ** -epoch + random.random() / epoch + offset

#     # log metrics to wandb
#     wandb.log({"acc": acc, "loss": loss})

# [optional] finish the wandb run, necessary in notebooks
wandb.finish()

In [None]:
class COCODataset(torch.utils.data.Dataset):
    def __init__(self, image_dir, annotation_file, transforms=None):
        self.image_dir = image_dir
        self.transforms = transforms
        with open(annotation_file) as f:
            self.coco_data = json.load(f)

        self.images = self.coco_data['images']
        self.annotations = self.coco_data['annotations']

        # Mapping from image ID to its annotations
        self.img_id_to_annotations = {}
        for ann in self.annotations:
            img_id = ann['image_id']
            if img_id not in self.img_id_to_annotations:
                self.img_id_to_annotations[img_id] = []
            self.img_id_to_annotations[img_id].append(ann)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_info = self.images[idx]
        img_path = os.path.join(self.image_dir, img_info['file_name'])
        image = Image.open(img_path).convert("RGB")

        # Get annotations for the current image
        img_id = img_info['id']
        annotations = self.img_id_to_annotations.get(img_id, [])

        # Extract boxes and labels from annotations
        boxes = []
        labels = []
        for ann in annotations:
            bbox = ann['bbox']
            # Convert COCO bbox (x, y, w, h) to (x1, y1, x2, y2)
            x1, y1, w, h = bbox
            x2, y2 = x1 + w, y1 + h
            boxes.append([x1, y1, x2, y2])
            labels.append(ann['category_id'])

        # Convert data to tensor format
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.as_tensor(labels, dtype=torch.int64)

        target = {}
        target['boxes'] = boxes
        target['labels'] = labels

        if self.transforms:
            image = self.transforms(image)

        return image, target

# Transforms (optional: for data augmentation)
def get_transform(train):
    transforms = []
    transforms.append(torchvision.transforms.ToTensor())  # 수정된 부분
    return torchvision.transforms.Compose(transforms)

In [None]:
class RetinaNetModel(pl.LightningModule):
    def __init__(self):
        super(RetinaNetModel, self).__init__()
        # Pretrained RetinaNet model
        self.model = retinanet_resnet50_fpn(weights_backbone='DEFAULT', num_classes=11)
        
        # Modify the number of classes for the model
        # self.model.head.classification_head.num_classes = num_classes

    def forward(self, images, targets=None):
        return self.model(images, targets)

    def training_step(self, batch, batch_idx):
        images, targets = batch
        loss_dict = self.model(images, targets)
        loss = sum(loss for loss in loss_dict.values())
        self.log('train_loss', loss, prog_bar=True, on_step=True, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        images, targets = batch
        # Forward pass for validation
        loss_dict = self.model(images, targets)
        val_loss = sum(loss for loss in loss_dict.values())

        # Calculate accuracy
        preds = self.extract_preds(images)
        targets_labels = [target['labels'] for target in targets]
        self.val_acc(preds, targets_labels)

        # Log validation loss for each batch
        self.log('val_loss', val_loss, prog_bar=True, on_step=False, on_epoch=True)

        return {'val_loss': val_loss}

    def validation_epoch_end(self, outputs):
        # Log validation accuracy at the end of each epoch
        self.log('val_acc', self.val_acc, prog_bar=True, on_epoch=True)

    def extract_preds(self, images):
        """Extract predictions in the correct format for accuracy calculation."""
        preds = self.model(images)
        preds_labels = [pred['labels'] for pred in preds]
        return preds_labels

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-4)


In [None]:
class COCODataModule(pl.LightningDataModule):
    def __init__(self, train_dir, train_ann, test_dir, test_ann, batch_size=4):
        super().__init__()
        self.train_dir = train_dir
        self.train_ann = train_ann
        self.test_dir = test_dir
        self.test_ann = test_ann
        self.batch_size = batch_size

    def setup(self, stage=None):
        self.train_dataset = COCODataset(self.train_dir, self.train_ann, transforms=get_transform(train=True))
        self.test_dataset = COCODataset(self.test_dir, self.test_ann, transforms=get_transform(train=False))

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, collate_fn=self.collate_fn)

    def val_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False, collate_fn=self.collate_fn)

    @staticmethod
    def collate_fn(batch):
        return tuple(zip(*batch))


In [None]:
if __name__ == "__main__":
    # Paths to your dataset
    train_dir = '/data/ephemeral/level2-objectdetection-cv-15/dataset'
    train_ann = '/data/ephemeral/level2-objectdetection-cv-15/dataset/train.json'
    test_dir = '/data/ephemeral/level2-objectdetection-cv-15/dataset'
    test_ann = '/data/ephemeral/level2-objectdetection-cv-15/dataset/test.json'

    # Number of classes (including background as class 0)
    # num_classes = 11  # Change this to the number of classes in your dataset

    # Initialize the data module and the RetinaNet model
    data_module = COCODataModule(train_dir, train_ann, test_dir, test_ann, batch_size=4)
    model = RetinaNetModel()

    # # Initialize trainer with updated arguments
    trainer = pl.Trainer(accelerator='gpu', devices=1, max_epochs=10)


    # # Train the model
    trainer.fit(model, datamodule=data_module)
