In [1]:
# Import necessary libraries
import os
from PIL import Image
import torch
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models


In [2]:
# Define the custom dataset class for segmentation
class SegmentationDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None, target_transform=None):
        """
        Args:
            image_dir (str): Directory with all the images.
            mask_dir (str): Directory with all the masks.
            transform (callable, optional): Optional transform to be applied on the images.
            target_transform (callable, optional): Optional transform to be applied on the masks.
        """
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.target_transform = target_transform
        self.images = os.listdir(image_dir)

    def __len__(self):
        # Return the total number of images
        return len(self.images)

    def __getitem__(self, idx):
        """
        Args:
            idx (int): Index of the image to be fetched.
        
        Returns:
            tuple: (image, mask) where image is the input image and mask is the segmentation mask.
        """
        # Get the image name from the list
        img_name = self.images[idx]
        # Construct the full path for the image
        img_path = os.path.join(self.image_dir, img_name)
        
        # Find the corresponding mask file
        mask_path = None
        for ext in ['.png', '.jpg', '.jpeg']:
            potential_mask_path = os.path.join(self.mask_dir, img_name.replace('.jpg', ext).replace('.jpeg', ext))
            if os.path.isfile(potential_mask_path):
                mask_path = potential_mask_path
                break
        
        # Raise an error if the mask file is not found
        if mask_path is None:
            raise FileNotFoundError(f"Mask file not found for image: {img_path}")

        # Open the image and mask
        image = Image.open(img_path).convert("RGB")
        mask = Image.open(mask_path).convert("L")
        
        # Apply transformations if specified
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            mask = self.target_transform(mask)
        
        # Squeeze the mask to remove the channel dimension and convert to long tensor
        mask = mask.squeeze(0).long()
        
        return image, mask

In [3]:
# Define the transformations for images and masks
image_transform = transforms.Compose([
    transforms.Resize((128, 128)),  # Resize images to 128x128
    transforms.ToTensor()           # Convert images to PyTorch tensors
])

mask_transform = transforms.Compose([
    transforms.Resize((128, 128)),  # Resize masks to 128x128
    transforms.ToTensor(),          # Convert masks to PyTorch tensors
    transforms.Lambda(lambda x: torch.squeeze(x, 0).long())  # Squeeze channel dimension and convert to long tensor
])

In [4]:
# Create dataset objects for training and validation sets
train_dataset = SegmentationDataset('data/train/images', 'data/train/masks', transform=image_transform, target_transform=mask_transform)
val_dataset = SegmentationDataset('data/val/images', 'data/val/masks', transform=image_transform, target_transform=mask_transform)

# Create DataLoader objects for batching and shuffling
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False)

# Load a pre-trained DeepLabV3 model
model = models.segmentation.deeplabv3_resnet50(pretrained=True)




In [5]:
# Modify the classifier to match the number of classes in your dataset
num_classes = 2  # Assuming two classes: background and rocks
model.classifier[4] = torch.nn.Conv2d(256, num_classes, kernel_size=(1, 1), stride=(1, 1))

# Move the model to the GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Define the loss function
criterion = torch.nn.CrossEntropyLoss()

# Define the optimizer
optimizer = optim.Adam(model.parameters(), lr=0.001)


In [6]:
# Training loop
num_epochs = 25

for epoch in range(num_epochs):
    model.train()  # Set the model to training mode
    running_loss = 0.0
    
    for images, masks in train_loader:
        images = images.to(device)
        masks = masks.to(device)

        optimizer.zero_grad()  # Zero the parameter gradients
        
        outputs = model(images)['out']  # Forward pass
        loss = criterion(outputs, masks)  # Compute loss
        loss.backward()  # Backward pass
        optimizer.step()  # Update weights
        
        running_loss += loss.item()
    
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}")

    # Validation step (optional but recommended)
    model.eval()  # Set the model to evaluation mode
    val_loss = 0.0
    with torch.no_grad():
        for images, masks in val_loader:
            images = images.to(device)
            masks = masks.to(device)
            
            outputs = model(images)['out']
            loss = criterion(outputs, masks)
            
            val_loss += loss.item()
    
    print(f"Validation Loss: {val_loss/len(val_loader):.4f}")

Epoch [1/25], Loss: 0.7623
Validation Loss: 0.5779
Epoch [2/25], Loss: 0.6942
Validation Loss: 1.1191
Epoch [3/25], Loss: 0.6178
Validation Loss: 0.9698
Epoch [4/25], Loss: 0.5349
Validation Loss: 0.0216
Epoch [5/25], Loss: 0.4636
Validation Loss: 0.0236
Epoch [6/25], Loss: 0.3688
Validation Loss: 0.1501
Epoch [7/25], Loss: 0.3096
Validation Loss: 0.2765
Epoch [8/25], Loss: 0.2613
Validation Loss: 0.3547
Epoch [9/25], Loss: 0.2209
Validation Loss: 0.3120
Epoch [10/25], Loss: 0.1897
Validation Loss: 0.2500
Epoch [11/25], Loss: 0.1637
Validation Loss: 0.1928
Epoch [12/25], Loss: 0.1418
Validation Loss: 0.1513
Epoch [13/25], Loss: 0.1241
Validation Loss: 0.1349
Epoch [14/25], Loss: 0.1097
Validation Loss: 0.1334
Epoch [15/25], Loss: 0.0976
Validation Loss: 0.1321
Epoch [16/25], Loss: 0.0874
Validation Loss: 0.1293
Epoch [17/25], Loss: 0.0778
Validation Loss: 0.1258
Epoch [18/25], Loss: 0.0701
Validation Loss: 0.1202
Epoch [19/25], Loss: 0.0636
Validation Loss: 0.1119
Epoch [20/25], Loss: 

In [7]:
# Save the trained model to a file
torch.save(model.state_dict(), 'deeplabv3_rock_detection.pth')
print("Model saved to 'deeplabv3_rock_detection.pth'")


Model saved to 'deeplabv3_rock_detection.pth'


In [8]:
# Load the trained model for inference (example usage)
model.load_state_dict(torch.load('deeplabv3_rock_detection.pth'))
model.eval()
print("Model loaded for inference")

Model loaded for inference
