**WARNING**: Remember to run the `ExtractBoundaries` notebook present in this same folder before running all (or specifically the `Import Boundaries` cell)

In [1]:
SAVE_ON_DRIVE = True
TYPE = 'Train'
STYLE_TRANSFER = False

# Dataset initialization


### Download Data

In [2]:
from google.colab import drive
import os
import shutil

drive.mount('/content/drive')
if (os.path.exists("./Train") == False):
    if (os.path.exists("/content/drive/MyDrive/LoveDA/Train.zip")):
        print("Dataset available on own drive, unzipping...")
        !unzip -q /content/drive/MyDrive/LoveDA/Train.zip -d ./
    else:
        print("Downloading dataset...")
        !wget -O Train.zip "https://zenodo.org/records/5706578/files/Train.zip?download=1"
        if(SAVE_ON_DRIVE):
            print("Saving dataset on drive...")
            !cp Train.zip /content/drive/MyDrive/LoveDA/
        !unzip -q Train.zip -d ./

else:
    print("Dataset already in local")


if STYLE_TRANSFER:
  if (os.path.exists("./StyleTransfer") == False):
    if (os.path.exists("/content/drive/MyDrive/LoveDA/StyleTransfer.zip")):
      print("StyleTransfer available on own drive, unzipping...")
      !mkdir ./StyleTransfer
      !unzip -q /content/drive/MyDrive/LoveDA/StyleTransfer.zip -d ./StyleTransfer
    else:
      print("Cant download StyleTransfer")
  else:
    print("StyleTransfer already in local")


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Dataset already in local


### Import Boundaries

In [3]:
# Paths
rural_boundaries_path = "./Train/Rural/boundaries_png"
rural_masks_path = './Train/Rural/masks_png'

urban_boundaries_path = "./Train/Urban/boundaries_png"
urban_masks_path = './Train/Urban/masks_png'
drive_rural_boundaries_path = '/content/drive/MyDrive/LoveDA/boundaries/Rural/boundaries_png'
drive_urban_boundaries_path = '/content/drive/MyDrive/LoveDA/boundaries/Urban/boundaries_png'

boundaries_paths = [rural_boundaries_path, urban_boundaries_path]

# Make dir inside ./Train/...
for boundaries_path in boundaries_paths:
    if (os.path.exists(boundaries_path) == False):
        print(f"Creating {boundaries_path}...")
        os.makedirs(boundaries_path)
    else:
        print(f"{boundaries_path} exists...")


# Check if files are already present
rural_file_count = len([name for name in os.listdir(rural_boundaries_path) if os.path.isfile(os.path.join(rural_boundaries_path, name))])
rural_mask_file_count = len([name for name in os.listdir(rural_masks_path) if os.path.isfile(os.path.join(rural_masks_path, name))])
urban_file_count = len([name for name in os.listdir(urban_boundaries_path) if os.path.isfile(os.path.join(urban_boundaries_path, name))])
urban_mask_file_count = len([name for name in os.listdir(urban_masks_path) if os.path.isfile(os.path.join(urban_masks_path, name))])

if (rural_file_count != rural_mask_file_count):
    print(f"Importing boundaries, as we have {rural_file_count} rural boundaries as of now...")
    shutil.copytree(drive_rural_boundaries_path, rural_boundaries_path, dirs_exist_ok=True)
else:
    print(f"Rural boundaries already present, {rural_file_count} files...")

if (urban_file_count != urban_mask_file_count):
    print(f"Importing boundaries, as we have {urban_file_count} urban boundaries as of now...")
    shutil.copytree(drive_urban_boundaries_path, urban_boundaries_path, dirs_exist_ok=True)
else:
    print(f"Urban boundaries already present, {urban_file_count} files...")

./Train/Rural/boundaries_png exists...
./Train/Urban/boundaries_png exists...
Rural boundaries already present, 1366 files...
Urban boundaries already present, 1156 files...


### Dataset Definition

In [4]:
import os
import numpy as np
from PIL import Image
from torch.utils.data import Dataset
import torchvision.transforms as T
import random
import cv2


def pil_loader(path, color_type):
    # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
    with open(path, 'rb') as f:
        img = Image.open(f)
        return img.convert(color_type)

class LoveDADataset(Dataset):
    def __init__(self, baseTransform, augTransforms, split = 'Urban', typeDataset = 'Train', useBoundaries=True, styleTransfer = False, probStyle=0.0, validation_ratio=0.2, seed=265637):
        # Validate typeDataset input
        if typeDataset not in ['Train', 'Validation', 'Total']:
            raise ValueError("Invalid typeDataset. Expected 'Train' or 'Validation'.")
        self.directory = []
        directory_path = os.path.join('./Train', split, 'images_png')

        # Check if the directory exists
        if not os.path.exists(directory_path):
            raise FileNotFoundError(f"Directory not found: {directory_path}")
        # Get all image paths
        all_images = [os.path.join(directory_path, entry) for entry in os.listdir(directory_path) if os.path.isfile(os.path.join(directory_path, entry))]
        # Shuffle images for random splitting
        random.seed(seed)
        random.shuffle(all_images)
        # Split into training and validation sets
        split_idx = int(len(all_images) * (1 - validation_ratio))
        if typeDataset == 'Train':
            self.directory = all_images[:split_idx]
            if styleTransfer:
              style_directory = []
              listFromStyleToOriginal = []
              # Remove ones that not exist in directory, we have StyleTransfer Only in train
              for el in self.directory:
                name = el.replace('Train/Urban/images_png', 'StyleTransfer/Urban')
                style_directory.append(name)
                listFromStyleToOriginal.append(el)
              self.style_directory = style_directory
              self.probStyle = probStyle
              self.listFromStyleToOriginal = listFromStyleToOriginal

        elif typeDataset == 'Validation':
            self.directory = all_images[split_idx:]
        elif typeDataset == 'Total':
            self.directory = all_images
        else:
            raise ValueError("Invalid typeDataset. Expected 'Train' or 'Validation' or 'Total'.")
        self.baseTransforms = baseTransform
        self.augTransforms = augTransforms
        self.useBoundaries = useBoundaries
        self.typeDataset = typeDataset
        self.styleTransfer = styleTransfer
        self.probStyle = probStyle
        # Print dataset size
        print(f"Dataset size: {len(self.directory)}")

    def __len__(self):
        return len(self.directory)

    def __getitem__(self, idx):

        flagStyle = False
        if self.styleTransfer and random.random() < self.probStyle:
          image_path = self.style_directory[idx]
          original_path = self.listFromStyleToOriginal[idx]
          flagStyle = True
        else:
          image_path = self.directory[idx]
          original_path = image_path

        image = pil_loader(image_path, 'RGB')
        mask_path = original_path.replace('images_png', 'masks_png')
        if not os.path.exists(mask_path):
          raise FileNotFoundError(f"Mask not found: {mask_path}")
        boundaries_path = original_path.replace('images_png', 'boundaries_png')
        if not os.path.exists(boundaries_path):
          raise FileNotFoundError(f"Boundaries not found: {boundaries_path}")

        mask = pil_loader(mask_path, 'L')
        if flagStyle: # This is only because StyleTransfered images are 512x512
          # Resize mask to 512x512
          mask = T.Resize((512,512))(mask)

        if self.useBoundaries:
          boundaries = pil_loader(boundaries_path, 'L')
          if flagStyle:
            boundaries = T.Resize((512,512))(boundaries)
        else:
          boundaries = mask # Just a placeholder to mantain current dimensionality. May cause errors, so use `useBoundaries` only in testing scenario

        base_transformed = self.baseTransforms(image=np.array(image), mask=np.array(mask), boundaries=np.array(boundaries))
        base_image = base_transformed['image']
        base_mask = base_transformed['mask']
        base_boundaries = base_transformed['boundaries']

        base_image = T.Compose([T.ToTensor()])(base_image)
        base_mask = torch.from_numpy(base_mask).long()
        base_mask -= 1
        base_boundaries = torch.from_numpy(base_boundaries)

        if self.typeDataset == 'Train':
          return [base_image], [base_mask], image_path, [base_boundaries]
        else:
          return base_image, base_mask, image_path, base_boundaries

### Dataset Utils

In [5]:
from collections import OrderedDict
COLOR_MAP = OrderedDict(
    Background=(255, 255, 255),
    Building=(255, 0, 0),
    Road=(255, 255, 0),
    Water=(0, 0, 255),
    Barren=(159, 129, 183),
    Forest=(34, 139, 34),
    Agricultural=(255, 195, 128),
)

LABEL_MAP = OrderedDict(
    Background=0,
    Building=1,
    Road=2,
    Water=3,
    Barren=4,
    Forest=5,
    Agricultural=6,
)
inverted_label_map = OrderedDict((v, k) for k, v in LABEL_MAP.items())


def getLabelColor(label):
    # Default color for unclassified labels
    default_color = np.array([128, 128, 128])  # Gray

    # Check if label exists in inverted_label_map
    label_name = inverted_label_map.get(label, None)
    if label_name is None or label_name not in COLOR_MAP:
        return default_color  # Return default color for unclassified

    # Return the mapped color
    label_color = np.array(COLOR_MAP[label_name])
    return label_color


def getLegendHandles():
  handles = [mpatches.Patch(color=getLabelColor(i)/255, label=inverted_label_map[i]) for i in range(0, len(LABEL_MAP))]
  handles.append(mpatches.Patch(color=getLabelColor(-1)/255, label='Unclassified'))
  return handles

def new_colors_mask(mask):
  new_image = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8)
  for i, row in enumerate(mask):
    for j, cell in enumerate(row):
      new_image[i][j] = getLabelColor(cell.item())
  return new_image



### Dataset Debug

In [6]:
# # Comment this cell to save GPU time

# import matplotlib.pyplot as plt
# import torch
# from torch.utils.data import DataLoader
# import matplotlib.patches as mpatches

# train_dataset = LoveDADataset(type='Train', seed=222)
# print(train_dataset.__len__())

# # Get item
# image, mask, path, bd = train_dataset.__getitem__(88)

# # Show path
# print(f"Image is at {path}")

# # Show image
# image = image.permute(1, 2, 0)
# image = image.numpy()
# plt.imshow(image)

# # Show mask
# new_image = new_colors_mask(mask)
# plt.imshow(image)
# plt.show()
# plt.legend(handles=getLegendHandles(), loc='upper left', bbox_to_anchor=(1.05, 1), borderaxespad=0.)
# plt.imshow(new_image)
# plt.show()

# # Show boundaries
# # for row in bd:
# #     for col in row:
# #         if col != 0 and col != 1:
# #             print(col)
# bd = bd.numpy()
# plt.imshow(bd)


# Initialize model

### PIDNet Util Modules

In [7]:
# ------------------------------------------------------------------------------
# Written by Jiacong Xu (jiacong.xu@tamu.edu)
# ------------------------------------------------------------------------------
import torch
import torch.nn as nn
import torch.nn.functional as F

BatchNorm2d = nn.BatchNorm2d
bn_mom = 0.1
algc = False

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None, no_relu=False):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn1 = BatchNorm2d(planes, momentum=bn_mom)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               padding=1, bias=False)
        self.bn2 = BatchNorm2d(planes, momentum=bn_mom)
        self.downsample = downsample
        self.stride = stride
        self.no_relu = no_relu

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual

        if self.no_relu:
            return out
        else:
            return self.relu(out)

class Bottleneck(nn.Module):
    expansion = 2

    def __init__(self, inplanes, planes, stride=1, downsample=None, no_relu=True):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = BatchNorm2d(planes, momentum=bn_mom)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = BatchNorm2d(planes, momentum=bn_mom)
        self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1,
                               bias=False)
        self.bn3 = BatchNorm2d(planes * self.expansion, momentum=bn_mom)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride
        self.no_relu = no_relu

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        if self.no_relu:
            return out
        else:
            return self.relu(out)

class segmenthead(nn.Module):

    def __init__(self, inplanes, interplanes, outplanes, scale_factor=None):
        super(segmenthead, self).__init__()
        self.bn1 = BatchNorm2d(inplanes, momentum=bn_mom)
        self.conv1 = nn.Conv2d(inplanes, interplanes, kernel_size=3, padding=1, bias=False)
        self.bn2 = BatchNorm2d(interplanes, momentum=bn_mom)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(interplanes, outplanes, kernel_size=1, padding=0, bias=True)
        self.scale_factor = scale_factor

    def forward(self, x):

        x = self.conv1(self.relu(self.bn1(x)))
        out = self.conv2(self.relu(self.bn2(x)))

        if self.scale_factor is not None:
            height = x.shape[-2] * self.scale_factor
            width = x.shape[-1] * self.scale_factor
            out = F.interpolate(out,
                        size=[height, width],
                        mode='bilinear', align_corners=algc)

        return out

class DAPPM(nn.Module):
    def __init__(self, inplanes, branch_planes, outplanes, BatchNorm=nn.BatchNorm2d):
        super(DAPPM, self).__init__()
        bn_mom = 0.1
        self.scale1 = nn.Sequential(nn.AvgPool2d(kernel_size=5, stride=2, padding=2),
                                    BatchNorm(inplanes, momentum=bn_mom),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(inplanes, branch_planes, kernel_size=1, bias=False),
                                    )
        self.scale2 = nn.Sequential(nn.AvgPool2d(kernel_size=9, stride=4, padding=4),
                                    BatchNorm(inplanes, momentum=bn_mom),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(inplanes, branch_planes, kernel_size=1, bias=False),
                                    )
        self.scale3 = nn.Sequential(nn.AvgPool2d(kernel_size=17, stride=8, padding=8),
                                    BatchNorm(inplanes, momentum=bn_mom),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(inplanes, branch_planes, kernel_size=1, bias=False),
                                    )
        self.scale4 = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),
                                    BatchNorm(inplanes, momentum=bn_mom),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(inplanes, branch_planes, kernel_size=1, bias=False),
                                    )
        self.scale0 = nn.Sequential(
                                    BatchNorm(inplanes, momentum=bn_mom),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(inplanes, branch_planes, kernel_size=1, bias=False),
                                    )
        self.process1 = nn.Sequential(
                                    BatchNorm(branch_planes, momentum=bn_mom),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(branch_planes, branch_planes, kernel_size=3, padding=1, bias=False),
                                    )
        self.process2 = nn.Sequential(
                                    BatchNorm(branch_planes, momentum=bn_mom),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(branch_planes, branch_planes, kernel_size=3, padding=1, bias=False),
                                    )
        self.process3 = nn.Sequential(
                                    BatchNorm(branch_planes, momentum=bn_mom),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(branch_planes, branch_planes, kernel_size=3, padding=1, bias=False),
                                    )
        self.process4 = nn.Sequential(
                                    BatchNorm(branch_planes, momentum=bn_mom),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(branch_planes, branch_planes, kernel_size=3, padding=1, bias=False),
                                    )
        self.compression = nn.Sequential(
                                    BatchNorm(branch_planes * 5, momentum=bn_mom),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(branch_planes * 5, outplanes, kernel_size=1, bias=False),
                                    )
        self.shortcut = nn.Sequential(
                                    BatchNorm(inplanes, momentum=bn_mom),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(inplanes, outplanes, kernel_size=1, bias=False),
                                    )

    def forward(self, x):
        width = x.shape[-1]
        height = x.shape[-2]
        x_list = []

        x_list.append(self.scale0(x))
        x_list.append(self.process1((F.interpolate(self.scale1(x),
                        size=[height, width],
                        mode='bilinear', align_corners=algc)+x_list[0])))
        x_list.append((self.process2((F.interpolate(self.scale2(x),
                        size=[height, width],
                        mode='bilinear', align_corners=algc)+x_list[1]))))
        x_list.append(self.process3((F.interpolate(self.scale3(x),
                        size=[height, width],
                        mode='bilinear', align_corners=algc)+x_list[2])))
        x_list.append(self.process4((F.interpolate(self.scale4(x),
                        size=[height, width],
                        mode='bilinear', align_corners=algc)+x_list[3])))

        out = self.compression(torch.cat(x_list, 1)) + self.shortcut(x)
        return out

class PAPPM(nn.Module):
    def __init__(self, inplanes, branch_planes, outplanes, BatchNorm=nn.BatchNorm2d):
        super(PAPPM, self).__init__()
        bn_mom = 0.1
        self.scale1 = nn.Sequential(nn.AvgPool2d(kernel_size=5, stride=2, padding=2),
                                    BatchNorm(inplanes, momentum=bn_mom),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(inplanes, branch_planes, kernel_size=1, bias=False),
                                    )
        self.scale2 = nn.Sequential(nn.AvgPool2d(kernel_size=9, stride=4, padding=4),
                                    BatchNorm(inplanes, momentum=bn_mom),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(inplanes, branch_planes, kernel_size=1, bias=False),
                                    )
        self.scale3 = nn.Sequential(nn.AvgPool2d(kernel_size=17, stride=8, padding=8),
                                    BatchNorm(inplanes, momentum=bn_mom),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(inplanes, branch_planes, kernel_size=1, bias=False),
                                    )
        self.scale4 = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),
                                    BatchNorm(inplanes, momentum=bn_mom),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(inplanes, branch_planes, kernel_size=1, bias=False),
                                    )

        self.scale0 = nn.Sequential(
                                    BatchNorm(inplanes, momentum=bn_mom),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(inplanes, branch_planes, kernel_size=1, bias=False),
                                    )

        self.scale_process = nn.Sequential(
                                    BatchNorm(branch_planes*4, momentum=bn_mom),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(branch_planes*4, branch_planes*4, kernel_size=3, padding=1, groups=4, bias=False),
                                    )


        self.compression = nn.Sequential(
                                    BatchNorm(branch_planes * 5, momentum=bn_mom),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(branch_planes * 5, outplanes, kernel_size=1, bias=False),
                                    )

        self.shortcut = nn.Sequential(
                                    BatchNorm(inplanes, momentum=bn_mom),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(inplanes, outplanes, kernel_size=1, bias=False),
                                    )


    def forward(self, x):
        width = x.shape[-1]
        height = x.shape[-2]
        scale_list = []

        x_ = self.scale0(x)
        scale_list.append(F.interpolate(self.scale1(x), size=[height, width],
                        mode='bilinear', align_corners=algc)+x_)
        scale_list.append(F.interpolate(self.scale2(x), size=[height, width],
                        mode='bilinear', align_corners=algc)+x_)
        scale_list.append(F.interpolate(self.scale3(x), size=[height, width],
                        mode='bilinear', align_corners=algc)+x_)
        scale_list.append(F.interpolate(self.scale4(x), size=[height, width],
                        mode='bilinear', align_corners=algc)+x_)

        scale_out = self.scale_process(torch.cat(scale_list, 1))

        out = self.compression(torch.cat([x_,scale_out], 1)) + self.shortcut(x)
        return out


class PagFM(nn.Module):
    def __init__(self, in_channels, mid_channels, after_relu=False, with_channel=False, BatchNorm=nn.BatchNorm2d):
        super(PagFM, self).__init__()
        self.with_channel = with_channel
        self.after_relu = after_relu
        self.f_x = nn.Sequential(
                                nn.Conv2d(in_channels, mid_channels,
                                          kernel_size=1, bias=False),
                                BatchNorm(mid_channels)
                                )
        self.f_y = nn.Sequential(
                                nn.Conv2d(in_channels, mid_channels,
                                          kernel_size=1, bias=False),
                                BatchNorm(mid_channels)
                                )
        if with_channel:
            self.up = nn.Sequential(
                                    nn.Conv2d(mid_channels, in_channels,
                                              kernel_size=1, bias=False),
                                    BatchNorm(in_channels)
                                   )
        if after_relu:
            self.relu = nn.ReLU(inplace=True)

    def forward(self, x, y):
        input_size = x.size()
        if self.after_relu:
            y = self.relu(y)
            x = self.relu(x)

        y_q = self.f_y(y)
        y_q = F.interpolate(y_q, size=[input_size[2], input_size[3]],
                            mode='bilinear', align_corners=False)
        x_k = self.f_x(x)

        if self.with_channel:
            sim_map = torch.sigmoid(self.up(x_k * y_q))
        else:
            sim_map = torch.sigmoid(torch.sum(x_k * y_q, dim=1).unsqueeze(1))

        y = F.interpolate(y, size=[input_size[2], input_size[3]],
                            mode='bilinear', align_corners=False)
        x = (1-sim_map)*x + sim_map*y

        return x

class Light_Bag(nn.Module):
    def __init__(self, in_channels, out_channels, BatchNorm=nn.BatchNorm2d):
        super(Light_Bag, self).__init__()
        self.conv_p = nn.Sequential(
                                nn.Conv2d(in_channels, out_channels,
                                          kernel_size=1, bias=False),
                                BatchNorm(out_channels)
                                )
        self.conv_i = nn.Sequential(
                                nn.Conv2d(in_channels, out_channels,
                                          kernel_size=1, bias=False),
                                BatchNorm(out_channels)
                                )

    def forward(self, p, i, d):
        edge_att = torch.sigmoid(d)

        p_add = self.conv_p((1-edge_att)*i + p)
        i_add = self.conv_i(i + edge_att*p)

        return p_add + i_add


class DDFMv2(nn.Module):
    def __init__(self, in_channels, out_channels, BatchNorm=nn.BatchNorm2d):
        super(DDFMv2, self).__init__()
        self.conv_p = nn.Sequential(
                                BatchNorm(in_channels),
                                nn.ReLU(inplace=True),
                                nn.Conv2d(in_channels, out_channels,
                                          kernel_size=1, bias=False),
                                BatchNorm(out_channels)
                                )
        self.conv_i = nn.Sequential(
                                BatchNorm(in_channels),
                                nn.ReLU(inplace=True),
                                nn.Conv2d(in_channels, out_channels,
                                          kernel_size=1, bias=False),
                                BatchNorm(out_channels)
                                )

    def forward(self, p, i, d):
        edge_att = torch.sigmoid(d)

        p_add = self.conv_p((1-edge_att)*i + p)
        i_add = self.conv_i(i + edge_att*p)

        return p_add + i_add

class Bag(nn.Module):
    def __init__(self, in_channels, out_channels, BatchNorm=nn.BatchNorm2d):
        super(Bag, self).__init__()

        self.conv = nn.Sequential(
                                BatchNorm(in_channels),
                                nn.ReLU(inplace=True),
                                nn.Conv2d(in_channels, out_channels,
                                          kernel_size=3, padding=1, bias=False)
                                )


    def forward(self, p, i, d):
        edge_att = torch.sigmoid(d)
        return self.conv(edge_att*p + (1-edge_att)*i)




### PIDNet Definition

In [8]:
# ------------------------------------------------------------------------------
# Written by Jiacong Xu (jiacong.xu@tamu.edu)
# ------------------------------------------------------------------------------
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
import logging

BatchNorm2d = nn.BatchNorm2d
bn_mom = 0.1
algc = False



class PIDNet(nn.Module):

    def __init__(self, m=2, n=3, num_classes=19, planes=64, ppm_planes=96, head_planes=128, augment=True):
        super(PIDNet, self).__init__()
        self.augment = augment

        # I Branch
        self.conv1 =  nn.Sequential(
                          nn.Conv2d(3,planes,kernel_size=3, stride=2, padding=1),
                          BatchNorm2d(planes, momentum=bn_mom),
                          nn.ReLU(inplace=True),
                          nn.Conv2d(planes,planes,kernel_size=3, stride=2, padding=1),
                          BatchNorm2d(planes, momentum=bn_mom),
                          nn.ReLU(inplace=True),
                      )

        self.relu = nn.ReLU(inplace=True)
        self.layer1 = self._make_layer(BasicBlock, planes, planes, m)
        self.layer2 = self._make_layer(BasicBlock, planes, planes * 2, m, stride=2)
        self.layer3 = self._make_layer(BasicBlock, planes * 2, planes * 4, n, stride=2)
        self.layer4 = self._make_layer(BasicBlock, planes * 4, planes * 8, n, stride=2)
        self.layer5 =  self._make_layer(Bottleneck, planes * 8, planes * 8, 2, stride=2)

        # P Branch
        self.compression3 = nn.Sequential(
                                          nn.Conv2d(planes * 4, planes * 2, kernel_size=1, bias=False),
                                          BatchNorm2d(planes * 2, momentum=bn_mom),
                                          )

        self.compression4 = nn.Sequential(
                                          nn.Conv2d(planes * 8, planes * 2, kernel_size=1, bias=False),
                                          BatchNorm2d(planes * 2, momentum=bn_mom),
                                          )
        self.pag3 = PagFM(planes * 2, planes)
        self.pag4 = PagFM(planes * 2, planes)

        self.layer3_ = self._make_layer(BasicBlock, planes * 2, planes * 2, m)
        self.layer4_ = self._make_layer(BasicBlock, planes * 2, planes * 2, m)
        self.layer5_ = self._make_layer(Bottleneck, planes * 2, planes * 2, 1)

        # D Branch
        if m == 2:
            self.layer3_d = self._make_single_layer(BasicBlock, planes * 2, planes)
            self.layer4_d = self._make_layer(Bottleneck, planes, planes, 1)
            self.diff3 = nn.Sequential(
                                        nn.Conv2d(planes * 4, planes, kernel_size=3, padding=1, bias=False),
                                        BatchNorm2d(planes, momentum=bn_mom),
                                        )
            self.diff4 = nn.Sequential(
                                     nn.Conv2d(planes * 8, planes * 2, kernel_size=3, padding=1, bias=False),
                                     BatchNorm2d(planes * 2, momentum=bn_mom),
                                     )
            self.spp = PAPPM(planes * 16, ppm_planes, planes * 4)
            self.dfm = Light_Bag(planes * 4, planes * 4)
        else:
            self.layer3_d = self._make_single_layer(BasicBlock, planes * 2, planes * 2)
            self.layer4_d = self._make_single_layer(BasicBlock, planes * 2, planes * 2)
            self.diff3 = nn.Sequential(
                                        nn.Conv2d(planes * 4, planes * 2, kernel_size=3, padding=1, bias=False),
                                        BatchNorm2d(planes * 2, momentum=bn_mom),
                                        )
            self.diff4 = nn.Sequential(
                                     nn.Conv2d(planes * 8, planes * 2, kernel_size=3, padding=1, bias=False),
                                     BatchNorm2d(planes * 2, momentum=bn_mom),
                                     )
            self.spp = DAPPM(planes * 16, ppm_planes, planes * 4)
            self.dfm = Bag(planes * 4, planes * 4)

        self.layer5_d = self._make_layer(Bottleneck, planes * 2, planes * 2, 1)

        # Prediction Head
        if self.augment:
            self.seghead_p = segmenthead(planes * 2, head_planes, num_classes)
            self.seghead_d = segmenthead(planes * 2, planes, 1)

        self.final_layer = segmenthead(planes * 4, head_planes, num_classes)


        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)


    def _make_layer(self, block, inplanes, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion, momentum=bn_mom),
            )

        layers = []
        layers.append(block(inplanes, planes, stride, downsample))
        inplanes = planes * block.expansion
        for i in range(1, blocks):
            if i == (blocks-1):
                layers.append(block(inplanes, planes, stride=1, no_relu=True))
            else:
                layers.append(block(inplanes, planes, stride=1, no_relu=False))

        return nn.Sequential(*layers)

    def _make_single_layer(self, block, inplanes, planes, stride=1):
        downsample = None
        if stride != 1 or inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion, momentum=bn_mom),
            )

        layer = block(inplanes, planes, stride, downsample, no_relu=True)

        return layer

    def forward(self, x):

        width_output = x.shape[-1] // 8
        height_output = x.shape[-2] // 8

        x = self.conv1(x)
        x = self.layer1(x)
        x = self.relu(self.layer2(self.relu(x)))
        x_ = self.layer3_(x)
        x_d = self.layer3_d(x)

        x = self.relu(self.layer3(x))
        x_ = self.pag3(x_, self.compression3(x))
        x_d = x_d + F.interpolate(
                        self.diff3(x),
                        size=[height_output, width_output],
                        mode='bilinear', align_corners=algc)
        if self.augment:
            temp_p = x_

        x = self.relu(self.layer4(x))
        x_ = self.layer4_(self.relu(x_))
        x_d = self.layer4_d(self.relu(x_d))

        x_ = self.pag4(x_, self.compression4(x))
        x_d = x_d + F.interpolate(
                        self.diff4(x),
                        size=[height_output, width_output],
                        mode='bilinear', align_corners=algc)
        if self.augment:
            temp_d = x_d

        x_ = self.layer5_(self.relu(x_))
        x_d = self.layer5_d(self.relu(x_d))
        x = F.interpolate(
                        self.spp(self.layer5(x)),
                        size=[height_output, width_output],
                        mode='bilinear', align_corners=algc)

        x_ = self.final_layer(self.dfm(x_, x, x_d))

        if self.augment:
            x_extra_p = self.seghead_p(temp_p)
            x_extra_d = self.seghead_d(temp_d)
            return [x_extra_p, x_, x_extra_d]
        else:
            return x_

def get_seg_model(cfg, imgnet_pretrained):

    if 's' in cfg.MODEL.NAME:
        model = PIDNet(m=2, n=3, num_classes=cfg.DATASET.NUM_CLASSES, planes=32, ppm_planes=96, head_planes=128, augment=True)
    elif 'm' in cfg.MODEL.NAME:
        model = PIDNet(m=2, n=3, num_classes=cfg.DATASET.NUM_CLASSES, planes=64, ppm_planes=96, head_planes=128, augment=True)
    else:
        model = PIDNet(m=3, n=4, num_classes=cfg.DATASET.NUM_CLASSES, planes=64, ppm_planes=112, head_planes=256, augment=True)

    if imgnet_pretrained:
        pretrained_state = torch.load(cfg.MODEL.PRETRAINED, map_location='cpu')['state_dict']
        model_dict = model.state_dict()
        pretrained_state = {k: v for k, v in pretrained_state.items() if (k in model_dict and v.shape == model_dict[k].shape)}
        model_dict.update(pretrained_state)
        msg = 'Loaded {} parameters!'.format(len(pretrained_state))
        logging.info('Attention!!!')
        logging.info(msg)
        logging.info('Over!!!')
        model.load_state_dict(model_dict, strict = False)
    else:
        pretrained_dict = torch.load(cfg.MODEL.PRETRAINED, map_location='cpu')
        if 'state_dict' in pretrained_dict:
            pretrained_dict = pretrained_dict['state_dict']
        model_dict = model.state_dict()
        pretrained_dict = {k[6:]: v for k, v in pretrained_dict.items() if (k[6:] in model_dict and v.shape == model_dict[k[6:]].shape)}
        msg = 'Loaded {} parameters!'.format(len(pretrained_dict))
        logging.info('Attention!!!')
        logging.info(msg)
        logging.info('Over!!!')
        model_dict.update(pretrained_dict)
        model.load_state_dict(model_dict, strict = False)

    return model

def get_pred_model(name, num_classes):

    if 's' in name:
        model = PIDNet(m=2, n=3, num_classes=num_classes, planes=32, ppm_planes=96, head_planes=128, augment=False)
    elif 'm' in name:
        model = PIDNet(m=2, n=3, num_classes=num_classes, planes=64, ppm_planes=96, head_planes=128, augment=False)
    else:
        model = PIDNet(m=3, n=4, num_classes=num_classes, planes=64, ppm_planes=112, head_planes=256, augment=False)

    return model

### Load PIDNet Model

In [9]:
import gdown
import tarfile

if (os.path.exists("./PIDNet_S_ImageNet.pth.tar") == False):
  url = "https://drive.google.com/uc?id=1hIBp_8maRr60-B3PF0NVtaA6TYBvO4y-"
  output = "./"
  gdown.download(url, output, quiet=False)
# Then keep as tar, as it's already the correct format to feed the model

# Create a config object with required parameters
class Config:
    class MODEL:
        NAME = 'pidnet_s'  # or 'pidnet_m' or 'pidnet_l'
        PRETRAINED = 'PIDNet_S_ImageNet.pth.tar'
    class DATASET:
        NUM_CLASSES = len(LABEL_MAP)

cfg = Config()

model = get_seg_model(cfg, imgnet_pretrained=True)
# model = get_pred_model('s', len(LABEL_MAP))

  pretrained_state = torch.load(cfg.MODEL.PRETRAINED, map_location='cpu')['state_dict']


### Model Debugging

In [10]:
# import torch.nn.functional as F
# from torch.utils.data import DataLoader
# import matplotlib.pyplot as plt

# train_dataset = LoveDADataset(type='Train')
# train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, drop_last=True)

# model = model.train()
# model = model.to('cuda')

# for img, mask, _ in train_loader:
#     print(f"iamge shape: {img.shape}")
#     print(f"mask shape: {mask.shape}")

#     img = img.to('cuda')
#     outputs = model(img)

#     # bilinear interpolation
#     h, w = mask.size(1), mask.size(2)
#     ph, pw = outputs[0].size(2), outputs[0].size(3)
#     if ph != h or pw != w:
#         for i in range(len(outputs)):
#             outputs[i] = F.interpolate(outputs[i], size=(h, w), mode='bilinear',
#                                        align_corners=True)

#     for output in outputs:
#       print(output.shape)
#     break

# print("===================== Original Image =====================")
# plt.imshow(img[0].permute(1, 2, 0).cpu().numpy())
# plt.show()

# print("===================== Ground Truth =====================")
# plt.imshow(mask[0].cpu().numpy())
# plt.show()

# print("===================== Predicted Mask =====================")
# plt.imshow(torch.argmax(outputs[0][0], dim=0).cpu().numpy())
# plt.show()

# Training & Dataset creation

### Ablations and Macros

In [11]:
DEVICE = 'cuda' # 'cuda' or 'cpu'

LR = 2e-3            # The initial Learning Rate -- I increased it using quadratic rule in relation with batch size
MOMENTUM = 0.9       # Hyperparameter for SGD, keep this at 0.9 when using SGD
WEIGHT_DECAY = 5e-5  # Regularization, you can keep this at the default

NUM_EPOCHS = 20      # Total number of training epochs (iterations over dataset)
STEP_SIZE = 21      # How many epochs before decreasing learning rate (if using a step-down policy) -- Trying to keep a 2:3 ratio with NUM_EPOCHS
GAMMA = 0.1          # Multiplicative factor for learning rate step-down

LOG_FREQUENCY = 5
NUM_CLASSES = len(LABEL_MAP)
BATCH_SIZE = 64

In [12]:
#!pip install -U albumentations

### Setup, Create Datasets and DataLoaders. With annexed transforms.

In [13]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torch.nn.functional as F
from albumentations import Compose, HorizontalFlip, RandomRotate90, RandomScale, RandomCrop, GaussNoise, Rotate, Resize, OneOf, Normalize, ColorJitter, GaussianBlur
from albumentations.pytorch import ToTensorV2

#How big should be the image that we feed to the model?
RESIZE = 512
# DEFINE TRANSFORMATIONS HERE
# To Tensor is not needed since its performed inside the getitem


AUGMENTATIONS = {
    'Resize': Compose([
            Resize(RESIZE, RESIZE),
    ], additional_targets={"boundaries": "mask"}),
    'Normalize': Compose([
            Normalize(mean=(123.675, 116.28, 103.53), std=(58.395, 57.12, 57.375), max_pixel_value=1.0, always_apply=True),
            Resize(RESIZE, RESIZE),
            ], additional_targets={"boundaries": "mask"}),
    'RandomCropOrResize': Compose([
            OneOf([
                RandomCrop(RESIZE, RESIZE, p=0.5),  # Random crop to resize
                Resize(RESIZE, RESIZE, p=0.5)
            ], p=1)
            ], additional_targets={"boundaries": "mask"}),
    'Jitter+Resize': Compose([
            ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1, p=0.5),
            Resize(RESIZE, RESIZE),
            ], additional_targets={"boundaries": "mask"}),
    'NormalizeOnRural+Resize': Compose([
            Normalize(mean=(73.532, 80.017, 74.593),
                      std=(41.493, 35.653, 33.747), max_pixel_value=1.0, always_apply=True),
            Resize(RESIZE, RESIZE),
            ], additional_targets={"boundaries": "mask"}),
    'GaussianBlur+Resize': Compose([
            GaussianBlur(p=0.5),
            Resize(RESIZE, RESIZE),
            ], additional_targets={"boundaries": "mask"}),
    'RandomFlipOrRotate+Resize': Compose([
            OneOf([
                HorizontalFlip(p=0.5),
                RandomRotate90(p=0.5),
            ], p=0.5),
            Resize(RESIZE, RESIZE)
            ], additional_targets={"boundaries": "mask"}),
    'RandomCrop600-900+Resize': Compose([
            RandomCrop(900, 900, p=0.3),
            RandomCrop(800, 800, p=0.3),
            RandomCrop(700, 700, p=0.2),
            RandomCrop(600, 600, p=0.2),
            Resize(RESIZE, RESIZE),
            ], additional_targets={"boundaries": "mask"}),
    'Jitter+RotateFlip+Resize': Compose([
            ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1, p=0.5),
            OneOf([
                HorizontalFlip(p=0.5),
                RandomRotate90(p=0.5),
            ], p=0.5),
            Resize(RESIZE, RESIZE),
            ], additional_targets={"boundaries": "mask"}),
    'Jitter+RandomCrop600-900+Resize': Compose([
            ColorJitter(brightness=0.3, contrast=0, saturation=0, hue=0, p=0.5),
            RandomCrop(900, 900, p=0.3),
            RandomCrop(800, 800, p=0.3),
            RandomCrop(700, 700, p=0.2),
            RandomCrop(600, 600, p=0.2),
            Resize(RESIZE, RESIZE),
            ], additional_targets={"boundaries": "mask"}),
    'Brightess+Resize': Compose([
            ColorJitter(brightness=0.3, contrast=0, saturation=0, hue=0, p=0.5),
            Resize(RESIZE, RESIZE),
            ], additional_targets={"boundaries": "mask"}),
    'Jitter+RotateFlip+RandomCrop600-900+Resize_LESSPROB': Compose([
            ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1, p=0.3),
            OneOf([
                HorizontalFlip(p=0.5),
                RandomRotate90(p=0.5),
              ], p=0.3),
            OneOf([
                RandomCrop(900, 900, p=0.3),
                RandomCrop(800, 800, p=0.3),
                RandomCrop(700, 700, p=0.2),
                RandomCrop(600, 600, p=0.2),
                ], p= 0.3),
            Resize(RESIZE, RESIZE),
            ], additional_targets={"boundaries": "mask"})



}

#CHOOSE_TRANSFORM = ['Jitter+Resize', 'GaussianBlur+Resize', 'NormalizeOnRural+Resize',
CHOOSE_TRANSFORM =  ['Jitter+RandomCrop600-900+Resize']
transforms = [AUGMENTATIONS[transform] for transform in CHOOSE_TRANSFORM]

  check_for_updates()


### Losses

In [14]:
def weighted_bce(bd_pre, target):
    n, c, h, w = bd_pre.size()
    log_p = bd_pre.permute(0,2,3,1).contiguous().view(1, -1)
    target_t = target.view(1, -1)

    pos_index = (target_t == 1)
    neg_index = (target_t == 0)

    weight = torch.zeros_like(log_p)
    pos_num = pos_index.sum()
    neg_num = neg_index.sum()
    sum_num = pos_num + neg_num
    weight[pos_index] = neg_num * 1.0 / sum_num
    weight[neg_index] = pos_num * 1.0 / sum_num

    loss = F.binary_cross_entropy_with_logits(log_p, target_t, weight, reduction='mean')

    return loss

def boundary_loss(bd_pre, bd_gt):
    loss = 20.0 * weighted_bce(bd_pre, bd_gt)
    return loss

# TODO EXTRA add weights=class_weights to nn.CrossEntropyLoss()
# TODO EXTRA use OHCE instead of basic one
def cross_entropy(score, target):
    compute_ce_loss = nn.CrossEntropyLoss(ignore_index=-1)

    # See paper for weights. In order of loss index: (0.4, 20, 1, 1) # But on cfg they set everything to 0.5
    balance_weights = [0.4, 1]
    sb_weights = 1

    # print(f"DEBUG: inside cross_entropy: len(score) = {len(score)}")
    if len(balance_weights) == len(score):
        return sum([w * compute_ce_loss(x, target) for (w, x) in zip(balance_weights, score)])
    elif len(score) == 1:
        return sb_weights * compute_ce_loss(score[0], target)
    else:
        raise ValueError("lengths of prediction and target are not identical!")

sem_loss = cross_entropy
bd_loss = boundary_loss

### Training Loop

In [15]:
import matplotlib.pyplot as plt

for transform, name in zip(transforms, CHOOSE_TRANSFORM):
  cfg = Config()
  model = get_seg_model(cfg, imgnet_pretrained=True)

  style_transf = ''
  if STYLE_TRANSFER:
    style_transf = '_styleTransfer'

  SAVE_MODEL_AS = f'best_model_PIDNET{style_transf}_{name}.pth'

  # Dataset and Loader
  train_dataset = LoveDADataset(baseTransform=transform, augTransforms=None ,split='Urban', typeDataset='Train', styleTransfer=STYLE_TRANSFER, probStyle=0.5, validation_ratio=0.2)

  train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,
                            num_workers=2, drop_last=True, pin_memory=True)

  validation_dataset = LoveDADataset(baseTransform=AUGMENTATIONS['Resize'], augTransforms=None,
                        split='Urban', typeDataset='Validation', validation_ratio=0.2)

  validation_loader = DataLoader(validation_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, drop_last=True, pin_memory=True)

  # Optimizier and Scheduler
  optimizer = optim.SGD(model.parameters(), lr=LR, momentum=MOMENTUM)
  scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=STEP_SIZE, gamma=GAMMA)

  best_loss = float('inf')
  best_model = model.state_dict()
  model = model.to(DEVICE)
  print(f"DEVICE is {DEVICE}")
  if TYPE == 'Train':
    for epoch in range(NUM_EPOCHS):
      model.train()
      print('Starting epoch {}/{}, LR = {}'.format(epoch+1, NUM_EPOCHS, scheduler.get_lr()))
      epoch_loss = [0.0, 0]
      for (batch_i, batch) in enumerate(train_loader):
          optimizer.zero_grad()

          ### Extract input
          image_list, masks_list, img_path, bd_gts_list = batch

          index = 0
          for images, masks, bd_gts in zip(image_list, masks_list, bd_gts_list):
            images = images.to(DEVICE)
            masks = masks.to(DEVICE)
            bd_gts = bd_gts.float().to(DEVICE)

            ### ===> Forward, Upscale, Compute Losses
            ## Forward
            outputs = model(images) # in model.train() mode batch size must be > 1 I think
                                    # NOTE: we have 3 heads (i.e. 3 outputs) but 4 losses: 2nd head is used for both S and BAS

            ## Upscale (bilinear interpolation - not learned)
            h, w = masks.size(1), masks.size(2)
            ph, pw = outputs[0].size(2), outputs[0].size(3)
            if ph != h or pw != w:
                for i in range(len(outputs)):
                    outputs[i] = F.interpolate(outputs[i], size=(h, w), mode='bilinear', align_corners=True)

            ## Losses
            # Semantic Losses (l_0 and l_2)
            loss_s = sem_loss(outputs[:-1], masks) # output #1 and #2 are segmentation predictions (i.e. low level (P) and high+low level (PI) respectively)
            # Boundary Loss (l_1)
            loss_b = bd_loss(outputs[-1], bd_gts) # output #3 is the boundary prediction

            # Boundary AwareneSS (BAS) Loss (l_3)
            filler = torch.ones_like(masks) * -1
            bd_label = torch.where(F.sigmoid(outputs[-1][:,0,:,:]) > 0.8, masks, filler)
                                # REMEMBER to wrap in list, as the checks in ce use that to know what to do
            loss_sb = sem_loss([outputs[-2]], bd_label) # output #2 is the PI segmentation prediction, done here in BAS mode (see `filler` variable)

            # Complete Loss
            loss = loss_s + loss_b + loss_sb # The coefficients of the sum of the four losses (0.4, 20, 1, 1) are taken into account in the various `sem_loss` and `bd_loss`
            ### <=== Forward, Upscale, Compute Losses

            ### Backprop
            if batch_i % LOG_FREQUENCY == 0:
                if index > 0:
                  print(f'Augmented images loss: {loss.item()}')
                else:
                  print(f'Loss at batch {batch_i}: {loss.item()}')
            loss.backward()

            optimizer.step()
            epoch_loss[0] += loss.item()
            epoch_loss[1] += images.size(0)
            index+=1

      # Evaluate model on the evaluation set and save the parameters if is better than best model
      model.eval()
      total_loss = 0.0
      outputs = []
      with torch.no_grad():
        for (batch_i, batch) in enumerate(validation_loader):
          ### Extract input
          images, masks, img_path, bd_gts = batch
          images = images.float().to(DEVICE)
          masks = masks.to(DEVICE)
          bd_gts = bd_gts.float().to(DEVICE)

          ### ===> Forward, Upscale, Compute Losses
          ## Forward
          outputs = model(images)

          ## Upscale (bilinear interpolation - not learned)
          h, w = masks.size(1), masks.size(2)
          ph, pw = outputs[0].size(2), outputs[0].size(3)
          if ph != h or pw != w:
              for i in range(len(outputs)):
                  outputs[i] = F.interpolate(outputs[i], size=(h, w), mode='bilinear', align_corners=True)

          ## Losses
          # Semantic Losses (l_0 and l_2)
          loss_s = sem_loss(outputs[:-1], masks)

          # Boundary Loss (l_1)
          loss_b = bd_loss(outputs[-1], bd_gts)

          # Boundary AwareneSS (BAS) Loss (l_3)
          filler = torch.ones_like(masks) * -1
          bd_label = torch.where(F.sigmoid(outputs[-1][:,0,:,:]) > 0.8, masks, filler)
          loss_sb = sem_loss([outputs[-2]], bd_label)

          # Complete Loss
          loss = loss_s + loss_b + loss_sb
          ### <=== Forward, Upscale, Compute Losses
          total_loss += loss.item()

      print('Epoch {}, Loss {}'.format(epoch+1, total_loss))
      if total_loss < best_loss:
        best_loss = total_loss
        best_model = model.state_dict()
        #Save in Drive and local
        torch.save(best_model, SAVE_MODEL_AS)
        if SAVE_ON_DRIVE:
          !cp {SAVE_MODEL_AS} /content/drive/MyDrive/LoveDA/{SAVE_MODEL_AS}
          print(f"model succesfully saved on drive. loss went down to {best_loss}")

      scheduler.step()
      print(f'[EPOCH {epoch+1}] Avg. Loss: {epoch_loss[0] / epoch_loss[1]}')

      # Create a figure with 1 row and 3 columns
      fig, axes = plt.subplots(1, 3, figsize=(15, 5))

      # Plot the original image
      axes[0].imshow(images[0].permute(1, 2, 0).cpu().numpy())
      axes[0].set_title("Original Image")
      axes[0].axis('off')

      # Plot the ground truth mask
      axes[1].imshow(masks[0].cpu().numpy())
      axes[1].set_title("Ground Truth")
      axes[1].axis('off')

      # Plot the predicted mask
      axes[2].imshow(torch.argmax(outputs[0][0], dim=0).cpu().numpy())
      axes[2].set_title("Predicted Mask")
      axes[2].axis('off')

      # Display the figure
      plt.tight_layout()
      plt.show()

Output hidden; open in https://colab.research.google.com to view.

In [16]:
# from google.colab import runtime
# runtime.unassign()

## Try styleTransfer dataset


In [17]:
# import matplotlib.pyplot as plt
# for i in range(10):
#   image, mask, path, boundaries = train_dataset.__getitem__(i)
#   image = image[0]
#   mask = mask[0]
#   boundaries = boundaries[0]
#   print(path)
#   # Draw image
#   plt.imshow(image.permute(1, 2, 0).cpu().numpy())
#   plt.show()
#   # Draw mask
#   plt.imshow(mask.cpu().numpy())
#   plt.show()
#   # Draw boundaries in greyscale
#   plt.imshow(boundaries.cpu().numpy(), cmap='gray')
#   plt.show()

# TEST

In [18]:
!pip install torchmetrics ptflops



In [19]:
from torchmetrics import Accuracy
from tqdm import tqdm
import time
import ptflops

TYPE = 'Test'
# TARGET = 'Rural' # ALREADY RUNNING ON BOTH BELOW
#Load best_model
#!cp /content/drive/MyDrive/LoveDA/best_model_step2b.pth /content/best_model_step2b.pth


model = get_seg_model(cfg, imgnet_pretrained=False)
# original model path: /content/drive/MyDrive/LoveDA/best_model_step2b.pth
best_model = torch.load(f'/content/best_model_{CHOOSE_TRANSFORM}.pth', weights_only=True)
model.load_state_dict(best_model)
model = model.to(DEVICE)

accuracy, mIoU = True, True

TARGETs = ['Urban', 'Rural']

for TARGET in TARGETs:
  if TARGET == 'Urban': # Here we just validate on less images (20%) if URBAN, as it's not the focus of step 3b.
      target_type = 'Validation'
  elif TARGET == 'Rural': # While we take the entirety of the Rural folder in case of Rural
      target_type = 'Total'
  else:
      raise ValueError("TARGET must be 'Urban' or 'Rural'")

  if CHOOSE_TRANSFORM == 'RandomCropOrResize':
    test_augmentation = AUGMENTATIONS['RandomCrop512']
  elif CHOOSE_TRANSFORM == 'RandomCropXXX':
    test_augmentation = AUGMENTATIONS['None']
  elif CHOOSE_TRANSFORM == 'Jitter':
    test_augmentation = AUGMENTATIONS['None']
  elif CHOOSE_TRANSFORM == 'GaussianBlur':
    test_augmentation = AUGMENTATIONS['None']
  else:
    test_augmentation = AUGMENTATIONS[CHOOSE_TRANSFORM]

  test_dataset = LoveDADataset(baseTranform=test_augmentation, augTransforms=None, split=TARGET, type=target_type)
  test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, drop_last=True, pin_memory=True)

  #### TEST LOOP
  model.eval()
  print(f"Testing on domain={TARGET} using augmentations={CHOOSE_TRANSFORM} on a {target_type} split")

  # Latency
  with torch.no_grad():
      start_time = time.time()
      for _ in range(100):
          _ = model(torch.randn(1, 3, RESIZE, RESIZE).to(DEVICE))
      end_time = time.time()
  latency = (end_time - start_time) / 100
  print(f"Latency: {latency:.4f} seconds")

  # FLOPs
  macs, _ = ptflops.get_model_complexity_info(model,
    (3, RESIZE, RESIZE), as_strings=False,
    print_per_layer_stat=False, verbose=False)
  flops = macs * 2  # MACs perform two FLOPs
  print("FLOPs:", flops)

  # Number of parameters
  total_params = sum(p.numel() for p in model.parameters())
  print(f"Total number of parameters: {total_params}")

  if TYPE == 'Test':
    with torch.no_grad():
        total_union = torch.zeros(NUM_CLASSES).to(DEVICE)
        total_intersection = torch.zeros(NUM_CLASSES).to(DEVICE)
        meter = Accuracy(task='multiclass', num_classes=NUM_CLASSES).to(DEVICE)
        for (batch) in tqdm(test_loader):
            ### Extract input
            images, masks, img_path, bd_gts = batch
            images = images.float().to(DEVICE)
            masks = masks.to(DEVICE)

            ### ===> Forward, Upscale, Compute Losses
            ## Forward
            outputs = model(images)

            ## Upscale (bilinear interpolation - not learned)
            h, w = masks.size(1), masks.size(2)
            ph, pw = outputs[0].size(2), outputs[0].size(3)
            if ph != h or pw != w:
                for i in range(len(outputs)):
                    outputs[i] = F.interpolate(outputs[i], size=(h, w), mode='bilinear', align_corners=True)

            # Output 1 is the prediction

            # Shape: NBATCHES x classes x h x w
            class_indices = torch.argmax(outputs[1], dim=1)  # Shape: NBATCHES x h x w

            if accuracy:
            # Create a mask for valid targets (where target is not -1)
              valid_mask = (masks != -1)  # Mask of shape: NBATCHES x h x w
              # Apply the mask to ignore -1 targets when updating the accuracy metric
              meter.update(class_indices[valid_mask], masks[valid_mask])

            if mIoU:
              for predicted, target in zip(class_indices, masks):
                for i in range(NUM_CLASSES):
                  total_intersection[i] += torch.sum(torch.logical_and(predicted == i, target == i))
                  total_union[i] += torch.sum(torch.logical_or(predicted == i, target == i))

    if accuracy:
      accuracy = meter.compute()
      print(f'\nAccuracy on the target domain: {100 * accuracy:.2f}%')

    if mIoU:
      intersection_over_union = total_intersection / total_union

      # Per class IoU
      for i, iou in enumerate(intersection_over_union):
          class_name = list(LABEL_MAP.keys())[list(LABEL_MAP.values()).index(i)]  # Get the class name from LABEL_MAP
          print(f'{class_name} IoU: {iou:.4f}')

      mIoU = torch.mean(intersection_over_union)
      print(f'\nmIoU on the target domain: {mIoU}')

  print("========================================================================")


  pretrained_dict = torch.load(cfg.MODEL.PRETRAINED, map_location='cpu')


FileNotFoundError: [Errno 2] No such file or directory: "/content/best_model_['Jitter+RandomCrop600-900+Resize'].pth"