In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms.functional as T
import numpy as np

# Generalized UNet architecture from scratch for learning purposes. See https://arxiv.org/abs/1505.04597

class CNNBlock(nn.Module):

     # In the UNet architecture, this is the smallest possible block.
     # Each of these applies a 2D convolution, normalizes the output and passes it through a ReLU activation function.
     # The result of passing an input of size [N, in, H_in, W_in] through this layer is an output [N, out, H_out, W_out]
     #    where H_out = (H_in + 2*padding - dilation*(kernel_size-1) - 1)/stride + 1 
     #          W_out = (W_in + 2*padding - dilation*(kernel_size-1) - 1)/stride + 1 
     # This output corresponds to the element_wise result of the ReLU activation function of the normalized convolution output.

     # The block is used several times at each "level", which is why we later define a CNNSet, which is composed of several 
     # usages of the CNNBlock sequentially 

     """
     Parameters:
     in_channels (int): Number of channels in the input to the block.
     out_channels (int): Number of channels produced by the convolution of the input within the block.
     kernel_size (int): Size of the convolving kernel. Default = 3.
     stride (int) : Stride of the convolution. Default = 1.
     padding (int) : Padding added to all four sides of the input. Default = 0.
     """

     def __init__(self, in_channels: int, out_channels:int, kernel_size:int=3, stride:int=1, padding:int=0):
          super(CNNBlock, self).__init__()

          self.sequential_block = nn.Sequential(
               nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
               nn.BatchNorm2d(out_channels),
               nn.ReLU(inplace=True)
          )

     def forward(self, x):
          x = self.sequential_block(x)
          return x
     


class Encoder(nn.Module):

     # The encoder part of the UNet consists of the downsampling of the initial input through the application of 
     # a sequential series of convolution sets followed by a max pooling layer. The complete operation's objective 
     # is to capture the context and spatial information of the input image at different scales.

     """
     Parameters:
     in_channels (int): Number of input channels of the first CNNSet.
     out_channels (int): Number of output channels of the first CNNSet.
     padding (int): Padding applied in each convolution.
     levels (int): Number times a CNNSet + MaxPool2D layer is applied.
     """

     def __init__(self, input_channels:int, output_channels:int, pool_kernelsize: int, parameters: list):
          super(Encoder, self).__init__()
          self.encoder_layers = nn.ModuleList()
          levels = len(parameters)
          for level in range(levels-1):
               for conv in range(len(parameters[level])):
                    conv_kernelsize = parameters[level][conv][0]
                    conv_stride = parameters[level][conv][1]
                    self.encoder_layers.append(CNNBlock(in_channels=input_channels, out_channels=output_channels, kernel_size=conv_kernelsize, stride=conv_stride))
                    input_channels = output_channels
               output_channels *= 2
               self.encoder_layers.append(nn.MaxPool2d(pool_kernelsize))
          # A final convolution set is applied after all the levels, commonly referred to as the bottleneck
          for conv in range(len(parameters[-1])):
               conv_kernelsize = parameters[-1][conv][0]
               conv_stride = parameters[-1][conv][1]
               self.encoder_layers.append(CNNBlock(in_channels=input_channels, out_channels=output_channels, kernel_size=conv_kernelsize, stride=conv_stride))
               input_channels = output_channels

     def forward(self,x):
          residual_connection = []
          for i, layer in enumerate(self.encoder_layers):
               x = layer(x)
               # After the set CNN is processed, the result is logged to be sent in a connection to the decoder
               if i<len(self.encoder_layers)-1 and isinstance(self.encoder_layers[i+1], nn.MaxPool2d):
                    residual_connection.append(x)
               # If the processed layer is a pooling operation, the result is not logged
          return x, residual_connection



class Decoder(nn.Module):

     """
     Parameters:
     in_channels (int): Number of input channels of the first up-convolution layer.
     out_channels (int): Number of output channels of the first up-convolution layer.
     padding (int): Padding applied in each convolution.
     levels (int): number times an up-convolution + CNNSet is applied.
     """

     # After the encoder has downsampled the information, the decoder now applies an upsampling to match to original features.
     # This is achieved combining up-convolutions followed by convolution sets sequentially, achieving a recovery of the 
     # fine-grained spatial information lost during the downsampling in the encoder.

     def __init__(self, input_channels:int, exit_channels:int, uppool_kernelsize:int, parameters: list):
          super(Decoder, self).__init__()
          self.exit_channels = exit_channels
          self.decoder_layers = nn.ModuleList()

          levels = len(parameters)
          for level in range(levels-1):
               output_channels = int(input_channels/2)
               self.decoder_layers.append(nn.ConvTranspose2d(in_channels=input_channels, out_channels=output_channels, kernel_size=uppool_kernelsize, stride=uppool_kernelsize))
               for conv in range(len(parameters[level])):
                    conv_kernelsize = parameters[level][conv][0]
                    conv_stride = parameters[level][conv][1]
                    self.decoder_layers.append(CNNBlock(in_channels=input_channels, out_channels=output_channels, kernel_size=conv_kernelsize, stride=conv_stride))
                    input_channels = output_channels
          # A final convolution set without the ReLU activation function since the output will be passed through a BCELoss 
          self.decoder_layers.append(nn.Conv2d(in_channels=input_channels, out_channels=exit_channels, kernel_size=1))

     def forward(self, x, residual_connection):
          for i, layer in enumerate(self.decoder_layers):
               # After the previous output is up-sampled, the connection from the equivalent level is concatenated
               if i>0 and isinstance(self.decoder_layers[i-1], nn.ConvTranspose2d):
                    # First we center-crop the route tensor to make the size match
                    residual_connection[-1] = T.center_crop(residual_connection[-1], x.shape[2])
                    # Then we concatenate the tensors in the dimensions of the channels
                    x = torch.cat([x, residual_connection.pop(-1)], dim=1)
                    x = layer(x)
               # If the processed layer is an up-convolution operation, the connection is not performed
               else:
                    x = layer(x)
          return x


class UNetV2(nn.Module):

     """
     Parameters:
     in_channels (int): Number of input channels.
     first_out_channels (int): Number of output channels of the first convolution set.
     exit_channels (int): Number of output channels.
     levels (int): Number of levels for the encoder-decoder architecture.
     padding (int): Padding applied in each convolution operation.
     """

     # After the encoder has downsampled the information, the decoder now applies an upsampling to match to original features.
     # This is achieved combining up-convolutions followed by convolution sets sequentially, achieving a recovery of the 
     # fine-grained spatial information lost during the downsampling in the encoder.

     def __init__(self, in_channels, first_out_channels, exit_channels, pool_kernelsize, down_parameters, up_parameters, augment=False):
          super(UNetV2, self).__init__()
          levels = len(down_parameters)
          self.encoder = Encoder(input_channels=in_channels, output_channels=first_out_channels, pool_kernelsize=pool_kernelsize, parameters=down_parameters)
          self.decoder = Decoder(input_channels=first_out_channels*(2**(levels-1)), exit_channels=exit_channels, uppool_kernelsize=pool_kernelsize, parameters=up_parameters)
          self.augment = augment
        
     def forward(self, x):
          encoder_out, residuals = self.encoder(x)
          decoder_out = self.decoder(encoder_out, residuals)
          if self.augment: 
              return T.center_crop(decoder_out, (512,512))
          else:
              return T.center_crop(decoder_out, (256,256))

In [2]:
import torch
from PIL import Image
import os
import numpy as np
from torch.utils.data import Dataset, DataLoader
import random
import albumentations as A

def create_grid(nc: int, offset=0.5) -> torch.Tensor:
    grid = np.zeros((nc, nc, 2), dtype=np.float32)
    for ix in range(nc):
        for iy in range(nc):
            grid[ix, iy, 1] = -1 + 2 * (ix + 0.5) / nc + offset / 128
            grid[ix, iy, 0] = -1 + 2 * (iy + 0.5) / nc + offset / 128
    grid = torch.from_numpy(grid).unsqueeze(0)
    return grid

def img2tensor(img, dtype: np.dtype = np.float32):
    img = np.transpose(img, (2, 0, 1))
    tensor = torch.from_numpy(img.astype(dtype, copy=False))
    return tensor

class ContrailsDataset(Dataset):

     def __init__(self, path, use='train', soft_labels=False, only_positives=True, repeat=1, augment=False):
          if use == 'train' or use=='metrics':
               train = True
          else:
               train = False
          self.path = os.path.join(path, "train" if train else "validation", "images")
          if only_positives:
               positives_path = '/kaggle/input/positivess'
               positives_file = np.load(os.path.join(positives_path,"positive_train.npy" if train else "positive_validation.npy"))
               positives_fnames = [filename.split("\\")[3] for filename in positives_file]
               if use == 'train' or use == 'cross-validate':
                    self.filenames = [filename.split(".")[0] for filename in os.listdir(self.path) if filename.split(".")[0] in positives_fnames]
               elif use == 'metrics':
                    self.filenames = random.sample([filename.split(".")[0] for filename in os.listdir(self.path) if filename.split(".")[0] in positives_fnames], 500)
          else:
               if use == 'train' or use == 'cross-validate':
                    self.filenames = [filename.split(".")[0] for filename in os.listdir(self.path)]
               elif use == 'metrics':
                    self.filenames = random.sample([filename.split(".")[0] for filename in os.listdir(self.path)],500)
          self.train = train
          self.nc = 3
          self.repeat = repeat
          self.soft_labels = soft_labels
          self.augment = augment 
          if self.augment:
               self.grid = create_grid(512, offset=0.5)

     def __len__(self):
          return self.repeat * len(self.filenames)
     
     def __getitem__(self, index):
          index = index % len(self.filenames)
          try:
               image = np.array(Image.open(os.path.join(self.path, self.filenames[index] + '.png')))
               if self.soft_labels:
                    mask  = np.load(os.path.join(self.path.replace('images','soft_label'), self.filenames[index] + '.npy'))
               else:
                    mask  = np.load(os.path.join(self.path.replace('images','ground_truth'), self.filenames[index] + '.npy'))
               image_tensor, mask_tensor = img2tensor(image/255), img2tensor(mask)   # Sizes 3x256x256 and 1x256x256  
               if self.augment:
                    transform = A.Compose([A.RandomRotate90(p=0.05),A.HorizontalFlip(p=0.05)])
                    image_tensor = T.resize(image_tensor,512)
                    mask_sym = F.grid_sample(mask_tensor.unsqueeze(0), self.grid, mode='bilinear', padding_mode='border', 
                                             align_corners=False).squeeze(0)
                    image_np = image_tensor.permute(1, 2, 0).numpy()
                    mask_np = mask_sym.permute(1, 2, 0).numpy()
                    # Apply augmentation
                    aug = transform(image=image_np, mask=mask_np)
                    transformed_image = aug['image'].transpose(2, 0, 1)
                    transformed_mask = aug['mask'].transpose(2, 0, 1)
                    image_tensor = torch.from_numpy(transformed_image)
                    mask_tensor = torch.from_numpy(transformed_mask)
               return image_tensor, mask_tensor
          except Exception as e:
               print(f"\n Error loading file: {e} \n")
               return None, None

In [3]:
import numpy as np
import matplotlib.pyplot as plt

def validate(net, usage, device, pad, test_batch_size=50, threshold=0.5):    

     testset = ContrailsDataset(path='/kaggle/input/opencontrails-png/SingleFrame_PNG', use=usage, augment=False, only_positives=False)
     testloader = torch.utils.data.DataLoader(testset,batch_size=test_batch_size, shuffle=True, num_workers=0)

     criterion = nn.BCEWithLogitsLoss()

     # Positive pixels labelled as positive
     TP = 0
     # Negative pixels labelled as negative
     TN = 0
     # Positive pixels labelled as negative
     FN = 0
     # Negative pixels labelled as positive
     FP = 0

     running_loss = 0

     with torch.no_grad():
          for i, data in enumerate(testloader):
               images, labels = data
               images, labels = images.to(device), labels.to(device)
               images = F.pad(images,(pad,pad,pad,pad), mode='reflect')
               outputs = torch.sigmoid(net(images))
               outputs = T.center_crop(outputs, (256,256))
               outputs = outputs.view(-1,1,256,256)
               binary_outputs = (outputs > threshold).float()

               loss = criterion(outputs, labels)
               running_loss += loss.item()

               TP += torch.sum((binary_outputs == 1) & (labels == 1)).item()
               TN += torch.sum((binary_outputs == 0) & (labels == 0)).item()
               FN += torch.sum((binary_outputs == 0) & (labels == 1)).item()
               FP += torch.sum((binary_outputs == 1) & (labels == 0)).item()

               print(f'Processing batch {i+1}/{len(testloader)}', end='\r')

     # Pixel Accuracy
     PA = TP/(TP+TN+FP+FN) if (TP+TN+FP+FN)>0 else '-'
     # Jaccard Coefficient
     IoU = TP/(TP+FP+FN) if (TP+FP+FN)>0 else '-'
     # Precision
     precision = TP/(TP+FP) if (TP+FP)>0 else '-'
     # Recall
     recall = TP/(TP+FN) if (TP+FN)>0 else '-'
     # F1 Score
     F1 = 2*(precision*recall/(precision+recall)) if ((TP+FP)>0 and (TP+FN)>0 and TP>0) else '-'
     # Dice Coefficient
     dice = 2*TP/(2*TP+FP+FN) if (TP+FP+FN)>0 else '-'

     return running_loss/len(testset), PA, IoU, precision, recall, F1, dice

In [4]:
# !pip install segmentation-models-pytorch

# import segmentation_models_pytorch as smp

In [5]:
def parameter_generator(levels, convs, kernel_size):
     parameters = [[None for _ in range(convs)] for _ in range(levels)]
     for level in range(levels):
          for conv in range(convs):
               if isinstance(kernel_size,list):
                    parameters[level][conv] = [(kernel_size[level], kernel_size[level]), 1]
               else:
                    parameters[level][conv] = [(kernel_size, kernel_size), 1]
     return parameters


levs = 5
cons = 4
kers = 3
path2trained = '/kaggle/input/trained0705/UNETv2_5431_Positives_20epoch.pth'
pad = 188

state_dict = torch.load(path2trained)
downparameters = parameter_generator(levs,cons,kers)
upparameters = parameter_generator(levs,cons,kers)

In [6]:
net = UNetV2(in_channels=3, first_out_channels=64, exit_channels=1, pool_kernelsize=2, down_parameters=downparameters, up_parameters=upparameters, augment=False).to('cuda')
# net = smp.Unet('tu-maxxvit_rmlp_small_rw_256.sw_in1k',classes=1, encoder_depth=4, decoder_channels=[512,256,128,64]).to('cuda')
# net = smp.Unet('tu-coatnet_rmlp_2_rw_384.sw_in12k_ft_in1k',classes=1, encoder_depth=4, decoder_channels=[512,256,128,64]).to('cuda') 
net.load_state_dict(state_dict)
net.eval()

_, PA, IoU, precision, recall, F1, dice = validate(net, pad=pad, device='cuda', usage='cross-validate', test_batch_size=3)

print(f'Pixel accuracy: {PA}')
print(f'Jaccard Index: {IoU}')
print(f'Precision: {precision}')
print(f'Recall: {recall}')
print(f'F1 Score: {F1}')
print(f'Dice coefficient: {dice}', end='\n')

Pixel accuracy: 0.0010508411335495283
Jaccard Index: 0.28538557769636314
Precision: 0.3599159304002885
Recall: 0.5795069994465765
F1 Score: 0.444046646622291
Dice coefficient: 0.444046646622291
