In [None]:
import os
import torchvision
from torchvision.models.detection import fasterrcnn_resnet50_fpn_v2, FasterRCNN_ResNet50_FPN_V2_Weights
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

def get_model(num_classes):
    """
    Args:
        num_classes: number of classes
    Returns:
        model: torchvision model
    """

    model = fasterrcnn_resnet50_fpn_v2(weights=FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT)
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
    return model


In [None]:
import torch
from torch.utils.data import Dataset
import torchvision.transforms as T
from PIL import Image
import json
import os

class TACODataset(Dataset):
    def __init__(self, root_dir, annotation_file, transform=None):
        """
        Args:
            root_dir (string): 이미지가 있는 디렉토리 경로
            annotation_file (string): annotations_X_train.json 또는 annotations_X_test.json 파일 경로
            transform (callable, optional): 이미지에 적용할 변환
        """
        self.root_dir = root_dir
        self.transform = transform or T.Compose([
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406],
                       std=[0.229, 0.224, 0.225])
        ])

        # Load annotations
        with open(annotation_file, 'r') as f:
            self.coco = json.load(f)

        # Create category mapping
        self.cat_ids = {}
        self.cat_id_to_label = {}
        for i, cat in enumerate(self.coco['categories'], 1):  # 1부터 시작 (0은 배경)
            self.cat_ids[cat['id']] = cat['name']
            self.cat_id_to_label[cat['id']] = i

        # Get all valid image ids
        self.ids = []
        for img in self.coco['images']:
            if os.path.exists(os.path.join(self.root_dir, img['file_name'])):
                self.ids.append(img['id'])

    def __getitem__(self, idx):
        img_id = self.ids[idx]
        img_info = next(img for img in self.coco['images'] if img['id'] == img_id)

        # Load image
        img_path = os.path.join(self.root_dir, img_info['file_name'])
        img = Image.open(img_path).convert("RGB")

        # Get annotations
        anns = [ann for ann in self.coco['annotations'] if ann['image_id'] == img_id]

        boxes = []
        labels = []

        for ann in anns:
            bbox = ann['bbox']  # [x, y, width, height]
            boxes.append([
                bbox[0],
                bbox[1],
                bbox[0] + bbox[2],
                bbox[1] + bbox[3]
            ])
            # Convert category_id to sequential label
            labels.append(self.cat_id_to_label[ann['category_id']])

        # Handle images without annotations
        if len(boxes) == 0:
            boxes = torch.zeros((0, 4), dtype=torch.float32)
            labels = torch.zeros((0,), dtype=torch.int64)
        else:
            boxes = torch.as_tensor(boxes, dtype=torch.float32)
            labels = torch.as_tensor(labels, dtype=torch.int64)

        image_id = torch.tensor([img_id])
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
        iscrowd = torch.zeros((len(boxes),), dtype=torch.int64)

        target = {
            'boxes': boxes,
            'labels': labels,
            'image_id': image_id,
            'area': area,
            'iscrowd': iscrowd
        }

        if self.transform:
            img = self.transform(img)

        return img, target

    def __len__(self):
        return len(self.ids)

    @property
    def num_classes(self):
        return len(self.cat_ids) + 1  # +1 for background

In [None]:
# train.ipynb

# 1. 필요한 라이브러리 임포트
import torch
import os
from torch.utils.data import DataLoader
import torch.optim as optim

# 2. collate_fn 정의
def collate_fn(batch):
    return tuple(zip(*batch))

def train(args):
    # Device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # 모든 round의 데이터셋과 데이터로더 준비
    train_datasets = []
    train_loaders = []
    test_datasets = []
    test_loaders = []

    for round_num in range(args.num_rounds):
        # Train dataset/loader
        train_dataset = TACODataset(
            root_dir=args.data_dir,
            annotation_file=os.path.join(args.data_dir, f'annotations_{round_num}_train.json')
        )
        train_loader = DataLoader(
            train_dataset,
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=2,
            collate_fn=collate_fn
        )
        train_datasets.append(train_dataset)
        train_loaders.append(train_loader)

        # Test dataset/loader
        test_dataset = TACODataset(
            root_dir=args.data_dir,
            annotation_file=os.path.join(args.data_dir, f'annotations_{round_num}_test.json')
        )
        test_loader = DataLoader(
            test_dataset,
            batch_size=1,
            shuffle=False,
            num_workers=2,
            collate_fn=collate_fn
        )
        test_datasets.append(test_dataset)
        test_loaders.append(test_loader)

    # Model 초기화
    model = get_model(train_datasets[0].num_classes)
    model.to(device)

    # Optimizer
    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = optim.SGD(params, lr=args.lr, momentum=0.9, weight_decay=0.0005)

    # Learning rate scheduler
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                  step_size=3,
                                                  gamma=0.1)

    # Training loop
    for epoch in range(args.epochs):
        model.train()
        epoch_loss = 0
        total_batches = 0

        # 각 epoch에서 모든 round의 데이터로 학습
        for round_num, train_loader in enumerate(train_loaders):
            round_loss = 0

            for i, (images, targets) in enumerate(train_loader):
                images = list(image.to(device) for image in images)
                targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

                loss_dict = model(images, targets)
                losses = sum(loss for loss in loss_dict.values())

                optimizer.zero_grad()
                losses.backward()
                optimizer.step()

                round_loss += losses.item()
                epoch_loss += losses.item()
                total_batches += 1

                if i % 50 == 0:
                    print(f'Epoch [{epoch+1}/{args.epochs}], Round [{round_num}], '
                          f'Step [{i}/{len(train_loader)}], Loss: {losses.item():.4f}')

            avg_round_loss = round_loss / len(train_loader)
            print(f'Epoch [{epoch+1}/{args.epochs}], Round [{round_num}] '
                  f'Average Loss: {avg_round_loss:.4f}')

        # 에폭의 평균 손실 계산
        avg_epoch_loss = epoch_loss / total_batches
        print(f'Epoch [{epoch+1}/{args.epochs}] Complete, '
              f'Average Loss: {avg_epoch_loss:.4f}')

        lr_scheduler.step()

        # 체크포인트 저장
        if (epoch + 1) % args.save_freq == 0:
            checkpoint_path = os.path.join(args.output_dir,
                                         f'model_epoch_{epoch+1}.pth')
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': avg_epoch_loss,
            }, checkpoint_path)
            print(f'Checkpoint saved to {checkpoint_path}')

    # 최종 모델 저장
    final_checkpoint_path = os.path.join(args.output_dir, 'model_final.pth')
    torch.save({
        'epoch': args.epochs,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': avg_epoch_loss,
    }, final_checkpoint_path)
    print(f'Final model saved to {final_checkpoint_path}')


In [None]:
# 3. 학습 파라미터 설정 (argparse 대신 직접 설정)
class Args:
    def __init__(self):
        self.output_dir = '/content/drive/My Drive/TACO_Fasterrcnn_checkpoints'  # 'checkpoints' 폴더를 생성하거나 사용

        # 폴더가 존재하지 않으면 생성
        if not os.path.exists(self.output_dir):
            os.makedirs(self.output_dir)

        self.data_dir = data_path  # TACO 데이터셋 경로
        self.num_rounds = 9            # 학습할 round 번호
        self.batch_size = 2
        self.epochs = 10
        self.lr = 0.005
        self.save_freq = 1


args = Args()

# 4. 출력 디렉토리 생성
os.makedirs(args.output_dir, exist_ok=True)

In [None]:
train(args)