In [None]:

import os
import random
import xml.etree.ElementTree as ET
import numpy as np
import torch
import torch.utils.data
import torchvision
from torchvision.models.detection import maskrcnn_resnet50_fpn
from torchvision.transforms import functional as F
from PIL import Image, ImageDraw
import matplotlib.pyplot as plt

class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, images_path, annot_path, classes, transforms=None, is_train=True, image_mode='RGB'):
        self.images_path = images_path
        self.annot_path = annot_path
        self.transforms = transforms
        self.classes = classes
        self.is_train = is_train
        self.image_mode = image_mode  # Add image_mode argument
        
        self.files = [file for file in os.listdir(images_path)]
        random.shuffle(self.files)
        
        split = 0.7    
        split_index = int(np.floor(len(self.files) * split))
        
        if self.is_train:
            self.files = self.files[:split_index]
        else:
            self.files = self.files[split_index:]

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        filename = self.files[idx]
        image_id = filename.split('.')[0]
        img_path = os.path.join(self.images_path, filename)
        ann_path = os.path.join(self.annot_path, image_id + '.xml')

        try:
            # Open the image based on the specified mode
            if self.image_mode == 'RGB':
                img = Image.open(img_path).convert("RGB")
            else:
                img = Image.open(img_path).convert("L")
            
            boxes, labels, masks = self.extract_boxes(ann_path, img.size)

            boxes = torch.as_tensor(boxes, dtype=torch.float32)
            labels = torch.as_tensor(labels, dtype=torch.int64)
            masks = torch.as_tensor(masks, dtype=torch.uint8)

            target = {}
            target["boxes"] = boxes
            target["labels"] = labels
            target["masks"] = masks
            target["image_id"] = torch.tensor([idx])
            target["area"] = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
            target["iscrowd"] = torch.zeros((len(labels),), dtype=torch.int64)

            if self.transforms is not None:
                img = self.transforms(img)
            
            return img, target, img_path  # Return img_path for visualization
        except Exception as e:
            print(f"Error processing file {filename}: {e}")
            raise

    def extract_boxes(self, filename, image_size):
        tree = ET.parse(filename)
        root = tree.getroot()
        
        boxes = []
        labels = []
        masks = []
        
        width, height = image_size
        
        for obj in root.findall('.//object'):
            label = obj.find('name').text.lower().strip()
            labels.append(self.classes.index(label))
            
            box = obj.find('bndbox')
            xmin = int(float(box.find('xmin').text))
            ymin = int(float(box.find('ymin').text))
            xmax = int(float(box.find('xmax').text))
            ymax = int(float(box.find('ymax').text))
            boxes.append([xmin, ymin, xmax, ymax])
            
            mask = Image.new('L', (width, height), 0)
            mask_draw = ImageDraw.Draw(mask)
            mask_draw.rectangle([xmin, ymin, xmax, ymax], outline=1, fill=1)
            masks.append(np.array(mask))
        
        return boxes, labels, masks

def get_transform(train):
    transforms = [torchvision.transforms.ToTensor()]
    if train:
        transforms.append(torchvision.transforms.RandomHorizontalFlip(0.5))
    return torchvision.transforms.Compose(transforms)

def collate_fn(batch):
    return tuple(zip(*batch))

# Paths to your images and annotations
images_path = r'C:\Users\brunolopez\Downloads\komi_analysis_pascal\JPEGImages'
annot_path = r'C:\Users\brunolopez\Downloads\komi_analysis_pascal\Annotations'
classes = ['background', 'face']  # Add more classes as needed

# Create datasets and data loaders
dataset = CustomDataset(images_path, annot_path, classes, transforms=get_transform(train=True), image_mode='L')
dataset_test = CustomDataset(images_path, annot_path, classes, transforms=get_transform(train=False), image_mode='L')

# Split the dataset in train and test set
indices = torch.randperm(len(dataset)).tolist()
dataset = torch.utils.data.Subset(dataset, indices[:-50])
dataset_test = torch.utils.data.Subset(dataset_test, indices[-50:])

# Define training and validation data loaders
data_loader = torch.utils.data.DataLoader(
    dataset, batch_size=2, shuffle=True, num_workers=0, collate_fn=collate_fn)

data_loader_test = torch.utils.data.DataLoader(
    dataset_test, batch_size=1, shuffle=False, num_workers=0, collate_fn=collate_fn)

# Get the Mask R-CNN model
model = maskrcnn_resnet50_fpn(pretrained=True)
num_classes = len(classes)
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(in_features, num_classes)
in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
hidden_layer = 256
model.roi_heads.mask_predictor = torchvision.models.detection.mask_rcnn.MaskRCNNPredictor(in_features_mask, hidden_layer, num_classes)

# Move model to the right device
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(f"Using device: {device}")
model.to(device)

# Training function
def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq):
    model.train()
    for images, targets, _ in data_loader:
        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
        loss_dict = model(images, targets)
        losses = sum(loss for loss in loss_dict.values())
        optimizer.zero_grad()
        losses.backward()
        optimizer.step()

# Optimizer
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)

# Train for 10 epochs
num_epochs = 10
for epoch in range(num_epochs):
    train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=10)

# Save the model
torch.save(model.state_dict(), "komi_maskrcnn.pth")

# Inference
model.eval()
for images, targets, img_paths in data_loader_test:
    images = list(img.to(device) for img in images)
    with torch.no_grad():
        outputs = model(images)
    
    # Access the image_mode attribute from the original dataset
    image_mode = data_loader_test.dataset.dataset.image_mode

    for i, output in enumerate(outputs):
        img_path = img_paths[i]
        original_img = Image.open(img_path).convert("L")  # Open original image for visualization
        mask_img = np.zeros_like(np.array(original_img))

        # Create mask image from the output masks
        for mask in output['masks']:
            mask = mask[0, :, :].cpu().numpy()
            mask_img = np.maximum(mask_img, mask * 255)
        
        # Plot side-by-side
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 10))
        
        # Plot original image with bounding boxes
        ax1.imshow(original_img, cmap='gray')
        for box in output['boxes']:
            box = box.cpu().numpy()
            ax1.add_patch(plt.Rectangle((box[0], box[1]), box[2]-box[0], box[3]-box[1], fill=False, edgecolor='red', linewidth=2))
        ax1.set_title("Original Image with Annotations")
        ax1.axis('off')
        
        # Plot mask image with bounding boxes
        ax2.imshow(mask_img, cmap='gray')
        for box in output['boxes']:
            box = box.cpu().numpy()
            ax2.add_patch(plt.Rectangle((box[0], box[1]), box[2]-box[0], box[3]-box[1], fill=False, edgecolor='red', linewidth=2))
        ax2.set_title("Mask with Annotations")
        ax2.axis('off')

        plt.show()


In [None]:
import os
import random
import xml.etree.ElementTree as ET
import numpy as np
import torch
import torch.utils.data
import torchvision
from torchvision.models.detection import maskrcnn_resnet50_fpn
from torchvision.transforms import functional as F
from PIL import Image, ImageDraw
import matplotlib.pyplot as plt
import matplotlib.cm as cm

class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, images_path, annot_path, classes, transforms=None, is_train=True, image_mode='RGB'):
        self.images_path = images_path
        self.annot_path = annot_path
        self.transforms = transforms
        self.classes = classes
        self.is_train = is_train
        self.image_mode = image_mode  # Add image_mode argument
        
        self.files = [file for file in os.listdir(images_path)]
        random.shuffle(self.files)
        
        split = 0.7    
        split_index = int(np.floor(len(self.files) * split))
        
        if self.is_train:
            self.files = self.files[:split_index]
        else:
            self.files = self.files[split_index:]

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        filename = self.files[idx]
        image_id = filename.split('.')[0]
        img_path = os.path.join(self.images_path, filename)
        ann_path = os.path.join(self.annot_path, image_id + '.xml')

        try:
            # Open the image based on the specified mode
            if self.image_mode == 'RGB':
                img = Image.open(img_path).convert("RGB")
            else:
                img = Image.open(img_path).convert("L")
            
            boxes, labels, masks = self.extract_boxes(ann_path, img.size)

            boxes = torch.as_tensor(boxes, dtype=torch.float32)
            labels = torch.as_tensor(labels, dtype=torch.int64)
            masks = torch.as_tensor(masks, dtype=torch.uint8)

            target = {}
            target["boxes"] = boxes
            target["labels"] = labels
            target["masks"] = masks
            target["image_id"] = torch.tensor([idx])
            target["area"] = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
            target["iscrowd"] = torch.zeros((len(labels),), dtype=torch.int64)

            if self.transforms is not None:
                img = self.transforms(img)
            
            return img, target, img_path  # Return img_path for visualization
        except Exception as e:
            print(f"Error processing file {filename}: {e}")
            raise

    def extract_boxes(self, filename, image_size):
        tree = ET.parse(filename)
        root = tree.getroot()
        
        boxes = []
        labels = []
        masks = []
        
        width, height = image_size
        
        for obj in root.findall('.//object'):
            label = obj.find('name').text.lower().strip()
            labels.append(self.classes.index(label))
            
            box = obj.find('bndbox')
            xmin = int(float(box.find('xmin').text))
            ymin = int(float(box.find('ymin').text))
            xmax = int(float(box.find('xmax').text))
            ymax = int(float(box.find('ymax').text))
            boxes.append([xmin, ymin, xmax, ymax])
            
            mask = Image.new('L', (width, height), 0)
            mask_draw = ImageDraw.Draw(mask)
            mask_draw.rectangle([xmin, ymin, xmax, ymax], outline=1, fill=1)
            masks.append(np.array(mask))
        
        return boxes, labels, masks

def get_transform(train):
    transforms = [torchvision.transforms.ToTensor()]
    if train:
        transforms.append(torchvision.transforms.RandomHorizontalFlip(0.5))
    return torchvision.transforms.Compose(transforms)

def collate_fn(batch):
    return tuple(zip(*batch))

def compute_iou(box1, box2):
    x1 = max(box1[0], box2[0])
    y1 = max(box1[1], box2[1])
    x2 = min(box1[2], box2[2])
    y2 = min(box1[3], box2[3])
    
    inter_area = max(0, x2 - x1 + 1) * max(0, y2 - y1 + 1)
    box1_area = (box1[2] - box1[0] + 1) * (box1[3] - box1[1] + 1)
    box2_area = (box2[2] - box2[0] + 1) * (box2[3] - box2[1] + 1)
    
    iou = inter_area / float(box1_area + box2_area - inter_area)
    return iou

# Paths to your images and annotations
images_path = r'C:\Users\brunolopez\Downloads\komi_analysis_pascal\JPEGImages'
annot_path = r'C:\Users\brunolopez\Downloads\komi_analysis_pascal\Annotations'
classes = ['background', 'face']  # Add more classes as needed

# Create datasets and data loaders
dataset = CustomDataset(images_path, annot_path, classes, transforms=get_transform(train=True), image_mode='L')
dataset_test = CustomDataset(images_path, annot_path, classes, transforms=get_transform(train=False), image_mode='L')

# Split the dataset in train and test set
indices = torch.randperm(len(dataset)).tolist()
dataset = torch.utils.data.Subset(dataset, indices[:-50])
dataset_test = torch.utils.data.Subset(dataset_test, indices[-50:])

# Define training and validation data loaders
data_loader = torch.utils.data.DataLoader(
    dataset, batch_size=2, shuffle=True, num_workers=0, collate_fn=collate_fn)

data_loader_test = torch.utils.data.DataLoader(
    dataset_test, batch_size=1, shuffle=False, num_workers=0, collate_fn=collate_fn)

# Get the Mask R-CNN model
model = maskrcnn_resnet50_fpn(pretrained=True)
num_classes = len(classes)
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(in_features, num_classes)
in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
hidden_layer = 256
model.roi_heads.mask_predictor = torchvision.models.detection.mask_rcnn.MaskRCNNPredictor(in_features_mask, hidden_layer, num_classes)

# Move model to the right device
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(f"Using device: {device}")
model.to(device)

# Training function
def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq):
    model.train()
    for images, targets, _ in data_loader:
        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
        loss_dict = model(images, targets)
        losses = sum(loss for loss in loss_dict.values())
        optimizer.zero_grad()
        losses.backward()
        optimizer.step()

# Optimizer
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)

# Train for 10 epochs
num_epochs = 10
for epoch in range(num_epochs):
    train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=10)

# Save the model
torch.save(model.state_dict(), "komi_maskrcnn.pth")

# Inference
model.eval()
for images, targets, img_paths in data_loader_test:
    images = list(img.to(device) for img in images)
    with torch.no_grad():
        outputs = model(images)
    
    # Access the image_mode attribute from the original dataset
    image_mode = data_loader_test.dataset.dataset.image_mode

    for i, output in enumerate(outputs):
        img_path = img_paths[i]
        original_img = Image.open(img_path).convert("L")  # Open original image for visualization
        mask_img = np.zeros_like(np.array(original_img))

        gt_boxes = targets[i]['boxes'].cpu().numpy()

        # Create mask image from the output masks
        for mask in output['masks']:
            mask = mask[0, :, :].cpu().numpy()
            mask_img = np.maximum(mask_img, mask * 255)
        
        # Plot side-by-side
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 10))
        
        # Plot original image with bounding boxes
        ax1.imshow(original_img, cmap='gray')
        for box in output['boxes']:
            box = box.cpu().numpy()
            ious = [compute_iou(box, gt_box) for gt_box in gt_boxes]
            max_iou = max(ious) if ious else 0
            color = cm.viridis(max_iou)[:3]  # Get color from viridis colormap
            ax1.add_patch(plt.Rectangle((box[0], box[1]), box[2]-box[0], box[3]-box[1], fill=False, edgecolor=color, linewidth=2))
        ax1.set_title("Original Image with Annotations")
        ax1.axis('off')
        
        # Plot mask image with bounding boxes
        ax2.imshow(mask_img, cmap='gray')
        for box in output['boxes']:
            box = box.cpu().numpy()
            ious = [compute_iou(box, gt_box) for gt_box in gt_boxes]
            max_iou = max(ious) if ious else 0
            color = cm.viridis(max_iou)[:3]  # Get color from viridis colormap
            ax2.add_patch(plt.Rectangle((box[0], box[1]), box[2]-box[0], box[3]-box[1], fill=False, edgecolor=color, linewidth=2))
        ax2.set_title("Mask with Annotations")
        ax2.axis('off')

        plt.show()
