In [1]:
import zipfile
from PIL import Image
import io
import os
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader, Subset
from torchvision import transforms
from tqdm import tqdm


In [None]:
# Dataset class for loading images and masks
class CustomDataset(Dataset):
    def __init__(self, images_path, masks_path, transform=None):
        self.images_path = images_path
        self.masks_path = masks_path
        self.transform = transform
        self.images = sorted(os.listdir(images_path))  # Ensure the images are sorted
        self.masks = sorted(os.listdir(masks_path))    # Ensure the masks are sorted
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, index):
        img_name = self.images[index]
        mask_name = self.masks[index]
        img_path = os.path.join(self.images_path, img_name)
        mask_path = os.path.join(self.masks_path, mask_name)

        image = Image.open(img_path).convert('RGB')
        mask = Image.open(mask_path).convert('L')

        if self.transform:
            image = self.transform(image)
            mask = self.transform(mask)

        return image, mask


In [None]:

# Define the transform for resizing and converting images to tensors
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

# Paths to your subfolders
images_path = './images'
masks_path = './masks'

# Prepare dataset and dataloaders
dataset = CustomDataset(images_path, masks_path, transform=transform)
train_size = int(0.9 * len(dataset))
test_size = len(dataset) - train_size

In [None]:
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)


In [None]:
# Define the MobileNetV2 U-Net Model
class mobilenetV2_Unet(nn.Module):
    def __init__(self):
        super(mobilenetV2_Unet, self).__init__()
        
        # Encoder (MobileNetV2-style) - simplified for the purpose of this example
        self.Conv2d_1 = nn.Conv2d(3, 32, 3, padding=1)
        self.Bottleneck_1 = nn.Sequential(
            nn.Conv2d(32, 32, kernel_size=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32, 32, groups=32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32, 16, kernel_size=1),
            nn.BatchNorm2d(16)
        )
        self.Bottleneck_2 = nn.Sequential(
            nn.Conv2d(16, 96, kernel_size=1),
            nn.BatchNorm2d(96),
            nn.ReLU(),
            nn.Conv2d(96, 96, groups=96, kernel_size=3, padding=1, stride=2),
            nn.BatchNorm2d(96),
            nn.ReLU(),
            nn.Conv2d(96, 24, kernel_size=1),
            nn.BatchNorm2d(24)
        )
        
        # Decoder (U-Net-style)
        self.up_ConvTranspose2d_1 = nn.ConvTranspose2d(24, 96, kernel_size=2, stride=2)
        self.Sequential_1 = nn.Sequential(
            nn.Conv2d(96, 48, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(48, 48, kernel_size=3, padding=1),
            nn.ReLU()
        )
        self.up_ConvTranspose2d_2 = nn.ConvTranspose2d(48, 24, kernel_size=2, stride=2)
        self.Sequential_2 = nn.Sequential(
            nn.Conv2d(48, 24, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(24, 24, kernel_size=3, padding=1),
            nn.ReLU()
        )
        self.up_ConvTranspose2d_3 = nn.ConvTranspose2d(24, 12, kernel_size=2, stride=2)
        self.Sequential_3 = nn.Sequential(
            nn.Conv2d(24, 12, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(12, 12, kernel_size=3, padding=1),
            nn.ReLU()
        )
        self.up_ConvTranspose2d_4 = nn.ConvTranspose2d(12, 1, kernel_size=2, stride=2)
        self.final_conv = nn.Conv2d(1, 1, kernel_size=1)
        
    def forward(self, x):
        x1 = self.Conv2d_1(x)
        x2 = self.Bottleneck_1(x1)
        x3 = self.Bottleneck_2(x2)

        # Decoder
        x4 = self.up_ConvTranspose2d_1(x3)
        x5 = torch.cat([x4, x2], 1)  # Skip connection
        x6 = self.Sequential_1(x5)

        x7 = self.up_ConvTranspose2d_2(x6)
        x8 = torch.cat([x7, x1], 1)  # Skip connection
        x9 = self.Sequential_2(x8)

        x10 = self.up_ConvTranspose2d_3(x9)
        x11 = self.Sequential_3(x10)

        x12 = self.up_ConvTranspose2d_4(x11)
        x13 = self.final_conv(x12)

        return x13

In [None]:

# Model, optimizer, and loss function
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = mobilenetV2_Unet().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_fn = nn.BCEWithLogitsLoss()

# Training loop
epochs = 5
for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    for i, (inputs, labels) in enumerate(tqdm(train_loader, desc=f'Epoch {epoch+1}/{epochs}')):
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_fn(outputs.squeeze(1), labels.squeeze(1))  # Binary cross-entropy loss
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f'Epoch {epoch+1} Loss: {running_loss / len(train_loader)}')


In [None]:
# Save the model
torch.save(model.state_dict(), 'mobilenetv2_unet.pth')

# Evaluation loop
model.eval()
with torch.no_grad():
    for i, (inputs, labels) in enumerate(test_loader):
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        
        # Visualize some test results
        if i == 0:  # Just show the first batch
            fig, axes = plt.subplots(1, 3, figsize=(12, 4))
            axes[0].imshow(inputs[0].cpu().permute(1, 2, 0))
            axes[0].set_title("Original Image")
            axes[0].axis("off")
            axes[1].imshow(outputs[0].cpu().squeeze(), cmap='gray')
            axes[1].set_title("Predicted Mask")
            axes[1].axis("off")
            axes[2].imshow(labels[0].cpu().squeeze(), cmap='gray')
            axes[2].set_title("Ground Truth Mask")
            axes[2].axis("off")
            plt.show()
        break