In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import os
import numpy as np
from sklearn.metrics import jaccard_score, precision_score, recall_score, f1_score
from skimage.exposure import match_histograms

**Implementing Partial Cross-Entropy Loss** - only labeled points would contribute to the loss

In [2]:
class PartialCrossEntropyLoss(nn.Module):
    def __init__(self):
        super(PartialCrossEntropyLoss, self).__init__()
        self.cross_entropy = nn.CrossEntropyLoss(reduction='none')  # No reduction means we get per-pixel loss

    def forward(self, predictions, targets, mask):
        # Apply the mask on the pixel-wise loss to only keep labeled points
        loss = self.cross_entropy(predictions, targets)  # Per-pixel loss
        masked_loss = loss * mask  # Only keep loss for labeled points
        return masked_loss.mean()  # Average the loss

**Dataset Preprocessing** - images and labels in the same directory

In [3]:
class LandcoverDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        self.data_dir = data_dir
        self.transform = transform
        self.image_list = [f for f in os.listdir(data_dir) if '_sat' in f]  # Only take satellite images (jpg)

    def __len__(self):
        return len(self.image_list)  # Number of images

    def __getitem__(self, idx):
        # Get the satellite image and corresponding mask
        image_name = self.image_list[idx]
        mask_name = image_name.replace('_sat', '_mask').replace('.jpg', '.png')

        image_path = os.path.join(self.data_dir, image_name)
        mask_path = os.path.join(self.data_dir, mask_name)

        # Load the satellite image and mask
        image = Image.open(image_path).convert("RGB")
        label = Image.open(mask_path).convert('RGB')  # Convert label to RGB

        if self.transform:
            image = self.transform(image)  # Apply image transformation (resize, normalize, etc.)

        # Resize the label to match the image size (256x256)
        label = label.resize((256, 256), Image.NEAREST)  # Resize mask using nearest neighbor (to keep label values)

        # Convert RGB labels to class indices using the provided class mapping
        label = np.array(label)  # Convert the PIL image to a NumPy array
        label = self.rgb_to_class_indices(label)  # Convert RGB mask to class indices

        # Convert label to tensor (shape: [256, 256])
        label = torch.tensor(label, dtype=torch.long)

        return image, label  # Return the image and 2D label

    def rgb_to_class_indices(self, rgb_mask):

        """Convert an RGB mask to class indices based on predefined colors."""

        # Define the RGB to class index mapping based on the provided class_dict.csv
        rgb_to_class = {
            (0, 255, 255): 0,       # Urban Land
            (255, 255, 0): 1,       # Agriculture Land
            (255, 0, 255): 2,       # Rangeland
            (0, 255, 0): 3,         # Forest Land
            (0, 0, 255): 4,         # Water
            (255, 255, 255): 5,     # Barren Land
            (0, 0, 0): 6            # Unknown
        }

        class_mask = np.zeros((rgb_mask.shape[0], rgb_mask.shape[1]), dtype=np.uint8)

        for rgb, class_index in rgb_to_class.items():
            match = np.all(rgb_mask == rgb, axis=-1)
            class_mask[match] = class_index
        return class_mask



**Data preprocessing** -resize images and convert to tensors

In [4]:
image_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),  # Convert image to tensor
])

# Define paths for the dataset directory (train, test, valid directories)
train_dir = 'land-cover-classification-dataset/train'

# Create datasets and dataloaders
train_dataset = LandcoverDataset(train_dir, transform=image_transform)

train_dataloader = DataLoader(train_dataset, batch_size=6, shuffle=True)

**Sample random point labels**

In [5]:
def sample_point_labels(targets, num_points=10000):
    mask = torch.zeros_like(targets)  # Create a mask with all zeros
    h, w = targets.shape[-2:]  # Get height and width of the label image
    indices = torch.randperm(h * w)[:num_points]  # Randomly pick num_points pixels
    mask.view(-1)[indices] = 1  # Mark the selected points as labeled
    return mask

**Loading pre-trained segmentation model** (DeepLabV3)


In [15]:
model = torch.hub.load('pytorch/vision:v0.10.0', 'deeplabv3_resnet50', pretrained=True)
model.classifier[4] = nn.Conv2d(256, 7, kernel_size=1)  # Adjust model for 7 classes (based on your dataset)

optimizer = optim.Adam(model.parameters(), lr=0.001)  # Adam optimizer
criterion = PartialCrossEntropyLoss()  # Use our custom loss function

Using cache found in C:\Users\ADMIN/.cache\torch\hub\pytorch_vision_v0.10.0


**Training model with point labels**

In [7]:
def train_model(model, dataloader, optimizer, criterion, epochs=22, num_points=10000):
    model.train()
    for epoch in range(epochs):
        running_loss = 0.0
        for images, targets in dataloader:
            optimizer.zero_grad()  # Reset gradients

            masks = torch.stack([sample_point_labels(target) for target in targets])  # Create masks for labeled points
            outputs = model(images)['out']  # Forward pass (get model predictions)

            # Ensure the outputs have the shape (batch_size, num_classes, H, W)
            loss = criterion(outputs, targets, masks)  # Calculate loss based on labeled points

            loss.backward()  # Backward pass
            optimizer.step()  # Update model weights

            running_loss += loss.item()

        print(f'Epoch {epoch+1}, Loss: {running_loss/len(dataloader)}')

**Run the model training**

In [8]:
train_model(model, train_dataloader, optimizer, criterion)

Epoch 1, Loss: 0.21456626516122085
Epoch 2, Loss: 0.1720361623626489
Epoch 3, Loss: 0.15826412233022544
Epoch 4, Loss: 0.14930987033324364
Epoch 5, Loss: 0.1497581056677378
Epoch 6, Loss: 0.14935774776415947
Epoch 7, Loss: 0.13795044254033995
Epoch 8, Loss: 0.13834679642548928
Epoch 9, Loss: 0.13270225452306944
Epoch 10, Loss: 0.12733492981164884
Epoch 11, Loss: 0.12507808762483108
Epoch 12, Loss: 0.12011546049362574
Epoch 13, Loss: 0.12418570082921249
Epoch 14, Loss: 0.12350879752865204
Epoch 15, Loss: 0.10909216086833905
Epoch 16, Loss: 0.11320377007508889
Epoch 17, Loss: 0.11199842288325994
Epoch 18, Loss: 0.1201382380647537
Epoch 19, Loss: 0.11169733231266339
Epoch 20, Loss: 0.10709816026381958
Epoch 21, Loss: 0.10782228161891301
Epoch 22, Loss: 0.10534701744715373


In [13]:
# After training, save the model
torch.save(model, 'trained_landcover_model_2.pth')


In [14]:
torch.save(model.state_dict(), 'trained_model_weights_2.pth')