In [1]:
import os
import cv2
import numpy as np
import torch
import torchvision
from torchvision.models.detection import maskrcnn_resnet50_fpn
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms as T
from PIL import Image

In [2]:
class CustomDataset(Dataset):
    def __init__(self, root, transform=None, device='cpu'):  # Thêm device vào constructor
        self.root = root
        self.transform = transform
        self.imgs = list(sorted(os.listdir(os.path.join(root, "images"))))
        self.masks = list(sorted(os.listdir(os.path.join(root, "masks"))))
        self.device = device # Lưu device

    def __getitem__(self, idx):
        img_path = os.path.join(self.root, "images", self.imgs[idx])
        mask_path = os.path.join(self.root, "masks", self.masks[idx])

        img = Image.open(img_path).convert("RGB")
        mask = Image.open(mask_path).convert("L")
        mask = np.array(mask)

        num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(mask, connectivity=8)

        masks = []
        boxes = []
        for i in range(1, num_labels):
            x, y, w, h, area = stats[i]
            if area > 0:
                boxes.append([x, y, x + w, y + h])
                masks.append((labels == i).astype(np.uint8))


        # Chuyển boxes và masks thành tensors trên device đã chọn NGAY LẬP TỨC
        if len(boxes) > 0:  # Quan trọng: Xử lý trường hợp không có đối tượng
            boxes = torch.as_tensor(boxes, dtype=torch.float32, device=self.device)
            labels = torch.ones((len(boxes),), dtype=torch.int64, device=self.device)
            masks = torch.as_tensor(np.array(masks), dtype=torch.uint8, device=self.device)
            area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
            iscrowd = torch.zeros((len(boxes),), dtype=torch.int64, device=self.device)
        else: # Trường hợp không có đối tượng nào -> tensor rỗng
            boxes = torch.empty((0, 4), dtype=torch.float32, device=self.device)
            labels = torch.empty((0,), dtype=torch.int64, device=self.device)
            masks = torch.empty((0, mask.shape[0], mask.shape[1]), dtype=torch.uint8, device=self.device)
            area = torch.empty((0,), dtype=torch.float32, device=self.device)
            iscrowd = torch.zeros((0,), dtype=torch.int64, device=self.device)


        image_id = torch.tensor([idx], device=self.device) # Cũng chuyển image_id sang device


        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        target["masks"] = masks
        target["image_id"] = image_id
        target["area"] = area
        target["iscrowd"] = iscrowd

        if self.transform is not None:
            img = self.transform(img)  # Transform đã chuyển img thành tensor

        return img, target

    def __len__(self):
        return len(self.imgs)

In [3]:
def get_transform(train):
    transforms = []
    transforms.append(T.ToTensor())
    if train:
        transforms.append(T.RandomHorizontalFlip(0.5))
    return T.Compose(transforms)

In [4]:
def get_model_instance_segmentation(num_classes):
    model = maskrcnn_resnet50_fpn(pretrained=True)
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 256
    model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask,
                                                        hidden_layer,
                                                        num_classes)
    return model

In [5]:
def train_model(model, data_loader_train, data_loader_val, optimizer, device, num_epochs=30):
    print(f"Training on device: {device}")
    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0

        i = 0
        for images, targets in data_loader_train:
            images = list(image.to(device) for image in images)
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

            optimizer.zero_grad()
            loss_dict = model(images, targets)
            losses = sum(loss for loss in loss_dict.values())
            train_loss += losses.item()

            losses.backward()
            optimizer.step()

            if i % 10 == 0:
                print(f'Epoch: {epoch + 1}, Batch: {i}, Loss: {losses.item():.4f}')
            i += 1

        # Validation loop
        model.eval()
        val_loss = 0.0

        with torch.no_grad():
            for images, targets in data_loader_val:
                images = list(image.to(device) for image in images)
                targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

                loss_dict = model(images, targets)

                # KIỂM TRA XEM loss_dict CÓ PHẢI LÀ DICTIONARY KHÔNG
                if isinstance(loss_dict, dict):
                    losses = sum(loss for loss in loss_dict.values())
                    val_loss += losses.item()
                # Nếu không phải dictionary, bỏ qua (vì không có loss)

        avg_train_loss = train_loss / len(data_loader_train)
        avg_val_loss = val_loss / len(data_loader_val)

        print(f"Epoch {epoch + 1}: Train Loss: {avg_train_loss:.4f}")

        print('-'*20)

        # Save the model
        torch.save(model.state_dict(), f'/kaggle/working/model_epoch_{epoch}.pth')

In [6]:
# Thiết lập device
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(f"Using device: {device}")

Using device: cuda


In [7]:
# Tạo thư mục nếu chúng không tồn tại
os.makedirs('/kaggle/input/mask-dts/dataset/train/images', exist_ok=True)
os.makedirs('/kaggle/input/mask-dts/dataset/train/masks', exist_ok=True)
os.makedirs('/kaggle/input/mask-dts/dataset/val/images', exist_ok=True)
os.makedirs('/kaggle/input/mask-dts/dataset/val/masks', exist_ok=True)
os.makedirs('/kaggle/working/model_weights', exist_ok=True) # Thêm dòng này

# Tạo datasets, truyền device vào
dataset_train = CustomDataset('/kaggle/input/mask-dts/dataset/train', get_transform(train=True), device=device)
dataset_val = CustomDataset('/kaggle/input/mask-dts/dataset/val', get_transform(train=False), device=device)


# Tạo data loaders
# Sử dụng num_workers để tải
num_workers = os.cpu_count() if os.cpu_count() is not None else 0  # Sử dụng số lượng cores CPU

data_loader_train = DataLoader(dataset_train, batch_size=2, shuffle=True, collate_fn=lambda x: tuple(zip(*x)))
data_loader_val = DataLoader(dataset_val, batch_size=1, shuffle=False, collate_fn=lambda x: tuple(zip(*x)))


# Tạo model
num_classes = 2  # Background + 1 class (Bishop)
model = get_model_instance_segmentation(num_classes)
model.to(device)

# Tạo optimizer
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)

Downloading: "https://download.pytorch.org/models/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth" to /root/.cache/torch/hub/checkpoints/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth
100%|██████████| 170M/170M [00:00<00:00, 198MB/s]  


In [8]:
train_model(model, data_loader_train, data_loader_val, optimizer, device)

Training on device: cuda
Epoch: 1, Batch: 0, Loss: 4.4809
Epoch: 1, Batch: 10, Loss: 0.8894
Epoch: 1, Batch: 20, Loss: 0.8143
Epoch: 1, Batch: 30, Loss: 0.7659
Epoch: 1, Batch: 40, Loss: 0.7725
Epoch: 1, Batch: 50, Loss: 0.7012
Epoch: 1, Batch: 60, Loss: 0.6534
Epoch: 1, Batch: 70, Loss: 0.6371
Epoch: 1, Batch: 80, Loss: 0.7365
Epoch: 1, Batch: 90, Loss: 0.6524
Epoch: 1, Batch: 100, Loss: 0.5968
Epoch: 1, Batch: 110, Loss: 0.6046
Epoch: 1, Batch: 120, Loss: 0.4175
Epoch: 1, Batch: 130, Loss: 0.4659
Epoch: 1, Batch: 140, Loss: 0.4299
Epoch: 1, Batch: 150, Loss: 0.4955
Epoch: 1, Batch: 160, Loss: 0.7073
Epoch: 1, Batch: 170, Loss: 0.7239
Epoch: 1, Batch: 180, Loss: 0.4992
Epoch: 1, Batch: 190, Loss: 0.4992
Epoch: 1, Batch: 200, Loss: 0.3534
Epoch: 1, Batch: 210, Loss: 0.6470
Epoch: 1, Batch: 220, Loss: 0.7737
Epoch: 1, Batch: 230, Loss: 0.5027
Epoch: 1, Batch: 240, Loss: 0.6511
Epoch: 1, Batch: 250, Loss: 0.5185
Epoch: 1, Batch: 260, Loss: 0.3981
Epoch: 1, Batch: 270, Loss: 0.4808
Epoch: