In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# import the necessary packages
import torch, torchvision
import torch.optim as optim
import os, glob
# base path of the dataset
IMAGE_PATH = "/content/drive/MyDrive/SegmenData2/images"
MASK_PATH = "/content/drive/MyDrive/SegmenData2/masks"

imgs = glob.glob(IMAGE_PATH+"/*.png")
masks = glob.glob(MASK_PATH+"/*.png")

train_imgs = imgs[:int(len(imgs)*0.75)]
train_masks = masks[:int(len(imgs)*0.75)]

val_imgs = imgs[int(len(imgs)*0.75):]
val_masks = masks[int(len(imgs)*0.75):]


In [None]:
import torch
import torch.nn as nn
import torchvision.transforms.functional as TF

class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.conv(x)

class UNET(nn.Module):
    def __init__(
            self, in_channels=3, out_channels=1, features=[64, 128, 256, 512],
    ):
        super(UNET, self).__init__()
        self.ups = nn.ModuleList()
        self.downs = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        # Down part of UNET
        for feature in features:
            self.downs.append(DoubleConv(in_channels, feature))
            in_channels = feature

        # Up part of UNET
        for feature in reversed(features):
            self.ups.append(
                nn.ConvTranspose2d(
                    feature*2, feature, kernel_size=2, stride=2,
                )
            )
            self.ups.append(DoubleConv(feature*2, feature))

        self.bottleneck = DoubleConv(features[-1], features[-1]*2)
        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)

    def forward(self, x):
        skip_connections = []

        for down in self.downs:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)

        x = self.bottleneck(x)
        skip_connections = skip_connections[::-1]

        for idx in range(0, len(self.ups), 2):
            x = self.ups[idx](x)
            skip_connection = skip_connections[idx//2]

            if x.shape != skip_connection.shape:
                x = TF.resize(x, size=skip_connection.shape[2:])

            concat_skip = torch.cat((skip_connection, x), dim=1)
            x = self.ups[idx+1](concat_skip)

        return self.final_conv(x)

In [None]:
from tqdm import tqdm
best_Acc = 0
def train_fn(loader, model, optimizer, loss_fn, scaler):
    
    loop = tqdm(loader)

    for batch_idx, (data, targets) in enumerate(loop):
        data = data.to(device=DEVICE)
        targets = targets.float().unsqueeze(1).to(device=DEVICE)

        # forward
        with torch.cuda.amp.autocast():
            predictions = model(data)
            loss = loss_fn(predictions, targets)

        # backward
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        # update tqdm loop
        loop.set_postfix(loss=loss.item())
    return model

def check_accuracy(loader, model, epoch, device="cuda"):
    global best_Acc
    num_correct = 0
    num_pixels = 0
    dice_score = 0
    model.eval()

    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device).unsqueeze(1)
            preds = torch.sigmoid(model(x))
            preds = (preds > 0.5).float()
            num_correct += (preds == y).sum()
            num_pixels += torch.numel(preds)
            dice_score += (2 * (preds * y).sum()) / (
                (preds + y).sum() + 1e-8
            )

    print(
        f"Got {num_correct}/{num_pixels} with acc {num_correct/num_pixels*100:.2f}"
    )
    print(f"Dice score: {dice_score/len(loader)}")
    if best_Acc < (num_correct/num_pixels*100):
      best_Acc = (num_correct/num_pixels*100)
      save_checkpoint(model,epoch, (num_correct/num_pixels*100),True)
    else:
      save_checkpoint(model,epoch, (num_correct/num_pixels*100),False)
    model.train()

def save_predictions_as_imgs(
    loader, model, folder="saved_images/", device="cuda"
):
    model.eval()
    for idx, (x, y) in enumerate(loader):
        x = x.to(device=device)
        with torch.no_grad():
            preds = torch.sigmoid(model(x))
            preds = (preds > 0.5).float()
        torchvision.utils.save_image(
            preds, f"{folder}/pred_{idx}.png"
        )
        torchvision.utils.save_image(y.unsqueeze(1), f"{folder}{idx}.png")

    model.train()

def save_checkpoint(model,epoch,score, best = False):
    
    if best:
      torch.save(model, "/content/drive/MyDrive/savemodels/bestmodel_epoch_{}_score_{:.2f}.pt".format(epoch,score))
      print("=> Saving checkpoint => bestmodel_epoch_{}_score_{:.2f}.pt".format(epoch,score))
    else:
      torch.save(model, "/content/drive/MyDrive/savemodels/model_epoch_{}_score_{:.2f}.pt".format(epoch,score))
      print("=> Saving checkpoint => model_epoch_{}_score_{:.2f}.pt".format(epoch,score))

In [None]:
import albumentations as A
from albumentations.pytorch import ToTensorV2
from torch.utils.data import Dataset
import cv2
import numpy as np

train_transform = A.Compose(
        [
            A.Resize(height=512, width=512),
            A.HorizontalFlip(p=0.5),
            A.Normalize(
                mean=[0.0, 0.0, 0.0],
                std=[1.0, 1.0, 1.0],
                max_pixel_value=255.0,
            ),
            ToTensorV2(),
        ],
    )

val_transforms = A.Compose(
        [
            A.Resize(height=512, width=512),
            A.Normalize(
                mean=[0.0, 0.0, 0.0],
                std=[1.0, 1.0, 1.0],
                max_pixel_value=255.0,
            ),
            ToTensorV2(),
        ],
    )

class SegDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = image_dir

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        img_path = self.images[index]
        mask_path = self.images[index].replace("images","masks")
        image = cv2.imread(img_path)
        mask = cv2.imread(mask_path, 0)
        mask[mask == 255.0] = 1.0

        if self.transform is not None:
            augmentations = self.transform(image=image, mask=mask)
            image = augmentations["image"]
            mask = augmentations["mask"]
        return image, mask



In [None]:
LEARNING_RATE = 1e-4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
Batchsize = 4
NUM_EPOCHS = 50

In [None]:
from torch.utils.data import DataLoader

train_ds = SegDataset(
        image_dir=train_imgs,
        mask_dir=train_masks,
        transform=train_transform,)

val_ds = SegDataset(
        image_dir=val_imgs,
        mask_dir=val_masks,
        transform=val_transforms,
    )
train_loader = DataLoader(
        train_ds,
        batch_size=Batchsize,
        pin_memory=True,
        shuffle=True,
    )
val_loader = DataLoader(
        train_ds,
        batch_size=Batchsize,
        pin_memory=True,
        shuffle=True,
    )

In [None]:
from torch.nn import BCEWithLogitsLoss
from torch.optim import Adam


model = UNET(in_channels=3, out_channels=1).to(DEVICE)
loss_fn = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

In [None]:
scaler = torch.cuda.amp.GradScaler()

for epoch in range(NUM_EPOCHS):
    model = train_fn(train_loader, model, optimizer, loss_fn, scaler)

    check_accuracy(val_loader, model, epoch,device=DEVICE)

100%|██████████| 500/500 [1:14:46<00:00,  8.97s/it, loss=0.443]


Got 418993241/524288000 with acc 79.92
Dice score: 0.665310800075531
=> Saving checkpoint => bestmodel_epoch_0_score_79.92.pt


100%|██████████| 500/500 [05:44<00:00,  1.45it/s, loss=0.485]


Got 436678705/524288000 with acc 83.29
Dice score: 0.7179595828056335
=> Saving checkpoint => bestmodel_epoch_1_score_83.29.pt


100%|██████████| 500/500 [05:44<00:00,  1.45it/s, loss=0.412]


Got 442737923/524288000 with acc 84.45
Dice score: 0.756439208984375
=> Saving checkpoint => bestmodel_epoch_2_score_84.45.pt


100%|██████████| 500/500 [05:45<00:00,  1.45it/s, loss=0.344]


Got 446810672/524288000 with acc 85.22
Dice score: 0.7760874032974243
=> Saving checkpoint => bestmodel_epoch_3_score_85.22.pt


100%|██████████| 500/500 [05:44<00:00,  1.45it/s, loss=0.348]


Got 453546752/524288000 with acc 86.51
Dice score: 0.7596524953842163
=> Saving checkpoint => bestmodel_epoch_4_score_86.51.pt


100%|██████████| 500/500 [05:43<00:00,  1.45it/s, loss=0.308]


Got 449498041/524288000 with acc 85.73
Dice score: 0.7925180792808533
=> Saving checkpoint => model_epoch_5_score_85.73.pt


100%|██████████| 500/500 [05:44<00:00,  1.45it/s, loss=0.173]


Got 467635756/524288000 with acc 89.19
Dice score: 0.8174414038658142
=> Saving checkpoint => bestmodel_epoch_6_score_89.19.pt


100%|██████████| 500/500 [05:46<00:00,  1.44it/s, loss=0.181]


Got 470416700/524288000 with acc 89.72
Dice score: 0.8287442326545715
=> Saving checkpoint => bestmodel_epoch_7_score_89.72.pt


100%|██████████| 500/500 [05:45<00:00,  1.45it/s, loss=0.167]


Got 475749372/524288000 with acc 90.74
Dice score: 0.8534322381019592
=> Saving checkpoint => bestmodel_epoch_8_score_90.74.pt


100%|██████████| 500/500 [05:45<00:00,  1.45it/s, loss=0.146]


Got 478994887/524288000 with acc 91.36
Dice score: 0.8579071164131165
=> Saving checkpoint => bestmodel_epoch_9_score_91.36.pt


100%|██████████| 500/500 [05:46<00:00,  1.44it/s, loss=0.296]


Got 474830208/524288000 with acc 90.57
Dice score: 0.853775680065155
=> Saving checkpoint => model_epoch_10_score_90.57.pt


100%|██████████| 500/500 [05:46<00:00,  1.44it/s, loss=0.185]


Got 478549545/524288000 with acc 91.28
Dice score: 0.8612474203109741
=> Saving checkpoint => model_epoch_11_score_91.28.pt


100%|██████████| 500/500 [05:46<00:00,  1.44it/s, loss=0.367]


Got 475229786/524288000 with acc 90.64
Dice score: 0.8492451906204224
=> Saving checkpoint => model_epoch_12_score_90.64.pt


100%|██████████| 500/500 [05:45<00:00,  1.45it/s, loss=0.546]


Got 476923141/524288000 with acc 90.97
Dice score: 0.8610224723815918
=> Saving checkpoint => model_epoch_13_score_90.97.pt


100%|██████████| 500/500 [05:44<00:00,  1.45it/s, loss=0.18]


Got 487007756/524288000 with acc 92.89
Dice score: 0.8860558271408081
=> Saving checkpoint => bestmodel_epoch_14_score_92.89.pt


100%|██████████| 500/500 [05:44<00:00,  1.45it/s, loss=0.136]


Got 484490254/524288000 with acc 92.41
Dice score: 0.8806789517402649
=> Saving checkpoint => model_epoch_15_score_92.41.pt


100%|██████████| 500/500 [05:44<00:00,  1.45it/s, loss=0.084]


Got 486205620/524288000 with acc 92.74
Dice score: 0.8806965947151184
=> Saving checkpoint => model_epoch_16_score_92.74.pt


100%|██████████| 500/500 [05:46<00:00,  1.44it/s, loss=0.236]


Got 486612721/524288000 with acc 92.81
Dice score: 0.8814024329185486
=> Saving checkpoint => model_epoch_17_score_92.81.pt


100%|██████████| 500/500 [05:46<00:00,  1.44it/s, loss=0.214]


Got 491156649/524288000 with acc 93.68
Dice score: 0.8968191742897034
=> Saving checkpoint => bestmodel_epoch_18_score_93.68.pt


100%|██████████| 500/500 [05:45<00:00,  1.45it/s, loss=0.3]


Got 491918000/524288000 with acc 93.83
Dice score: 0.9001505374908447
=> Saving checkpoint => bestmodel_epoch_19_score_93.83.pt


100%|██████████| 500/500 [05:46<00:00,  1.44it/s, loss=0.237]


Got 490823859/524288000 with acc 93.62
Dice score: 0.8957964777946472
=> Saving checkpoint => model_epoch_20_score_93.62.pt


100%|██████████| 500/500 [05:44<00:00,  1.45it/s, loss=0.151]


Got 490573087/524288000 with acc 93.57
Dice score: 0.891776442527771
=> Saving checkpoint => model_epoch_21_score_93.57.pt


100%|██████████| 500/500 [05:43<00:00,  1.46it/s, loss=0.132]


Got 491358725/524288000 with acc 93.72
Dice score: 0.8983933925628662
=> Saving checkpoint => model_epoch_22_score_93.72.pt


100%|██████████| 500/500 [05:44<00:00,  1.45it/s, loss=0.074]


Got 489585281/524288000 with acc 93.38
Dice score: 0.8902544975280762
=> Saving checkpoint => model_epoch_23_score_93.38.pt


100%|██████████| 500/500 [05:46<00:00,  1.44it/s, loss=0.0746]


Got 480350645/524288000 with acc 91.62
Dice score: 0.8720396757125854
=> Saving checkpoint => model_epoch_24_score_91.62.pt


100%|██████████| 500/500 [05:46<00:00,  1.44it/s, loss=0.151]


Got 492074409/524288000 with acc 93.86
Dice score: 0.9010517597198486
=> Saving checkpoint => bestmodel_epoch_25_score_93.86.pt


100%|██████████| 500/500 [05:42<00:00,  1.46it/s, loss=0.165]


Got 494484708/524288000 with acc 94.32
Dice score: 0.9072550535202026
=> Saving checkpoint => bestmodel_epoch_26_score_94.32.pt


100%|██████████| 500/500 [05:40<00:00,  1.47it/s, loss=0.391]


Got 496891315/524288000 with acc 94.77
Dice score: 0.9144235253334045
=> Saving checkpoint => bestmodel_epoch_27_score_94.77.pt


100%|██████████| 500/500 [05:42<00:00,  1.46it/s, loss=0.0601]


Got 493019523/524288000 with acc 94.04
Dice score: 0.9037988185882568
=> Saving checkpoint => model_epoch_28_score_94.04.pt


100%|██████████| 500/500 [05:41<00:00,  1.46it/s, loss=0.123]


Got 496465198/524288000 with acc 94.69
Dice score: 0.9146831631660461
=> Saving checkpoint => model_epoch_29_score_94.69.pt


 31%|███       | 155/500 [01:45<03:13,  1.78it/s, loss=0.124]