In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as T
from torchvision.models.detection import fasterrcnn_resnet50_fpn
from PIL import Image
import json
import os
from tqdm import tqdm

In [None]:
from google.colab import drive

# Google Drive 마운트
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
class CustomDataset(Dataset):
    def __init__(self, json_file, transforms=None):
        with open(json_file) as f:
            self.data = json.load(f)
        self.transforms = transforms

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        key = list(self.data.keys())[idx]
        img_path = self.data[key]["image"]

        try:
            img = Image.open(img_path).convert("RGB")
        except FileNotFoundError:
            print(f"File not found: {img_path}")
            return None, None

        boxes = torch.as_tensor(self.data[key]["bbox"], dtype=torch.float32)
        labels = torch.as_tensor(self.data[key]["label"], dtype=torch.int64)

        # Validate and fix bounding boxes
        valid_boxes = []
        valid_labels = []
        for box, label in zip(boxes, labels):
            if box[2] > box[0] and box[3] > box[1]:
                valid_boxes.append(box)
                valid_labels.append(label)

        boxes = torch.stack(valid_boxes) if valid_boxes else torch.zeros((0, 4), dtype=torch.float32)
        labels = torch.tensor(valid_labels, dtype=torch.int64) if valid_labels else torch.zeros((0,), dtype=torch.int64)

        target = {}
        target["boxes"] = boxes
        target["labels"] = labels

        if self.transforms:
            img = self.transforms(img)

        return img, target

In [None]:
def get_transform(train):
    transforms = []
    transforms.append(T.ToTensor())
    if train:
        transforms.append(T.RandomHorizontalFlip(0.5))
    return T.Compose(transforms)

In [None]:
def collate_fn(batch):
    batch = [b for b in batch if b[0] is not None and b[1] is not None]  # Filter out None values
    return tuple(zip(*batch)) if batch else ([], [])

In [None]:
# Load datasets
train_dataset = CustomDataset('/content/drive/MyDrive/preprocessed_data/dataset_train.json', get_transform(train=True))
valid_dataset = CustomDataset('/content/drive/MyDrive/preprocessed_data/dataset_valid.json', get_transform(train=False))
test_dataset = CustomDataset('/content/drive/MyDrive/preprocessed_data/dataset_test.json', get_transform(train=False))

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=collate_fn)
valid_loader = DataLoader(valid_dataset, batch_size=16, shuffle=False, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, collate_fn=collate_fn)

In [None]:
# Model
model = fasterrcnn_resnet50_fpn(pretrained=True)
num_classes = 16  # Assuming 15 classes + background
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = fasterrcnn_resnet50_fpn(num_classes=num_classes).roi_heads.box_predictor

Downloading: "https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth" to /root/.cache/torch/hub/checkpoints/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth
100%|██████████| 160M/160M [00:02<00:00, 79.9MB/s]
Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:01<00:00, 82.7MB/s]


In [None]:
# Training function
def train_one_epoch(model, data_loader, optimizer, device, epoch):
    model.train()
    running_loss = 0.0
    for images, targets in tqdm(data_loader, desc=f"Epoch {epoch}"):
        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
        loss_dict = model(images, targets)
        losses = sum(loss for loss in loss_dict.values())
        optimizer.zero_grad()
        losses.backward()
        optimizer.step()
        running_loss += losses.item()
    epoch_loss = running_loss / len(data_loader)
    return epoch_loss

In [None]:
# Evaluation function
def evaluate(model, data_loader, device):
    model.eval()
    total_boxes = 0
    correct_boxes = 0
    with torch.no_grad():
        for images, targets in data_loader:
            images = list(image.to(device) for image in images)
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
            outputs = model(images)

            for target, output in zip(targets, outputs):
                total_boxes += len(target["boxes"])
                for pred_box, pred_label in zip(output["boxes"], output["labels"]):
                    if pred_label in target["labels"]:
                        iou = box_iou(pred_box.unsqueeze(0), target["boxes"]).max().item()
                        if iou > 0.5:
                            correct_boxes += 1

    accuracy = correct_boxes / total_boxes if total_boxes > 0 else 0
    return accuracy

In [None]:
def box_iou(box1, box2):
    """Compute the Intersection Over Union (IOU) of two sets of boxes.
    The box order must be (xmin, ymin, xmax, ymax).
    """
    inter = (torch.min(box1[..., None, 2:], box2[..., 2:]) -
             torch.max(box1[..., None, :2], box2[..., :2])).clamp(0).prod(2)
    area1 = (box1[..., 2:] - box1[..., :2]).prod(1)
    area2 = (box2[..., 2:] - box2[..., :2]).prod(1)
    union = area1[..., None] + area2 - inter
    return inter / union

In [None]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device)

FasterRCNN(
  (transform): GeneralizedRCNNTransform(
      Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
      Resize(min_size=(800,), max_size=1333, mode='bilinear')
  )
  (backbone): BackboneWithFPN(
    (body): IntermediateLayerGetter(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): FrozenBatchNorm2d(64, eps=0.0)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(64, eps=0.0)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(64, eps=0.0)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(256, eps=0.0)
          (relu): ReLU(

In [None]:
params = [p for p in model.parameters() if p.requires_grad]
optimizer = optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)

In [None]:
num_epochs = 1
for epoch in range(num_epochs):
    train_loss = train_one_epoch(model, train_loader, optimizer, device, epoch)
    val_accuracy = evaluate(model, valid_loader, device)
    print(f"Epoch {epoch}: Train Loss: {train_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}")

Epoch 0: 100%|██████████| 907/907 [3:30:23<00:00, 13.92s/it]


File not found: /content/drive/MyDrive/preprocessed_data/valid/images/rotate0_mask_in_the_ocean_39_jpg.rf.c61ba35d9450c0dc0483dc3bee4bb188.jpg
File not found: /content/drive/MyDrive/preprocessed_data/valid/images/rotate90_mask_in_the_ocean_39_jpg.rf.c61ba35d9450c0dc0483dc3bee4bb188.jpg
File not found: /content/drive/MyDrive/preprocessed_data/valid/images/rotate180_mask_in_the_ocean_39_jpg.rf.c61ba35d9450c0dc0483dc3bee4bb188.jpg
File not found: /content/drive/MyDrive/preprocessed_data/valid/images/rotate270_mask_in_the_ocean_39_jpg.rf.c61ba35d9450c0dc0483dc3bee4bb188.jpg
File not found: /content/drive/MyDrive/preprocessed_data/valid/images/rotate0_uwg_g-838__flipv_jpg.rf.4cb0a5ecfe2561ad5979ba148a56b6b4.jpg
File not found: /content/drive/MyDrive/preprocessed_data/valid/images/rotate90_uwg_g-838__flipv_jpg.rf.4cb0a5ecfe2561ad5979ba148a56b6b4.jpg
File not found: /content/drive/MyDrive/preprocessed_data/valid/images/rotate180_uwg_g-838__flipv_jpg.rf.4cb0a5ecfe2561ad5979ba148a56b6b4.jpg
Fil

In [None]:
# Save the model
torch.save(model.state_dict(), '/content/drive/MyDrive/Colab Notebooks/detection_model.pth')

In [None]:
# Load the model for testing
model.load_state_dict(torch.load('/content/drive/MyDrive/Colab Notebooks/detection_model.pth'))
test_accuracy = evaluate(model, test_loader, device)
print(f"Test Accuracy: {test_accuracy:.4f}")

Test Accuracy: 0.4601
