In [1]:
import os
os.environ["SM_FRAMEWORK"] = "tf.keras"
os.environ["CUDA_VISIBLE_DEVICES"] = '0' #use GPU with ID=0

import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

In [2]:
def convert_to_binary(img, threshold=128):
    # Convert image to numpy array
    img_array = np.array(img)
    
    # Extract the alpha channel
    alpha_channel = img_array[:, :, 3]
    # Create a binary mask for red and blue pixels
    red_mask = (img_array[:, :, 0] > threshold) & (img_array[:, :, 1] < threshold) & (img_array[:, :, 2] < threshold)
    blue_mask = (img_array[:, :, 0] < threshold) & (img_array[:, :, 1] < threshold) & (img_array[:, :, 2] > threshold)
    
    # Initialize binary image with zeros
    binary_image = np.zeros_like(alpha_channel, dtype=np.uint8)
    
    # Assign class labels
    binary_image[red_mask] = 1
    binary_image[blue_mask] = 2

    
    return binary_image

In [3]:
# load images and convert them to numpy arrays
real_images_dir = 'data/imagery/'
mask_images_dir = 'data/masks/'

x_train = []
y_train = []
x_test = []
y_test = []

for item in os.listdir(real_images_dir):
    if item.endswith('.png') and not item.startswith('test'):
        real_img = Image.open(real_images_dir + item)
        mask_img = Image.open(mask_images_dir + item)
        # Divide the image into 256x256 patches
        real_img_array = np.array(real_img)
        mask_img_array = convert_to_binary(mask_img)
        tiles_real = [real_img_array[x:x+576,y:y+576] for x in range(0,real_img_array.shape[0],576) for y in range(0,real_img_array.shape[1],576)]
        tiles_mask = [mask_img_array[x:x+576,y:y+576] for x in range(0,mask_img_array.shape[0],576) for y in range(0,mask_img_array.shape[1],576)]
        x_train.extend(tiles_real)
        y_train.extend(tiles_mask)

x_train = np.array(x_train)
y_train = np.array(y_train)

# Shuffle the data
permutation = np.random.permutation(len(x_train))
x_train = x_train[permutation]
y_train = y_train[permutation]

# Split the data into training and testing sets
split = int(0.8 * len(x_train))
x_test = x_train[split:]
y_test = y_train[split:]
x_train = x_train[:split]
y_train = y_train[:split]

# Save as numpy arrays
np.save('data/train_images.npy', x_train)
np.save('data/train_masks.npy', y_train)
np.save('data/test_images.npy', x_test)
np.save('data/test_masks.npy', y_test)

In [None]:
fig, ax = plt.subplots(1, 2)
print(y_train[41])
ax[0].imshow(x_train[41])
ax[1].imshow(y_train[41])

In [None]:
import gc
del x_train, y_train, x_test, y_test
gc.collect()

In [31]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
import torchmetrics

# Define the U-Net model
class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        self.encoder1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.encoder2 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.middle = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
        
        self.upconv2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.decoder2 = nn.Sequential(
            nn.Conv2d(256, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
        
        self.upconv1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.decoder1 = nn.Sequential(
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
        
        self.final_conv = nn.Conv2d(64, 1, kernel_size=1)
    
    def forward(self, x):
        enc1 = self.encoder1(x)
        enc2 = self.encoder2(self.pool1(enc1))
        middle = self.middle(self.pool2(enc2))
        dec2 = self.decoder2(torch.cat([self.upconv2(middle), enc2], dim=1))
        dec1 = self.decoder1(torch.cat([self.upconv1(dec2), enc1], dim=1))
        return torch.sigmoid(self.final_conv(dec1))

class SegmentationDataset(Dataset):
    def __init__(self, images_path, masks_path, transform=None):
        self.images = np.load(images_path)
        self.masks = np.load(masks_path)
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        mask = self.masks[idx]
        
        if self.transform:
            image = self.transform(image)
            mask = self.transform(mask)
        
        return image, mask

# Function to preprocess the images
def preprocess_images(images_path, masks_path, batch_size=6, image_size=(576, 576)):
    transform = transforms.Compose([
        transforms.ToTensor()
    ])
    
    dataset = SegmentationDataset(images_path, masks_path, transform=transform)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    
    return dataloader

# Function to post-process the outputs
def postprocess_outputs(outputs):
    masks = []
    for output in outputs:
        output = output.squeeze(0)  # Remove batch dimension
        output = output.detach().cpu().numpy()
        # output = (output > 0.5).astype('uint8')  # Apply threshold
        masks.append(output)
    return masks

# Function to test the model on a batch of images
def test_model_batch(model, dataloader):
    model.eval()
    all_masks = []
    all_image_names = []
    with torch.no_grad():
        for images, image_names in dataloader:
            images = images.to('cuda')
            outputs = model(images)
            outputs = nn.functional.interpolate(outputs, size=(576, 576), mode='bilinear', align_corners=False, recompute_scale_factor=False)
            masks = postprocess_outputs(outputs)
            all_masks.extend(masks)
            all_image_names.extend(image_names)
    return all_masks, all_image_names

# Load the dataset
image_dir = 'data/train_images.npy'
mask_dir = 'data/train_masks.npy'
train_loader = preprocess_images(image_dir, mask_dir, batch_size=2)  # Reduce batch size

# Initialize the model, loss function, and optimizer
model = UNet()
model = model.to('cuda')
# Define Dice Loss
class DiceLoss(nn.Module):
    def __init__(self):
        super(DiceLoss, self).__init__()

    def forward(self, inputs, targets, smooth=1):
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        intersection = (inputs * targets).sum()
        dice = (2. * intersection + smooth) / (inputs.sum() + targets.sum() + smooth)
        return 1 - dice

# Define Jaccard Loss
class JaccardLoss(nn.Module):
    def __init__(self):
        super(JaccardLoss, self).__init__()

    def forward(self, inputs, targets, smooth=1):
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        intersection = (inputs * targets).sum()
        total = (inputs + targets).sum()
        union = total - intersection
        jaccard = (intersection + smooth) / (union + smooth)
        return 1 - jaccard

# Initialize the loss function
criterion = JaccardLoss()  # or JaccardLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)


In [None]:
# Training loop 
num_epochs = 20
losses = []

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for i, (images, masks) in enumerate(train_loader):
        images = images.to('cuda')
        masks = masks.to('cuda')
        
        optimizer.zero_grad()
        outputs = model(images)
        outputs = nn.functional.interpolate(outputs, size=(576, 576), mode='bilinear', align_corners=False)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
    
    epoch_loss = running_loss / len(train_loader)
    losses.append(epoch_loss)
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss}')

# Save the trained model
torch.save(model.state_dict(), 'trained_model.pth')

# Plot the training loss
plt.figure()
plt.plot(range(1, num_epochs + 1), losses, marker='o')
plt.title('Training Loss Over Epochs')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.grid(True)
plt.show()


In [None]:
import torch
print(torch.cuda.memory_summary())
torch.cuda.empty_cache()

In [None]:
# Load the trained model
model = UNet()
model.load_state_dict(torch.load('trained_model.pth'))
model = model.to('cuda')  # Move the model to GPU

# Convert the model to evaluation mode
model.eval()

# Test the model on a batch of images
# masks, image_names = test_model_batch(model, train_loader)

# # Display the results
# images, actual_masks = next(iter(train_loader))
# fig, ax = plt.subplots(1, 3)
# ax[0].imshow(images[0].cpu().numpy().transpose(1, 2, 0))
# ax[1].imshow(masks[0], cmap='gray')
# ax[2].imshow(actual_masks[0].cpu().numpy().squeeze(), cmap='gray')

# Load the test dataset
image_dir = 'data/test_images.npy'
mask_dir = 'data/test_masks.npy'

test_loader = preprocess_images(image_dir, mask_dir)

# Test the model on the test dataset
masks, image_names = test_model_batch(model, test_loader)

# Display the results
images, actual_masks = next(iter(test_loader))
fig, ax = plt.subplots(3, 3, figsize=(15, 15))
ax[0, 0].imshow(images[0].cpu().numpy().transpose(1, 2, 0))
ax[0, 1].imshow(masks[0], cmap='gray')
ax[0, 2].imshow(actual_masks[0].cpu().numpy().squeeze(), cmap='gray')
ax[1, 0].imshow(images[1].cpu().numpy().transpose(1, 2, 0))
ax[1, 1].imshow(masks[1], cmap='gray')
ax[1, 2].imshow(actual_masks[1].cpu().numpy().squeeze(), cmap='gray')
ax[2, 0].imshow(images[2].cpu().numpy().transpose(1, 2, 0))
ax[2, 1].imshow(masks[2], cmap='gray')
ax[2, 2].imshow(actual_masks[2].cpu().numpy().squeeze(), cmap='gray')

# In all masks, compare the number of white pixels in the predicted mask and the actual mask
white_pixels_predicted = 0
white_pixels_actual = 0
for mask, actual_mask in zip(masks, actual_masks):
    white_pixels_predicted += np.sum(mask)
    white_pixels_actual += np.sum(actual_mask.cpu().numpy().squeeze())

print(f'Number of white pixels in predicted masks: {white_pixels_predicted}')
print(f'Number of white pixels in actual masks: {white_pixels_actual}')