In [1]:
import os
import torch
from torch.utils.data import Dataset, DataLoader, ConcatDataset
import torchvision.transforms as transforms
import numpy as np 
import pandas as pd 



# 数据准备

数据读取

In [2]:
#SSDD数据路径读取
SSDD_train_inshore_img_path = r"/kaggle/input/ship-detect-private/SSDD/SSDD/JPEGImages_train_inshore"
SSDD_train_offshore_img_path = r"/kaggle/input/ship-detect-private/SSDD/SSDD/JPEGImages_train_offshore"
SSDD_test_inshore_img_path = r"/kaggle/input/ship-detect-private/SSDD/SSDD/JPEGImages_test_inshore"
SSDD_test_offshore_img_path = r"/kaggle/input/ship-detect-private/SSDD/SSDD/JPEGImages_test_offshore"
SSDD_train_label_path = r"/kaggle/input/ship-detect-private/SSDD_labels/SSDD_labels/train_labels"
SSDD_test_inshore_label_path = r"/kaggle/input/ship-detect-private/SSDD_labels/SSDD_labels/test_inshore_labels"
SSDD_test_offshore_label_path = r"/kaggle/input/ship-detect-private/SSDD_labels/SSDD_labels/test_offshore_labels"
#SeaShip数据读取
S_img_path = r"/kaggle/input/ship-detect-private/SeaShips(7000)/JPEGImages"
S_label_path = r"/kaggle/input/ship-detect-private/SeaShips_labels/SeaShips_labels/labels"


torch的Dataset数据集类构建

In [None]:
class YOLODataset(Dataset):
    def __init__(self, img_paths, label_paths, transform=None, target_size=640):
        """
        初始化YOLO数据集
        
        Args:
            img_paths: 图像路径列表，可以是单个路径或路径列表
            label_paths: 标签路径列表，可以是单个路径或路径列表
            transform: 图像变换
            target_size: 目标图像尺寸
        """
        self.img_paths = []
        self.label_paths = []
        self.transform = transform
        self.target_size = target_size
        
        # 处理多个路径输入
        if isinstance(img_paths, str):
            img_paths = [img_paths]
        if isinstance(label_paths, str):
            label_paths = [label_paths]
        
        # 收集所有图像和标签文件路径
        for img_path, label_path in zip(img_paths, label_paths):
            if os.path.exists(img_path) and os.path.exists(label_path):
                img_files = [f for f in os.listdir(img_path) if f.endswith(('.jpg', '.jpeg', '.png'))]
                for img_file in img_files:
                    name_without_ext = os.path.splitext(img_file)[0]
                    label_file = os.path.join(label_path, name_without_ext + '.txt')
                    
                    if os.path.exists(label_file):
                        self.img_paths.append(os.path.join(img_path, img_file))
                        self.label_paths.append(label_file)
        
        print(f"Loaded {len(self.img_paths)} samples")
    
    def __len__(self):
        return len(self.img_paths)
    
    def __getitem__(self, idx):
        # 读取图像
        img_path = self.img_paths[idx]
        image = Image.open(img_path).convert('RGB')
        original_size = image.size  # (width, height)
        
        # 读取标签
        label_path = self.label_paths[idx]
        boxes, labels = self.parse_yolo_label(label_path, original_size)
        
        # 应用变换
        if self.transform:
            image = self.transform(image)
        else:
            # 默认转换
            transform = transforms.Compose([
                transforms.Resize((self.target_size, self.target_size)),
                transforms.ToTensor(),
            ])
            image = transform(image)
        
        # 调整边界框到目标尺寸
        scale_x = self.target_size / original_size[0]
        scale_y = self.target_size / original_size[1]
        
        if len(boxes) > 0:
            boxes[:, [0, 2]] *= scale_x  # x坐标缩放
            boxes[:, [1, 3]] *= scale_y  # y坐标缩放
        
        # 转换为tensor
        boxes = torch.FloatTensor(boxes) if len(boxes) > 0 else torch.zeros((0, 4))
        labels = torch.LongTensor(labels) if len(labels) > 0 else torch.zeros((0,))
        
        target = {
            'boxes': boxes,
            'labels': labels,
            'image_id': torch.tensor([idx]),
            'area': (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0]) if len(boxes) > 0 else torch.zeros((0,)),
            'iscrowd': torch.zeros((len(labels),), dtype=torch.int64)
        }
        
        return image, target
    
    def parse_yolo_label(self, label_path, image_size):
        """
        解析YOLO格式的标签文件
        
        Args:
            label_path: 标签文件路径
            image_size: 图像尺寸 (width, height)
            
        Returns:
            boxes: 边界框 [x_min, y_min, x_max, y_max]
            labels: 类别标签
        """
        boxes = []
        labels = []
        
        if os.path.exists(label_path):
            with open(label_path, 'r') as f:
                lines = f.readlines()
            
            for line in lines:
                data = line.strip().split()
                if len(data) >= 5:
                    class_id = int(data[0])
                    x_center = float(data[1])
                    y_center = float(data[2])
                    width = float(data[3])
                    height = float(data[4])
                    
                    # 转换为绝对坐标
                    x_center_abs = x_center * image_size[0]
                    y_center_abs = y_center * image_size[1]
                    width_abs = width * image_size[0]
                    height_abs = height * image_size[1]
                    
                    # 计算边界框坐标 [x_min, y_min, x_max, y_max]
                    x_min = x_center_abs - width_abs / 2
                    y_min = y_center_abs - height_abs / 2
                    x_max = x_center_abs + width_abs / 2
                    y_max = y_center_abs + height_abs / 2
                    
                    boxes.append([x_min, y_min, x_max, y_max])
                    labels.append(class_id)
        
        return np.array(boxes), np.array(labels)

创建数据集实例

In [None]:
# 数据增强变换
def get_train_transforms(target_size=640):
    """训练时的数据增强变换"""
    return transforms.Compose([
        transforms.Resize((target_size, target_size)),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

def get_val_transforms(target_size=640):
    """验证时的变换（无数据增强）"""
    return transforms.Compose([
        transforms.Resize((target_size, target_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

# 创建数据集
def create_datasets(target_size=640):
    """
    创建训练和验证数据集
    """
    # 获取变换
    train_transform = get_train_transforms(target_size)
    val_transform = get_val_transforms(target_size)
    
    # 创建SSDD训练数据集
    ssdd_train_img_paths = [
        SSDD_train_inshore_img_path,
        SSDD_train_offshore_img_path
    ]
    
    ssdd_train_label_paths = [
        SSDD_train_label_path,
        SSDD_train_label_path
    ]
    
    ssdd_train_dataset = YOLODataset(
        img_paths=ssdd_train_img_paths,
        label_paths=ssdd_train_label_paths,
        transform=train_transform,
        target_size=target_size
    )
    
    # 创建SeaShip数据集
    seaship_dataset = YOLODataset(
        img_paths=S_img_path,
        label_paths=S_label_path,
        transform=train_transform,
        target_size=target_size
    )
    
    # 合并数据集
    combined_train_dataset = ConcatDataset([ssdd_train_dataset, seaship_dataset])
    
    # 创建验证数据集（使用SSDD测试集）
    ssdd_val_img_paths = [
        SSDD_test_inshore_img_path,
        SSDD_test_offshore_img_path
    ]
    
    ssdd_val_label_paths = [
        SSDD_test_inshore_label_path,
        SSDD_test_offshore_label_path
    ]
    
    val_dataset = YOLODataset(
        img_paths=ssdd_val_img_paths,
        label_paths=ssdd_val_label_paths,
        transform=val_transform,
        target_size=target_size
    )
    
    return combined_train_dataset, val_dataset

# 自定义collate函数用于批处理
def collate_fn(batch):
    """
    自定义批处理函数，处理不同数量的目标
    """
    images = []
    targets = []
    
    for img, target in batch:
        images.append(img)
        targets.append(target)
    
    images = torch.stack(images, 0)
    return images, targets

# 创建数据加载器
def create_dataloaders(batch_size=16, num_workers=4, target_size=640):
    """
    创建训练和验证数据加载器
    """
    train_dataset, val_dataset = create_datasets(target_size)
    
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        collate_fn=collate_fn,
        pin_memory=True
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        collate_fn=collate_fn,
        pin_memory=True
    )
    
    return train_loader, val_loader


#  构建diffusion模型(待完成)

# 构建DCGAN网络(待完成)

# baseline的基础上进行finetuning