In [1]:
from google.colab import drive
drive.mount('/content/drive')

import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms.functional as TF
import torchvision.transforms as T
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
import os
from PIL import Image
import numpy as np

Mounted at /content/drive


In [2]:
# Dataset
class CarvanaDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = os.listdir(image_dir)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        img_path = os.path.join(self.image_dir, self.images[index])
        mask_path = os.path.join(self.mask_dir, self.images[index].replace(".jpg", "_mask.gif"))
        image = np.array(Image.open(img_path).convert("RGB"))
        mask = np.array(Image.open(mask_path).convert("L"), dtype=np.float32)
        mask[mask == 255.0] = 1.0

        if self.transform is not None:
            augmentations = self.transform(image=image, mask=mask)
            image = augmentations["image"]
            mask = augmentations["mask"]

        return image, mask

In [14]:
# Dataset loader
def get_loaders(
    train_dir,
    train_maskdir,
    val_dir,
    val_maskdir,
    batch_size,
    train_transform,
    val_transform,
):
    train_dataset = CarvanaDataset(
        image_dir=train_dir,
        mask_dir=train_maskdir,
        transform=train_transform,
    )

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        num_workers=2,
        pin_memory=True,
        shuffle=True,
    )

    val_dataset = CarvanaDataset(
        image_dir=val_dir,
        mask_dir=val_maskdir,
        transform=val_transform,
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        num_workers=2,
        pin_memory=True,
        shuffle=False,
    )

    return train_loader, val_loader

In [4]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.conv(x)

class UNET(nn.Module):
    def __init__(
            self, in_channels=3, out_channels=1, features=[64, 128, 256, 512],
    ):
        super(UNET, self).__init__()
        self.ups = nn.ModuleList()
        self.downs = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        # Down part of UNET
        for feature in features:
            self.downs.append(DoubleConv(in_channels, feature))
            in_channels = feature

        # Up part of UNET
        for feature in reversed(features):
            self.ups.append(
                nn.ConvTranspose2d(
                    feature*2, feature, kernel_size=2, stride=2,
                )
            )
            self.ups.append(DoubleConv(feature*2, feature))

        self.bottleneck = DoubleConv(features[-1], features[-1]*2)
        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)

    def forward(self, x):
        skip_connections = []

        for down in self.downs:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)

        x = self.bottleneck(x)
        skip_connections = skip_connections[::-1]

        for idx in range(0, len(self.ups), 2):
            x = self.ups[idx](x)
            skip_connection = skip_connections[idx//2]

            if x.shape != skip_connection.shape:
                x = TF.resize(x, size=skip_connection.shape[2:])

            concat_skip = torch.cat((skip_connection, x), dim=1)
            x = self.ups[idx+1](concat_skip)

        return self.final_conv(x)

In [5]:
### Target metric ###
def calc_iou(
    prediction: np.array,
    ground_truth: np.array
    ) -> np.array:

    n_images = len(prediction)
    intersection, union = 0, 0
    for i in range(n_images):
        intersection += np.logical_and(prediction[i] > 0, ground_truth[i] > 0).astype(np.float32).sum() 
        union += np.logical_or(prediction[i] > 0, ground_truth[i] > 0).astype(np.float32).sum()
    return float(intersection) / union

In [7]:
def loss_function(name):
    if name == "BCEWithLogitsLoss":
        return nn.BCEWithLogitsLoss()
    if name == "BCELoss":
        return nn.BCELoss()
    if name == "MSELoss":
        return nn.MSELoss()

In [12]:
def check_accuracy(val_loader, model, device="cuda"):
    num_correct = 0
    num_pixels = 0
    model.eval()

    with torch.no_grad():
        for data, target in val_loader:
            data = data.to(device)
            target = target.to(device).unsqueeze(1)

            prediction = torch.sigmoid(model(data))
            prediction = (prediction > 0.5).float()
            
            num_correct += (prediction == target).sum()
            num_pixels += torch.numel(prediction)

    print(f"Accuracy: {num_correct/num_pixels:.2f}")
    
    model.train()

In [6]:
learning_rate = 1e-4
batch_size = 16
num_epochs = 3

origin_height = 1280
origin_width = 1918
image_height = 128
image_width = 192

# Inside val_images/val_masks 48 images/masks
# Inside dataset for training the rest of the images/masks excluding 48 val images/masks

train_img_dir = "./drive/MyDrive/Task2/data/train_images/train_hq/"
train_mask_dir = "./drive/MyDrive/Task2/data/train_masks/train_masks/"
val_img_dir = "./drive/MyDrive/Task2/data/val_images/"
val_mask_dir = "./drive/MyDrive/Task2/data/val_masks/"


DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [8]:
# Augmentations for train dataset (+ resize for train and val datasets)
train_transform = A.Compose(
    [
        A.Resize(height=image_height, width=image_width),
        A.Rotate(limit=35, p=1.0),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.1),
        A.Normalize(
            mean=[0.0, 0.0, 0.0],
            std=[1.0, 1.0, 1.0],
            max_pixel_value=255.0,
        ),
        ToTensorV2(),
    ],
)

val_transforms = A.Compose(
    [
        A.Resize(height=image_height, width=image_width),
        A.Normalize(
            mean=[0.0, 0.0, 0.0],
            std=[1.0, 1.0, 1.0],
            max_pixel_value=255.0,
        ),
        ToTensorV2(),
    ],
)

In [9]:
model = UNET(in_channels=3, out_channels=1).to(DEVICE)
criterion = loss_function(name="BCEWithLogitsLoss")
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

train_loader, val_loader = get_loaders(
    train_img_dir,
    train_mask_dir,
    val_img_dir,
    val_mask_dir,
    batch_size,
    train_transform,
    val_transforms,
)

scaler = torch.cuda.amp.GradScaler()

In [11]:
for epoch in range(num_epochs):
    loop = tqdm(train_loader)
    train_ious =[]

    for i, (data, targets) in enumerate(loop):
        optimizer.zero_grad()
        
        data = data.to(device=DEVICE)
        targets = targets.float().unsqueeze(1).to(device=DEVICE)

        # forward
        with torch.cuda.amp.autocast():
            predictions = model(data)
            loss = criterion(predictions, targets)
            
        # backward
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        transform = T.Resize((origin_height, origin_width))
        predictions = transform(predictions)
        targets = transform(targets)

        iou = calc_iou(predictions.cpu().detach().numpy() > 0.1, 
                                  targets.cpu().numpy())
        train_ious.append(iou)

        # update tqdm loop with loss value
        loop.set_postfix(loss=loss.item())
        
    # check accuracy
    check_accuracy(val_loader, model, device=DEVICE)

    train_iou = np.mean(np.array(train_ious))
    print("IoU on Train set: %.3f" % train_iou)

100%|██████████| 315/315 [14:33<00:00,  2.77s/it, loss=0.132]


Accuracy: 0.99
IoU on Train set: 0.907


100%|██████████| 315/315 [06:36<00:00,  1.26s/it, loss=0.0795]


Accuracy: 0.99
IoU on Train set: 0.940


100%|██████████| 315/315 [06:55<00:00,  1.32s/it, loss=0.0551]


Accuracy: 0.99
IoU on Train set: 0.944


In [13]:
  torch.save(model.state_dict(), "weights.tar")