In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
from PIL import Image
import os
import glob
import numpy as np
import random

In [2]:
class SementicSegmentationDrone(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None, tile_size=512):
        self.images = sorted(glob.glob(os.path.join(image_dir, "*.jpg")))
        self.masks = sorted(glob.glob(os.path.join(mask_dir, "*.png")))
        self.transform = transform
        self.tile_size = tile_size

    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        image = Image.open(self.images[idx])
        mask = Image.open(self.masks[idx])

        image = image.resize((self.tile_size, self.tile_size))
        mask = mask.resize((self.tile_size, self.tile_size), Image.NEAREST)

        mask = np.array(mask)

        if image.size[0] > self.tile_size or image.size[1] > self.tile_size:
            image_tiles, mask_tiles = self.split_into_tiles(image, mask)
            idx_tile = random.randint(0, len(image_tiles) - 1)
            image, mask = image_tiles[idx_tile], mask_tiles[idx_tile]

        if self.transform:
            image = self.transform(image)
            mask = torch.tensor(mask, dtype=torch.long)

        return image, mask
    
    def split_into_tiles(self, image, mask):
        image_width, image_height = image.size
        image_tiles = []
        mask_tiles = []

        for i in range(0, image_width, self.tile_size):
            for j in range(0, image_height, self.tile_size):
                image_tile = image.crop((i, j, min(i+self.tile_size, image_width), min(j+self.tile_size, image_height)))
                mask_pil = Image.fromarray(mask) 
                mask_tile = mask_pil.crop((i, j, min(i+self.tile_size, image_width), min(j+self.tile_size, image_height)))
                image_tiles.append(image_tile)
                mask_tiles.append(np.array(mask_tile))

        return image_tiles, mask_tiles

In [3]:
directory = {
    "train_images": "advanced_data/x_train",
    "train_masks": "advanced_data/y_train",
    "val_images": "advanced_data/x_valid",
    "val_masks": "advanced_data/y_valid",
    "test_images": "advanced_data/x_test",
    "test_masks": "advanced_data/y_test"
}

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

train_data = SementicSegmentationDrone(directory["train_images"], directory["train_masks"], transform=transform, tile_size=512)
valid_data = SementicSegmentationDrone(directory["val_images"], directory["val_masks"], transform=transform, tile_size=512)
test_data = SementicSegmentationDrone(directory["test_images"], directory["test_masks"], transform=transform, tile_size=512)

train_dataloader = DataLoader(train_data, batch_size=4, shuffle=True)
valid_dataloader = DataLoader(valid_data, batch_size=4, shuffle=False)
test_dataloader = DataLoader(test_data, batch_size=4, shuffle=False)

In [4]:
import torch
import torch.nn as nn
import torchvision.models.segmentation as models

# Load the DeepLabV3 model with no pre-trained weights
base_model = models.deeplabv3_resnet101(pretrained=False, weights_backbone=None)

# Load your custom pre-trained model weights
state_dict = torch.load('/home/almon004/DroneSegmentationModel/deeplabv3_model/deeplabv3_resnet101.pth')

# Filter out the auxiliary classifier keys
state_dict = {k: v for k, v in state_dict.items() if 'aux_classifier' not in k}

# Load the filtered state dict into the model
base_model.load_state_dict(state_dict)

# Modify the classifier layer to match your number of classes
num_classes = 24
base_model.classifier[4] = nn.Conv2d(256, num_classes, kernel_size=1)

# Move the model to the appropriate device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
base_model = base_model.to(device)

# Set the model to evaluation mode
base_model.eval()



DeepLabV3(
  (backbone): IntermediateLayerGetter(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Se

In [5]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(base_model.parameters(), lr=0.001)

In [6]:
epochs = 10
for epoch in range(epochs):
    base_model.train()
    run_loss = 0.0

    for images, masks in train_dataloader:
        images, masks = images.to(device), masks.to(device)

        optimizer.zero_grad()

        outputs = base_model(images)['out']

        loss = criterion(outputs, masks)
        loss.backward()

        optimizer.step()

        run_loss += loss.item()
    print(f'Epoch {epoch+1}/{epochs}, Loss: {run_loss/len(train_dataloader)}')

    base_model.eval()
    valid_loss = 0.0
    with torch.no_grad():
        for images, masks in valid_dataloader:
            images, masks = images.to(device), masks.to(device)
            outputs = base_model(images)['out']
            loss = criterion(outputs, masks)
            valid_loss += loss.item()

    print(f'Validation Loss after Epoch {epoch+1}: {valid_loss/len(valid_dataloader)}')

Epoch 1/10, Loss: 1.8524987374033246
Validation Loss after Epoch 1: 1.7307247459888457
Epoch 2/10, Loss: 1.4131367359842573
Validation Loss after Epoch 2: 1.6625065118074418
Epoch 3/10, Loss: 1.2521659093243735
Validation Loss after Epoch 3: 1.1425019711256028
Epoch 4/10, Loss: 1.16347639134952
Validation Loss after Epoch 4: 1.3327440917491913
Epoch 5/10, Loss: 1.0952675938606262
Validation Loss after Epoch 5: 1.017799898982048
Epoch 6/10, Loss: 1.0104293993541174
Validation Loss after Epoch 6: 1.032787349820137
Epoch 7/10, Loss: 0.9507571177823203
Validation Loss after Epoch 7: 0.855544313788414
Epoch 8/10, Loss: 0.9052147771630968
Validation Loss after Epoch 8: 1.0243705302476882
Epoch 9/10, Loss: 0.8309015853064401
Validation Loss after Epoch 9: 0.8596504718065262
Epoch 10/10, Loss: 0.768984147480556
Validation Loss after Epoch 10: 0.8024978205561638


In [7]:
def calculate_accuracy(preds, labels):
    preds = torch.argmax(preds, dim=1)
    correct = (preds == labels).float()
    return correct.sum() / correct.numel()

base_model.eval()
valid_loss = 0.0
valid_acc = 0.0

with torch.no_grad():
    for images, masks in valid_dataloader:
        images, masks = images.to(device), masks.to(device)
        outputs = base_model(images)['out']
        loss = criterion(outputs, masks)
        valid_loss += loss.item()

        acc = calculate_accuracy(outputs, masks)
        valid_acc += acc.item()

avg_val_loss = valid_loss / len(valid_dataloader)
avg_val_acc = valid_acc / len(valid_dataloader)
print(f'Validation Loss: {avg_val_loss:.4f}, Accuracy: {avg_val_acc:.4f}')

Validation Loss: 0.8025, Accuracy: 0.7539
