In [None]:
import os
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import matplotlib.pyplot as plt


class ShadowDataset(Dataset):
    def __init__(self, images_dir, masks_dir, transform=None):
        self.images_dir = images_dir
        self.masks_dir = masks_dir
        self.transform = transform
        self.image_filenames = os.listdir(images_dir)

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx):
        image_name = self.image_filenames[idx]
        image_path = os.path.join(self.images_dir, image_name)

        mask_name = image_name.replace(".jpg", ".png")
        mask_path = os.path.join(self.masks_dir, mask_name)

        if not os.path.exists(mask_path):
            print(f"Warning: Mask not found for {image_name}. Expected: {mask_path}")
            return self.__getitem__((idx + 1) % len(self.image_filenames))

        image = Image.open(image_path).convert("RGB")
        mask = Image.open(mask_path).convert("L")

        if self.transform:
            image = self.transform(image)
            mask = self.transform(mask)

        return image, mask


transform = transforms.Compose([transforms.Resize((256, 256)), transforms.ToTensor()])

train_images_dir = "sbu/SBUTrain4KRecoveredSmall/ShadowImages"
train_masks_dir = "sbu/SBUTrain4KRecoveredSmall/ShadowMasks"
test_images_dir = "sbu/SBU-Test/ShadowImages"
test_masks_dir = "sbu/SBU-Test/ShadowMasks"

train_dataset = ShadowDataset(train_images_dir, train_masks_dir, transform)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)

test_dataset = ShadowDataset(test_images_dir, test_masks_dir, transform)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False)

print(f"Number of training images: {len(train_dataset)}")
print(f"Number of testing images: {len(test_dataset)}")

Number of training images: 4085
Number of testing images: 638


In [None]:
import torch.nn.functional as F


class UNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=1):
        super(UNet, self).__init__()
        self.encoder1 = self.conv_block(in_channels, 64)
        self.encoder2 = self.conv_block(64, 128)
        self.encoder3 = self.conv_block(128, 256)
        self.decoder1 = self.conv_block(256, 128)
        self.decoder2 = self.conv_block(128, 64)
        self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1)

    def conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        enc1 = self.encoder1(x)
        enc2 = self.encoder2(F.max_pool2d(enc1, kernel_size=2))
        enc3 = self.encoder3(F.max_pool2d(enc2, kernel_size=2))
        dec1 = self.decoder1(F.interpolate(enc3, enc2.size()[2:]))
        dec2 = self.decoder2(F.interpolate(dec1, enc1.size()[2:]))
        return self.final_conv(dec2)


device = torch.device(
    "cuda"
    if torch.cuda.is_available()
    else ("mps" if torch.backends.mps.is_available() else "cpu")
)
model = UNet().to(device)

num_epochs = 25
learning_rate = 0.001

criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

model.train()
print(f"Starting training with {len(train_loader)} batches.", flush=True)
for epoch in range(num_epochs):
    epoch_loss = 0.0
    for images, masks in train_loader:
        images = images.to(device)
        masks = masks.to(device).float()

        outputs = model(images)

        loss = criterion(outputs, masks)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    print(
        f"Epoch [{epoch + 1}/{num_epochs}], Loss: {epoch_loss / len(train_loader):.4f}",
        flush=True,
    )

model.eval()
with torch.no_grad():
    print(f"Starting evaluation with {len(test_loader)} batches.", flush=True)
    for images, masks in test_loader:
        images = images.to(device)
        masks = masks.to(device)
        outputs = model(images)
        predicted_masks = torch.sigmoid(outputs) > 0.5

        plt.subplot(1, 3, 1)
        plt.imshow(images[0].permute(1, 2, 0).cpu())
        plt.title("Input Image")

        plt.subplot(1, 3, 2)
        plt.imshow(masks[0].cpu(), cmap="gray")
        plt.title("Ground Truth Mask")

        plt.subplot(1, 3, 3)
        plt.imshow(predicted_masks[0].cpu(), cmap="gray")
        plt.title("Predicted Mask")

        plt.show()
        break

Starting training with 511 batches.
Epoch [1/25], Loss: 0.2875
Epoch [2/25], Loss: 0.2233
Epoch [3/25], Loss: 0.2118


In [None]:
import torch
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot as plt

# Load the trained model (make sure to adjust the path if needed)
# model = UNet().to(device)  # Use the same model definition as earlier
# model.load_state_dict(torch.load("path/to/your/model.pth"))  # Load your trained model state
# model.eval()

# Load and preprocess the test image
image_path = "crosswalk.jpg"  # Update the path to your test image
image = Image.open(image_path).convert("RGB")

# Define the same transformations as during training
transform = transforms.Compose(
    [
        transforms.Resize((256, 256)),  # Make sure this matches your training transform
        transforms.ToTensor(),
    ]
)

# Preprocess the image
input_image = transform(image).unsqueeze(0)  # Add batch dimension
input_image = input_image.to(device)  # Move to the same device as the model

# Perform inference
with torch.no_grad():
    model.eval()  # Set the model to evaluation mode
    output = model(input_image)
    predicted_mask = torch.sigmoid(output) > 0.5  # Binarize output

# Convert predicted mask to a more visualizable format
predicted_mask = (
    predicted_mask.squeeze(0).cpu().numpy()
)  # Remove batch dimension and convert to ndarray

# Visualization
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.imshow(image)
plt.title("Original Image")
plt.axis("off")

plt.subplot(1, 2, 2)
plt.imshow(predicted_mask, cmap="gray")
plt.title("Predicted Shadow Mask")
plt.axis("off")

plt.show()