# Multi-Temporal Crop Image Segmentation

## Ideas

Data
- Data augmentation for segmentation

- Compute certain features as inputs [Link](https://towardsdatascience.com/satellite-imagery-analysis-using-python-9f389569862c)
    - NDVI = (NIR - RED) / (NIR + RED)
    - EVI
    - RVI = NIR / RED
    - NIR + SWIR1 + Blue (Healthy Vegetation Band Combination)

Approaches
- Same chip different timestamp -> same features (SSL)

Model
- U-Net
- 

Others
- torchgeo

In [2]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from osgeo import gdal
import os
import sys
import copy
import time

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision
import torch.distributed as dist
import torchmetrics
from torchinfo import summary

In [3]:
module_path = os.path.abspath(os.path.join('./code'))
sys.path.insert(0, module_path)

In [4]:
n_classes = 14

device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
device

device(type='mps')

## Data

### Data Augmentations

In [5]:
from transforms import SegmentationTrainTransform, SegmentationValTransform

transform_train = SegmentationTrainTransform()
transform_val = SegmentationValTransform()

### Dataset

In [6]:
class GeoCropDataset(Dataset):
    def __init__(self, root='data/', is_train=True, transform=None):
        super().__init__()
        file_name = 'training_data.txt' if is_train else 'validation_data.txt'
        with open(os.path.join(root, file_name)) as f:
            chip_list = [line.rstrip() for line in f]
        self.img_list = [os.path.join(root, 'hls', chip + '_merged.tif') for chip in chip_list]
        self.mask_list = [os.path.join(root, 'masks', chip + '.mask.tif') for chip in chip_list]
        self.transform = transform
        

    def __len__(self):
        return len(self.img_list)
    
    
    def __getitem__(self, index):
        img_path = self.img_list[index]
        mask_path = self.mask_list[index]
        imgs = gdal.Open(img_path, gdal.GA_ReadOnly).ReadAsArray()
        mask = gdal.Open(mask_path, gdal.GA_ReadOnly).ReadAsArray()

        assert imgs.shape == (18, 224, 224)
        assert mask.shape == (224, 224)

        if self.transform:
            imgs, mask = self.transform(imgs, mask)

        return imgs, mask


dataset_train = GeoCropDataset(root='data/', is_train=True, transform=transform_train)
loader_train = DataLoader(
    dataset_train,
    batch_size=16,
    shuffle=True
)

dataset_val = GeoCropDataset(root='data/', is_train=False, transform=transform_val)
loader_val = DataLoader(
    dataset_val,
    batch_size=16,
    shuffle=False
)

imgs, mask = next(iter(loader_train))
imgs.shape

torch.Size([16, 18, 224, 224])

## Model

### U-Net Baseline

In [10]:
class DoubleConvolution(nn.Module):
    """
    ### Two 3x3 Convolution Layers.
    In the U-Net paper they used $0$ padding, but we use $1$ padding so that final feature map is not cropped.
    """

    def __init__(self, in_channels: int, out_channels: int):
        super().__init__()
        self.first = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.act1 = nn.ReLU()
        self.second = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.act2 = nn.ReLU()

    def forward(self, x: torch.Tensor):
        x = self.first(x)
        x = self.act1(x)
        x = self.second(x)
        return self.act2(x)


class DownSample(nn.Module):
    """
    ### Down-sample

    Each step in the contracting path down-samples the feature map with
    a $2 \times 2$ max pooling layer.
    """

    def __init__(self):
        super().__init__()
        self.pool = nn.MaxPool2d(2)

    def forward(self, x: torch.Tensor):
        return self.pool(x)


class UpSample(nn.Module):
    """
    ### Up-sample

    Each step in the expansive path up-samples the feature map with
    a $2 \times 2$ up-convolution.
    """
    def __init__(self, in_channels: int, out_channels: int):
        super().__init__()

        # Up-convolution
        self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)

    def forward(self, x: torch.Tensor):
        return self.up(x)


class CropAndConcat(nn.Module):
    """
    ### Crop and Concatenate the feature map

    At every step in the expansive path the corresponding feature map from the contracting path
    concatenated with the current feature map.
    """
    def forward(self, x: torch.Tensor, contracting_x: torch.Tensor):
        """
        :param x: current feature map in the expansive path
        :param contracting_x: corresponding feature map from the contracting path
        """

        # Crop the feature map from the contracting path to the size of the current feature map
        contracting_x = torchvision.transforms.functional.center_crop(contracting_x, [x.shape[2], x.shape[3]])
        # Concatenate the feature maps
        x = torch.cat([x, contracting_x], dim=1)
        #
        return x


class UNet(nn.Module):
    """
    ## U-Net
    """
    def __init__(self, in_channels=18, out_channels=14):
        """
        :param in_channels: number of channels in the input image
        :param out_channels: number of channels in the result feature map
        """
        super().__init__()

        # Double convolution layers for the contracting path.
        # The number of features gets doubled at each step starting from $64$.
        self.down_conv = nn.ModuleList([DoubleConvolution(i, o) for i, o in
                                        [(in_channels, 64), (64, 128), (128, 256), (256, 512)]])
        # Down sampling layers for the contracting path
        self.down_sample = nn.ModuleList([DownSample() for _ in range(4)])

        # The two convolution layers at the lowest resolution (the bottom of the U).
        self.middle_conv = DoubleConvolution(512, 1024)

        # Up sampling layers for the expansive path.
        # The number of features is halved with up-sampling.
        self.up_sample = nn.ModuleList([UpSample(i, o) for i, o in
                                        [(1024, 512), (512, 256), (256, 128), (128, 64)]])
        # Double convolution layers for the expansive path.
        # Their input is the concatenation of the current feature map and the feature map from the
        # contracting path. Therefore, the number of input features is double the number of features
        # from up-sampling.
        self.up_conv = nn.ModuleList([DoubleConvolution(i, o) for i, o in
                                      [(1024, 512), (512, 256), (256, 128), (128, 64)]])
        # Crop and concatenate layers for the expansive path.
        self.concat = nn.ModuleList([CropAndConcat() for _ in range(4)])

        self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1)


    def forward(self, x: torch.Tensor):
        """
        :param x: input image
        """
        # To collect the outputs of contracting path for later concatenation with the expansive path.
        pass_through = []
        # Contracting path
        for i in range(len(self.down_conv)):
            x = self.down_conv[i](x)
            pass_through.append(x)
            x = self.down_sample[i](x)

        x = self.middle_conv(x)

        # Expansive path
        for i in range(len(self.up_conv)):
            x = self.up_sample[i](x)
            x = self.concat[i](x, pass_through.pop())
            x = self.up_conv[i](x)

        x = self.final_conv(x)
        return x
    

model = UNet(in_channels=18, out_channels=n_classes)
model(imgs.float()).shape

summary(model, input_size=(1, 18, 224, 224))

Layer (type:depth-idx)                   Output Shape              Param #
UNet                                     [1, 14, 224, 224]         --
├─ModuleList: 1-7                        --                        (recursive)
│    └─DoubleConvolution: 2-1            [1, 64, 224, 224]         --
│    │    └─Conv2d: 3-1                  [1, 64, 224, 224]         10,432
│    │    └─ReLU: 3-2                    [1, 64, 224, 224]         --
│    │    └─Conv2d: 3-3                  [1, 64, 224, 224]         36,928
│    │    └─ReLU: 3-4                    [1, 64, 224, 224]         --
├─ModuleList: 1-8                        --                        --
│    └─DownSample: 2-2                   [1, 64, 112, 112]         --
│    │    └─MaxPool2d: 3-5               [1, 64, 112, 112]         --
├─ModuleList: 1-7                        --                        (recursive)
│    └─DoubleConvolution: 2-3            [1, 128, 112, 112]        --
│    │    └─Conv2d: 3-6                  [1, 128, 112, 112]

In [14]:
class ConvBlock(nn.Module):
    """This module creates a user-defined number of conv+BN+ReLU layers.
    Args:
        in_channels (int)-- number of input features.
        out_channels (int) -- number of output features.
        kernel_size (int) -- Size of convolution kernel.
        stride (int) -- decides how jumpy kernel moves along the spatial dimensions.
        padding (int) -- how much the input should be padded on the borders with zero.
        dilation (int) -- dilation ratio for enlarging the receptive field.
        num_conv_layers (int) -- Number of conv+BN+ReLU layers in the block.
        drop_rate (float) -- dropout rate at the end of the block.
    """

    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,
                 padding=1, dilation=1, num_conv_layers=2, drop_rate=0):
        super(ConvBlock, self).__init__()

        layers = [nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size,
                            stride=stride, padding=padding, dilation=dilation, bias=False),
                  nn.BatchNorm2d(out_channels),
                  nn.ReLU(inplace=True), ]

        if num_conv_layers > 1:
            if drop_rate > 0:
                layers += [nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size,
                                     stride=stride, padding=padding, dilation=dilation, bias=False),
                           nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True),
                           nn.Dropout(drop_rate), ] * (num_conv_layers - 1)
            else:
                layers += [nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size, stride=stride,
                                     padding=padding, dilation=dilation, bias=False),
                           nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), ] * (num_conv_layers - 1)

        self.block = nn.Sequential(*layers)

    def forward(self, inputs):
        outputs = self.block(inputs)
        return outputs


class DUC(nn.Module):
    """
    Dense Upscaling Convolution (DUC) layer.
        
    Args:
        in_channels (int): Number of input channels.
        out_channels (int): Number of output channels.
        upscale (int): Upscaling factor.
    
    Returns:
        torch.Tensor: Output tensor after applying DUC.
    """
    def __init__(self, in_channels, out_channles, upscale):
        super(DUC, self).__init__()
        out_channles = out_channles * (upscale ** 2)
        self.conv = nn.Conv2d(in_channels, out_channles, 1, bias=False)
        self.bn = nn.BatchNorm2d(out_channles)
        self.relu = nn.ReLU(inplace=True)
        self.pixl_shf = nn.PixelShuffle(upscale_factor=upscale)

        kernel = self.icnr(self.conv.weight, scale=upscale)
        self.conv.weight.data.copy_(kernel)

    def forward(self, x):
        x = self.relu(self.bn(self.conv(x)))
        x = self.pixl_shf(x)
        return x

    def icnr(self, x, scale=2, init=nn.init.kaiming_normal):
        """
        ICNR (Initialization from Corresponding Normalized Response) function.
        
        Args:
            x (torch.Tensor): Input tensor.
            scale (int): Upscaling factor.
            init (function): Initialization function.
            
        Returns:
            torch.Tensor: Initialized kernel.
        Note:
            Even with pixel shuffle we still have check board artifacts,
            the solution is to initialize the d**2 feature maps with the same
            radom weights: https://arxiv.org/pdf/1707.02937.pdf
        """

        new_shape = [int(x.shape[0] / (scale ** 2))] + list(x.shape[1:])
        subkernel = torch.zeros(new_shape)
        subkernel = init(subkernel)
        subkernel = subkernel.transpose(0, 1)
        subkernel = subkernel.contiguous().view(subkernel.shape[0],
                                                subkernel.shape[1], -1)
        kernel = subkernel.repeat(1, 1, scale ** 2)
        transposed_shape = [x.shape[1]] + [x.shape[0]] + list(x.shape[2:])
        kernel = kernel.contiguous().view(transposed_shape)
        kernel = kernel.transpose(0, 1)
        return kernel


class UpconvBlock(nn.Module):
    """
    Decoder layer decodes the features along the expansive path.
    Args:
        in_channels (int) -- number of input features.
        out_channels (int) -- number of output features.
        upmode (str) -- Upsampling type. If "fixed" then a linear upsampling with scale factor
                        of two will be applied using bi-linear as interpolation method.
                        If deconv_1 is chosen then a non-overlapping transposed convolution will
                        be applied to upsample the feature maps. If deconv_1 is chosen then an
                        overlapping transposed convolution will be applied to upsample the feature maps.
    """

    def __init__(self, in_channels, out_channels, upmode="deconv_1"):
        super(UpconvBlock, self).__init__()

        if upmode == "fixed":
            layers = [nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True), ]
            layers += [nn.BatchNorm2d(in_channels),
                       nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False), ]

        elif upmode == "deconv_1":
            layers = [nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2, padding=0, dilation=1), ]

        elif upmode == "deconv_2":
            layers = [nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, dilation=1), ]

        # Dense Upscaling Convolution
        elif upmode == "DUC":
            up_factor = 2
            upsample_dim = (up_factor ** 2) * out_channels
            layers = [nn.Conv2d(in_channels, upsample_dim, kernel_size=3, padding=1),
                      nn.BatchNorm2d(upsample_dim),
                      nn.ReLU(inplace=True),
                      nn.PixelShuffle(up_factor), ]
            
            #layers = [DUC(in_channels, out_channels, upscale=2)]

        else:
            raise ValueError("Provided upsampling mode is not recognized.")

        self.block = nn.Sequential(*layers)

    def forward(self, inputs):
        return self.block(inputs)


class AdditiveAttentionBlock(nn.Module):
    r"""
    additive attention gate (AG) to merge feature maps extracted at multiple scales through skip connection.

    Args:
        f_g (int) -- number of feature maps collected from the higher resolution in encoder path.
        f_x (int) -- number of feature maps in layer "x" in the decoder.
        f_inter (int) -- number of feature maps after summation equal to the number of
                       learnable multidimensional attention coefficients.

    Note: Unlike the original paper we upsample
    """

    def __init__(self, F_g, F_x, F_inter):
        super(AdditiveAttentionBlock, self).__init__()

        # Decoder
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_inter, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_inter)
        )
        # Encoder
        self.W_x = nn.Sequential(
            nn.Conv2d(F_x, F_inter, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_inter)
        )

        # Fused
        self.psi = nn.Sequential(
            nn.Conv2d(F_inter, 1, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )

        self.relu = nn.ReLU(inplace=True)

    def forward(self, g, x):
        # set_trace()
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        merge = self.relu(g1 + x1)
        psi = self.psi(merge)

        return x * psi


class Unet(nn.Module):
    def __init__(self, n_classes, in_channels, filter_config=None, use_skipAtt=False, dropout_rate=0):
        """
        UNet model with optional additive attention between skip connections 
        for semantic segmentation of multispectral satellite images.

        Args:
            n_classes (int): Number of output classes.
            in_channels (int): Number of input channels.
            filter_config (tuple, optional): Configuration of filters in the contracting path.
                        Default is None, which uses the configuration (64, 128, 256, 512, 1024, 2048).
            use_skipAtt (bool, optional): Flag indicating whether to use skip connections with attention.
                        Default is False.
            dropout_rate (float, optional): Dropout rate applied to the convolutional layers.
                        Default is 0.

        """
        super(Unet, self).__init__()

        self.in_channels = in_channels
        self.use_skipAtt = use_skipAtt

        if not filter_config:
            filter_config = (64, 128, 256, 512, 1024, 2048)

        assert len(filter_config) == 6

        # Contraction Path
        self.encoder_1 = ConvBlock(self.in_channels, filter_config[0], num_conv_layers=2,
                                   drop_rate=dropout_rate)  # 64x224x224
        self.encoder_2 = ConvBlock(filter_config[0], filter_config[1], num_conv_layers=2,
                                   drop_rate=dropout_rate)  # 128x112x112
        self.encoder_3 = ConvBlock(filter_config[1], filter_config[2], num_conv_layers=2,
                                   drop_rate=dropout_rate)  # 256x56x56
        self.encoder_4 = ConvBlock(filter_config[2], filter_config[3], num_conv_layers=2,
                                   drop_rate=dropout_rate)  # 512x28x28
        self.encoder_5 = ConvBlock(filter_config[3], filter_config[4], num_conv_layers=2,
                                   drop_rate=dropout_rate)  # 1024x14x14
        self.encoder_6 = ConvBlock(filter_config[4], filter_config[5], num_conv_layers=2,
                                   drop_rate=dropout_rate)  # 2048x7x7
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        # Expansion Path
        self.decoder_1 = UpconvBlock(filter_config[5], filter_config[4], upmode="deconv_2")  # 1024x14x14
        self.conv1 = ConvBlock(filter_config[4] * 2, filter_config[4], num_conv_layers=2, drop_rate=dropout_rate)

        self.decoder_2 = UpconvBlock(filter_config[4], filter_config[3], upmode="deconv_2")  # 512x28x28
        self.conv2 = ConvBlock(filter_config[3] * 2, filter_config[3], num_conv_layers=2, drop_rate=dropout_rate)

        self.decoder_3 = UpconvBlock(filter_config[3], filter_config[2], upmode="deconv_2")  # 256x56x56
        self.conv3 = ConvBlock(filter_config[2] * 2, filter_config[2], num_conv_layers=2, drop_rate=dropout_rate)

        self.decoder_4 = UpconvBlock(filter_config[2], filter_config[1], upmode="deconv_2")  # 128x112x112
        self.conv4 = ConvBlock(filter_config[1] * 2, filter_config[1], num_conv_layers=2, drop_rate=dropout_rate)

        self.decoder_5 = UpconvBlock(filter_config[1], filter_config[0], upmode="deconv_2")  # 64x224x224
        self.conv5 = ConvBlock(filter_config[0] * 2, filter_config[0], num_conv_layers=2, drop_rate=dropout_rate)

        if self.use_skipAtt:
            self.Att1 = AdditiveAttentionBlock(F_g=filter_config[4], F_x=filter_config[4], F_inter=filter_config[3])
            self.Att2 = AdditiveAttentionBlock(F_g=filter_config[3], F_x=filter_config[3], F_inter=filter_config[2])
            self.Att3 = AdditiveAttentionBlock(F_g=filter_config[2], F_x=filter_config[2], F_inter=filter_config[1])
            self.Att4 = AdditiveAttentionBlock(F_g=filter_config[1], F_x=filter_config[1], F_inter=filter_config[0])
            self.Att5 = AdditiveAttentionBlock(F_g=filter_config[0], F_x=filter_config[0],
                                               F_inter=int(filter_config[0] / 2))

        self.classifier = nn.Conv2d(filter_config[0], n_classes, kernel_size=1, stride=1, padding=0)  # classNumx224x224

    def forward(self, inputs):
        """
        Forward pass of the UNet model.

        Args:
            inputs (torch.Tensor): Input tensor of shape (batch_size, in_channels, height, width).

        Returns:
            torch.Tensor: Output tensor of shape (batch_size, n_classes, height, width).

        """
        e1 = self.encoder_1(inputs)  # batch size x 64 x 224 x 224
        p1 = self.pool(e1)  # batch size x 64 x 112 x 112

        e2 = self.encoder_2(p1)  # batch size x 128 x 112 x 112
        p2 = self.pool(e2)  # batch size x 128 x 56 x 56

        e3 = self.encoder_3(p2)  # batch size x 256 x 56 x 56
        p3 = self.pool(e3)  # batch size x 256 x 28 x 28

        e4 = self.encoder_4(p3)  # batch size x 512 x 28 x 28
        p4 = self.pool(e4)  # batch size x 1024 x 14 x 14

        e5 = self.encoder_5(p4)  # batch size x 1024 x 14 x 14
        p5 = self.pool(e5)  # batch size x 1024 x 7 x 7

        e6 = self.encoder_6(p5)  # batch size x 2048 x 7 x 7

        d6 = self.decoder_1(e6)  # batch size x 1024 x 14 x 14

        if self.use_skipAtt:
            x5 = self.Att1(g=d6, x=e5)  # batch size x 1024 x 14 x 14
            skip1 = torch.cat((x5, d6), dim=1)  # batch size x 2048 x 14 x 14
        else:
            skip1 = torch.cat((e5, d6), dim=1)  # batch size x 2048 x 14 x 14

        d6_proper = self.conv1(skip1)  # batch size x 1024 x 14 x 14

        d5 = self.decoder_2(d6_proper)  # batch size x 512 x 28 x 28

        if self.use_skipAtt:
            x4 = self.Att2(g=d5, x=e4)  # batch size x 512 x 28 x 28
            skip2 = torch.cat((x4, d5), dim=1)  # batch size x 1024 x 28 x 28
        else:
            skip2 = torch.cat((e4, d5), dim=1)  # batch size x 1024 x 28 x 28

        d5_proper = self.conv2(skip2)  # batch size x 512 x 28 x 28

        d4 = self.decoder_3(d5_proper)  # batch size x 256 x 56 x 56

        if self.use_skipAtt:
            x3 = self.Att3(g=d4, x=e3)  # batch size x 256 x 56 x 56
            skip3 = torch.cat((x3, d4), dim=1)  # batch size x 512 x 56 x 56
        else:
            skip3 = torch.cat((e3, d4), dim=1)  # batch size x 512 x 56 x 56

        d4_proper = self.conv3(skip3)  # batch size x 256 x 56 x 56

        d3 = self.decoder_4(d4_proper)  # batch size x 128 x 112 x 112

        if self.use_skipAtt:
            x2 = self.Att4(g=d3, x=e2)  # batch size x 128 x 112 x 112
            skip4 = torch.cat((x2, d3), dim=1)  # batch size x 256 x 112 x 112
        else:
            skip4 = torch.cat((e2, d3), dim=1)  # batch size x 256 x 112 x 112

        d3_proper = self.conv4(skip4)  # batch size x 128 x 112 x 112

        d2 = self.decoder_5(d3_proper)  # batch size x 64 x 224 x 224

        if self.use_skipAtt:
            x1 = self.Att5(g=d2, x=e1)  # batch size x 64 x 224 x 224
            skip5 = torch.cat((x1, d2), dim=1)  # batch size x 128 x 224 x 224
        else:
            skip5 = torch.cat((e1, d2), dim=1)  # batch size x 128 x 224 x 224

        d2_proper = self.conv5(skip5)  # batch size x 64 x 224 x 224

        d1 = self.classifier(d2_proper)  # batch size x classNum x 224 x 224

        return d1


## Training

In [9]:
def train_model(
        model, 
        criterion, 
        optimizer, 
        dataloaders, 
        num_epochs, 
        device = 'cpu',
):
    t_start = time.time()
    model.to(device)

    model_without_ddp = model
    best_model_weights = copy.deepcopy(model_without_ddp.state_dict())

    best_jac = 0.0
    losses = {'train': [], 'val': []}
    jaccard_indices = {'train': [], 'val': []}
    jaccard = torchmetrics.JaccardIndex(task="multiclass", num_classes=14)

    for epoch in range(num_epochs):
        t1 = time.time()

        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  
            else:
                model.eval()

            running_loss = 0.0
            running_indices = 0.0
            masks_pred = np.zeros((0, 224, 224))
            masks = np.zeros((0, 224, 224))

            # Iterate over data
            for inputs, targets in dataloaders[phase]:
                inputs = inputs.to(device)
                targets = targets.to(device)

                optimizer.zero_grad()
                with torch.set_grad_enabled(phase=='train'):
                    outputs = model(inputs.float())
                    loss = criterion(outputs, targets.long())
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                
                with torch.no_grad():
                    out = outputs
                    masks_batch = out.cpu().detach().numpy().argmax(1)
                    masks_pred = np.concatenate([masks_pred, masks_batch], axis=0)
                    masks = np.concatenate([masks, targets.cpu().detach().numpy()], axis=0)
                    running_loss += loss.item() * inputs.size(0)


            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            jac = jaccard(torch.Tensor(masks_pred), torch.Tensor(masks))

            loss_avg = epoch_loss
            jac_avg = jac

            losses[phase].append(loss_avg)
            jaccard_indices[phase].append(jac_avg)

            if phase == 'train':
                print(f'Epoch [{epoch+1:4d}/{num_epochs:4d}]   Train Loss: {loss_avg:.3e}, Jaccard: {jac_avg:.4f}', end='')
            else:
                print(f'   Val Loss: {loss_avg:.3e}, Jaccard: {jac_avg:.4f}   Time: {(time.time()-t1):.0f}s')

            # save best model
            if phase == 'val' and jac_avg > best_jac:
                best_jac = jac_avg
                best_model_weights = copy.deepcopy(model_without_ddp.state_dict())
                torch.save(best_model_weights, './checkpoints/model_unet_best.pth')

    time_elapsed = time.time() - t_start
    print(f'Training completed in {time_elapsed // 60:.0f} min {time_elapsed % 60:.0f} s')
    print(f'Best Validation Jaccard: {best_jac:.4f}')

    return losses, jaccard_indices



model = UNet(in_channels=18, out_channels=n_classes)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
dataloaders = {
    "train": loader_train, 
    "val": loader_val
}
num_epochs = 3

losses, jaccard_indices = train_model(
    model, 
    criterion, 
    optimizer, 
    dataloaders,
    num_epochs,
    device
)

KeyboardInterrupt: 

### Evaluation

In [16]:
from utils import do_accuracy_evaluation, Evaluator
import torch.nn.functional as F


model = UNet()
model.load_state_dict(torch.load('./checkpoints/model_unet_best.pth'))
model.to(device)

class_mapping = {
  "0": "Unknown",
  "1": "Natural Vegetation",
  "2": "Forest",
  "3": "Corn",
  "4": "Soybeans",
  "5": "Wetlands",
  "6": "Developed/Barren",
  "7": "Open Water",
  "8": "Winter Wheat",
  "9": "Alfalfa",
  "10": "Fallow/Idle Cropland",
  "11": "Cotton",
  "12": "Sorghum",
  "13": "Other"
}

# do_accuracy_evaluation(model, loader_val, n_classes, class_mapping, device)

evaluator = Evaluator(n_classes)

model.eval()
    # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
with torch.no_grad():
  for data in loader_val:
    images, labels = data
    images = images.to(device)
    labels = labels.to(device)

    outputs = model(images)
    if torch.isnan(outputs).any():
        print("NaN value found in model outputs!")
    outputs = F.softmax(outputs, 1)
    _, preds = torch.max(outputs.data, 1)

            # add batch to evaluator
    evaluator.add_batch(labels.cpu().numpy(), 
                                preds.cpu().numpy())

    # calculate evaluation metrics
    overall_accuracy = evaluator.overall_accuracy()
    classwise_overal_accuracy = evaluator.classwise_overal_accuracy()
    mean_accuracy = np.nanmean(classwise_overal_accuracy)
    IoU = evaluator.intersection_over_union()
    mean_IoU = np.nanmean(IoU)
    precision = evaluator.precision()
    mean_precision = np.nanmean(precision)
    recall = evaluator.recall()
    mean_recall = np.nanmean(recall)
    f1_score = evaluator.f1_score()
    mean_f1_score = np.nanmean(f1_score)

    metrics = {
        "Overall Accuracy": overall_accuracy,
        "Mean Accuracy": mean_accuracy,
        "Mean IoU": mean_IoU,
        "mean Precision": mean_precision,
        "mean Recall": mean_recall,
        "Mean F1 Score": mean_f1_score
    }

    # print confusion matrix
    # evaluator.plot_confusion_matrix()

  acc = np.diag(self.confusion_matrix) / self.confusion_matrix.sum(axis=1)
  tp / (tp + fp + fn),
  tp / (tp + fp),
  tp / (tp + fn),
  2 * (precision * recall) / (precision + recall),


In [17]:
metrics

{'Overall Accuracy': 0.5262502967505095,
 'Mean Accuracy': 0.47177132954687007,
 'Mean IoU': 0.3000807281347857,
 'mean Precision': 0.4651827487605651,
 'mean Recall': 0.4380733774363793,
 'Mean F1 Score': 0.4388545105428932}

## Inference