In [1]:
import segmentation_models_pytorch as smp

In [4]:
import kagglehub
import os
from PIL import Image
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2
import segmentation_models_pytorch as smp

In [5]:
# Download latest version
path = kagglehub.dataset_download("gvclsu/water-segmentation-dataset")

print("Path to dataset files:", path)

Path to dataset files: /home/jovyan/.cache/kagglehub/datasets/gvclsu/water-segmentation-dataset/versions/4


In [6]:
img_dir = os.path.join(path,"water_v2", "water_v2", "JPEGImages")
mask_dir = os.path.join(path,"water_v2", "water_v2", "Annotations")

In [27]:
img_dir = os.path.join(path,"water_v2", "water_v2", "JPEGImages" , 'ADE20K')
mask_dir = os.path.join(path,"water_v2", "water_v2", "Annotations" , 'ADE20K')

In [28]:
# --- Step 2: Dataset class ---
class WaterSegmentationDataset(Dataset):
    def __init__(self, img_dir, mask_dir, transform=None):
        self.img_dir = img_dir
        self.mask_dir = mask_dir
        self.img_names = sorted(os.listdir(img_dir))
        self.mask_names = sorted(os.listdir(mask_dir))
        self.transform = transform

    def __len__(self):
        return len(self.img_names)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_names[idx])
        mask_path = os.path.join(self.mask_dir, self.mask_names[idx])

        image = np.array(Image.open(img_path).convert("L"))
        mask = np.array(Image.open(mask_path).convert("L"))

        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image = augmented["image"]
            mask = augmented["mask"]

        return image, mask

In [29]:
# --- Step 3: Transforms ---
transform = A.Compose([
    A.Resize(256, 256),
    A.Normalize(mean=(0.5,), std=(0.5,)),
    ToTensorV2(),
])


In [30]:
# --- Step 5: Setup Dataset and DataLoader ---
train_dataset = WaterSegmentationDataset(img_dir=img_dir, mask_dir=mask_dir, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)


In [40]:
# --- Step 6: Model, loss, optimizer ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = smp.Unet(
    encoder_name="resnet34",
    encoder_weights="imagenet",
    in_channels=1,
    classes=1,
)
model.to(device)

Unet(
  (encoder): ResNetEncoder(
    (conv1): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track

In [41]:
criterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

In [42]:
model.train()
print_every = 100  # print after this many images
image_count = 0    # count of images processed so far

for epoch in range(10):
    epoch_loss = 0.0
    for images, masks in train_loader:
        images = images.to(device)
        masks = masks.float().unsqueeze(1).to(device) / 255.0

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()

        batch_size = images.size(0)  # number of images in the current batch
        image_count += batch_size
        epoch_loss += loss.item() * batch_size  # multiply loss by batch size to weight average

        if image_count >= print_every:
            avg_loss = epoch_loss / image_count
            print(f"Epoch [{epoch+1}/10], Images processed: {image_count}, Average Loss: {avg_loss:.4f}")
            image_count = 0
            epoch_loss = 0.0

    # If leftover images after loop ends (less than print_every)
    if image_count > 0:
        avg_loss = epoch_loss / image_count
        print(f"Epoch [{epoch+1}/10], Images processed: {image_count}, Average Loss: {avg_loss:.4f}")
        image_count = 0
        epoch_loss = 0.0

Epoch [1/10], Images processed: 104, Average Loss: 0.7833
Epoch [1/10], Images processed: 104, Average Loss: 0.6719
Epoch [1/10], Images processed: 104, Average Loss: 0.5633
Epoch [1/10], Images processed: 104, Average Loss: 0.5088
Epoch [1/10], Images processed: 104, Average Loss: 0.4953
Epoch [1/10], Images processed: 104, Average Loss: 0.4827
Epoch [1/10], Images processed: 104, Average Loss: 0.4459
Epoch [1/10], Images processed: 104, Average Loss: 0.4459
Epoch [1/10], Images processed: 104, Average Loss: 0.4502
Epoch [1/10], Images processed: 104, Average Loss: 0.4205
Epoch [1/10], Images processed: 104, Average Loss: 0.3831
Epoch [1/10], Images processed: 104, Average Loss: 0.3775
Epoch [1/10], Images processed: 104, Average Loss: 0.3852
Epoch [1/10], Images processed: 104, Average Loss: 0.3710
Epoch [1/10], Images processed: 104, Average Loss: 0.3758
Epoch [1/10], Images processed: 104, Average Loss: 0.3671
Epoch [1/10], Images processed: 104, Average Loss: 0.3585
Epoch [1/10], 