# Training a 3D Convolutional Neural Network for hyperspectral data classification

In this notebook, you will train and apply a three-dimensional convolutional neural network for classification of hyperspectral data from Luční Hora, Krkonoše mountains, Czechia.

Pavia city centre is a common benchmark for hyperspectral data classification and can be obtained from http://www.ehu.eus/ccwintco/index.php/Hyperspectral_Remote_Sensing_Scenes#Pavia_Centre_and_University

Our dataset from Luční Hora is currently not publicly available, but we are working on providing it in the future.

First, we need to import external libraries:

- __torch, torch.nn, torch.optim, torchnet__ - Pytorch related libraries for deep learning
- __numpy__ - Arrays to hold our data
- __matplotlib.pyplot__ - Draw images
- __sklearn.model_selection__ - Cross-validation implemented in scikit-learn
- __time.perf_counter__ - Track how long individual functions take to run
- __os.path__ - Path manipulation
- __tqdm__ - show progress bars during training

- __image_preprocessing__ - Our library holding functions for image tiling, preprocessing, etc.
- __inference_utils__ - Our library for correctly exporting classifed images
- __visualisation_utils__ - Our library for visualising the data

Two external libraries are not imported directly in this notebook, but are used by functions in _image_preprocessing_ and _inference_utils_:

- __gdal__ - Manipulates spatial data
- __scipy.io__ - Reads .mat files

In [None]:
import torch
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.optim as optim
from time import perf_counter
from os.path import join

from sklearn.metrics import precision_recall_fscore_support, accuracy_score, jaccard_score
import torchnet as tnt

from sklearn.model_selection import KFold, StratifiedKFold
from tqdm import notebook as tqdm

import image_preprocessing
import inference_utils
import visualisation_utils

# GLOBAL SETTINGS
matplotlib.rcParams['figure.figsize'] = [5, 5]  
np.set_printoptions(precision=2, suppress=True)  # Array print precision

Please fill correct paths to your training and reference rasters:

In [None]:
# PATHS TO TRAINING DATA
trainingdata_path = '../data/LH_202008_54bands_9cm.tif'
referencedata_path = '../data/LH_202008_reference.tif'

trainingdata_path = '../../data/Pavia_centre/Pavia.mat'
referencedata_path = '../../data/Pavia_centre/Pavia_gt.mat'

## 1. Loading and preprocessing training data

### 1.1. Data loading into NumPy
Let's start by reading an image into a numpy array, we do this in the background using GDAL.

The result of our function is a dictionary named loaded_raster, which contains two numpy arrays under keys imagery and reference. As we can see, the loaded hyperspectral dataset has 1847 by 1563 pixels with 54 spectral bands. The raster containing our reference data has the same dimensions in height and width.

In [None]:
# loaded_raster = image_preprocessing.read_gdal(trainingdata_path, referencedata_path)
loaded_raster = image_preprocessing.read_pavia_centre(trainingdata_path, referencedata_path, out_shape=(1088, 1088, 102))

print(f'Tiled imagery shape {loaded_raster["imagery"].shape}')
print(f'Tiled reference shape {loaded_raster["reference"].shape}')

In [None]:
visualisation_utils.show_img_ref(loaded_raster['imagery'][:, :, [25, 15, 5]], loaded_raster['reference'])

### 1.2. Image tiling
We have our data loaded into a numpy array, the next step is to divide the image into individual tiles, which will be the input for our neural network.

As we want to perform convolution only in the spatial dimensions, we need to divide the hyperspectral image into tiles of a given shape. Standardly used tile sizes are multiplies of two, for example 2^8 = 256. This tile shape is ensured by setting the variable _tile_shape_ as (256, 256).

_overlap_ and _offset_ are not needed for one-dimensional processing.

This process creates 143 tiles of 256 by 256 pixels, with the same amount of spectral bands as earlier.

In [None]:
tile_shape = (64, 64)
overlap = 32
offset = (0, 0)

dataset_tiles = image_preprocessing.tile_training(loaded_raster, tile_shape, overlap, offset)
print(f'Tiled imagery shape {dataset_tiles["imagery"].shape}')
print(f'Tiled reference shape {dataset_tiles["reference"].shape}')

### 1.3. Tile filtration
However, most of the created tiles do not contain training data, we therefore need to filter them and only keep the tiles with a field-collected reference.

This process significantly reduces the size of our dataset from 2 886 861 to 49 842 - training data is available on less than 2 percent of the dataset.

In [None]:
filtered_tiles = image_preprocessing.filter_useful_tiles(dataset_tiles, nodata_vals=[65535], is_training=True)
print(f'Filtered imagery shape {filtered_tiles["imagery"].shape}')
print(f'Filtered reference shape {filtered_tiles["reference"].shape}')

### 1.4. Data normalization
After filtering the tiles to only include training data, we can move onto a final part of the preprocessing - data normalization. In Machine Learning, it is common to normalize all data before classification.

The resulting dictionary _preprocessed_tiles_ is subsequently transformed from numpy arrays into pytorch tensors for the training.

In [None]:
preprocessed_tiles, unique, counts = image_preprocessing.normalize_tiles_3d(filtered_tiles, nodata_vals=[65535], is_training=True)
print(f'Preprocessed imagery shape {preprocessed_tiles["imagery"].shape}')
print(f'Preprocessed reference shape {preprocessed_tiles["reference"].shape}')

In [None]:
dataset = tnt.dataset.TensorDataset([preprocessed_tiles['imagery'], preprocessed_tiles['reference']])
print(dataset)

print(f'Class labels: \n{unique}\n')
print(f'Number of pixels in a class: \n{counts}')

## 2. Neural network definition
After preprocessing our data, we can move onto defining our neural network and functions for training. You can either train your own neural network or use the one we already trained for you (_SpectroSpatialNet_pretrained.pt_). In case you are using the pretrained network, please run only the following code snippet (2.1.) and skip to section 3.

### 2.1. Network structure
Our network is named SpectralNet, and its structure is defined in the SpectralNet class, which has three methods:
- **__init__** - This method runs automatically when defining an instance of the class, it defines indiviudal layers of the networks (1D convolutions, fully connected layers, maxpooling and also a dropout layer).
- **init_weights** - Randomly initialising network weights based on a normal distribution.
- **forward** - Defining how data should flow through the network during a forward pass (network structure definition). The PyTorch library automatically creates a method for backward passes based on this structure.

In [None]:
class SpectroSpatialNet(nn.Module):
    """3D Spectral-Spatial CNN for semantic segmentation."""

    def __init__(self, args):
        """
        Initialize the SpectroSpatial model.

        n_channels, int, number of input channel
        size_e, int list, size of the feature maps of convs for the encoder
        size_d, int list, size of the feature maps of convs for the decoder
        n_class = int,  the number of classes
        """
        # necessary for all classes extending the module class
        super(SpectroSpatialNet, self).__init__()

        self.maxpool = nn.MaxPool3d(2, 2, return_indices=False)
        self.dropout = nn.Dropout3d(p=0.5, inplace=True)

        self.n_channels = args['n_channel']
        self.size_e = args['size_e']
        self.size_d = args['size_d']
        self.n_class = args['n_class']

        # Encoder layer definitions
        def c_en_3d(in_ch, out_ch, k_size=3, pad=1, pad_mode='zeros',
                    bias=False):
            """Create default conv layer for the encoder."""
            return nn.Sequential(nn.Conv3d(in_ch, out_ch, kernel_size=k_size,
                                           padding=pad, padding_mode=pad_mode,
                                           bias=bias),
                                 nn.BatchNorm3d(out_ch), nn.ReLU())

        def c_de_2d(in_ch, out_ch, k_size=3, pad=1, pad_mode='zeros',
                    bias=False):
            """Create default conv layer for the decoder."""
            return nn.Sequential(nn.Conv2d(in_ch, out_ch, kernel_size=k_size,
                                           padding=pad, padding_mode=pad_mode,
                                           bias=bias),
                                 nn.BatchNorm2d(out_ch), nn.ReLU())

        self.c1 = c_en_3d(self.n_channels, self.size_e[0])
        self.c2 = c_en_3d(self.size_e[0], self.size_e[1])
        self.c3 = c_en_3d(self.size_e[1], self.size_e[2])
        self.c4 = c_en_3d(self.size_e[2], self.size_e[3])
        self.c5 = c_en_3d(self.size_e[3], self.size_e[4])
        self.c6 = c_en_3d(self.size_e[4], self.size_e[5])

        self.trans1 = nn.ConvTranspose2d(self.size_d[0], self.size_d[1],
                                      kernel_size=2, stride=2)
        self.c7 = c_de_2d(self.size_d[1], self.size_d[2])
        self.c8 = c_de_2d(self.size_d[2], self.size_d[3])
        self.trans2 = nn.ConvTranspose2d(self.size_d[3], self.size_d[4],
                                      kernel_size=2, stride=2)
        self.c9 = c_de_2d(self.size_d[4], self.size_d[5])
        self.c10 = c_de_2d(self.size_d[5], self.size_d[6])

        # Final classifying layer
        self.classifier = nn.Conv2d(self.size_d[6], self.n_class,
                                    1, padding=0)

        # Weight initialization
        self.c1[0].apply(self.init_weights)
        self.c2[0].apply(self.init_weights)
        self.c3[0].apply(self.init_weights)
        self.c4[0].apply(self.init_weights)
        self.c5[0].apply(self.init_weights)
        self.c6[0].apply(self.init_weights)

        self.c7[0].apply(self.init_weights)
        self.c8[0].apply(self.init_weights)

        self.c9[0].apply(self.init_weights)
        self.c10[0].apply(self.init_weights)
        self.classifier.apply(self.init_weights)

        # Put the model on GPU memory
        if torch.cuda.is_available():
            self.cuda()
            torch.backends.cudnn.enabled = True
            torch.backends.cudnn.benchmark = True

    def init_weights(self, layer):  # gaussian init for the conv layers
        """Initialise layer weights."""
        nn.init.kaiming_normal_(
            layer.weight, mode='fan_out', nonlinearity='relu')

    def forward(self, input_data):
        """Define model structure."""
        # Encoder
        # level 1
        x1 = self.c2(self.c1(input_data))
        x2 = self.maxpool(x1)
        # level 2
        x3 = self.c4(self.c3(x2))
        x4 = self.maxpool(x3)
        # level 3
        x5 = self.c6(self.c5(x4))
        # Decoder
        # Level 3
        y5 = torch.flatten(x5, start_dim=1, end_dim=2)
        # level 2
        y4 = self.trans1(y5)
        y3 = self.c8(self.c7(y4))
        # level 1
        y2 = self.trans2(y3)
        y1 = self.c10(self.c9(y2))
        # Output
        out = self.classifier(self.dropout(y1))
        return out

### 2.2. Functions for network training
Training the network is handled by four functions:
- __augment__ - Augments the training data by adding random noise.
- __train__ - Trains the network for one epoch. This function contains a for loop, which loads the training data in individual batches. Each batch of training data goes through the network, after which we compute the loss function (cross-entropy). Last step of training is performing an optimiser step, which changes the networks heights.
- __eval__ - Evaluates the results on a validation set, should be done periodically during training to check for overfitting.
- __train_full__ - Performs the full training loop.

__augment__ takes in the training tile and the corresponding reference labels. It then adds a random value (taken from a normal distribution) at each wavelength and thus slightly modifies the training data. Change _tile_number_ to see the augmentation effect for different tiles.

In [None]:
class SpectroSpatialNet(nn.Module):
    """3D Spectral-Spatial CNN for semantic segmentation."""

    def __init__(self, args):
        """
        Initialize the SpectroSpatial model.

        n_channels, int, number of input channel
        size_e, int list, size of the feature maps of convs for the encoder
        size_d, int list, size of the feature maps of convs for the decoder
        n_class = int,  the number of classes
        """
        # necessary for all classes extending the module class
        super(SpectroSpatialNet, self).__init__()

        self.maxpool = nn.MaxPool3d(2, 2, return_indices=False)
        self.dropout = nn.Dropout3d(p=0.5, inplace=True)

        self.n_channels = args['n_channel']
        self.size_e = args['size_e']
        self.size_d = args['size_d']
        self.n_class = args['n_class']

        # Encoder layer definitions
        def c_en_3d(in_ch, out_ch, k_size=3, pad=1, pad_mode='zeros',
                    bias=False):
            """Create default conv layer for the encoder."""
            return nn.Sequential(nn.Conv3d(in_ch, out_ch, kernel_size=k_size,
                                           padding=pad, padding_mode=pad_mode,
                                           bias=bias),
                                 nn.BatchNorm3d(out_ch), nn.ReLU())

        def c_de_2d(in_ch, out_ch, k_size=3, pad=1, pad_mode='zeros',
                    bias=False):
            """Create default conv layer for the decoder."""
            return nn.Sequential(nn.Conv2d(in_ch, out_ch, kernel_size=k_size,
                                           padding=pad, padding_mode=pad_mode,
                                           bias=bias),
                                 nn.BatchNorm2d(out_ch), nn.ReLU())

        self.c1 = c_en_3d(self.n_channels, self.size_e[0])
        self.c2 = c_en_3d(self.size_e[0], self.size_e[1])
        self.c3 = c_en_3d(self.size_e[1], self.size_e[2])
        self.c4 = c_en_3d(self.size_e[2], self.size_e[3])
        self.c5 = c_en_3d(self.size_e[3], self.size_e[4])
        self.c6 = c_en_3d(self.size_e[4], self.size_e[5])

        self.trans1 = nn.ConvTranspose2d(self.size_d[0], self.size_d[1],
                                      kernel_size=2, stride=2)
        self.c7 = c_de_2d(self.size_d[1], self.size_d[2])
        self.c8 = c_de_2d(self.size_d[2], self.size_d[3])
        self.trans2 = nn.ConvTranspose2d(self.size_d[3], self.size_d[4],
                                      kernel_size=2, stride=2)
        self.c9 = c_de_2d(self.size_d[4], self.size_d[5])
        self.c10 = c_de_2d(self.size_d[5], self.size_d[6])

        # Final classifying layer
        self.classifier = nn.Conv2d(self.size_d[6], self.n_class,
                                    1, padding=0)

        # Weight initialization
        self.c1[0].apply(self.init_weights)
        self.c2[0].apply(self.init_weights)
        self.c3[0].apply(self.init_weights)
        self.c4[0].apply(self.init_weights)
        self.c5[0].apply(self.init_weights)
        self.c6[0].apply(self.init_weights)

        self.c7[0].apply(self.init_weights)
        self.c8[0].apply(self.init_weights)

        self.c9[0].apply(self.init_weights)
        self.c10[0].apply(self.init_weights)
        self.classifier.apply(self.init_weights)

        # Put the model on GPU memory
        if torch.cuda.is_available():
            self.cuda()
            torch.backends.cudnn.enabled = True
            torch.backends.cudnn.benchmark = True

    def init_weights(self, layer):  # gaussian init for the conv layers
        """Initialise layer weights."""
        nn.init.kaiming_normal_(
            layer.weight, mode='fan_out', nonlinearity='relu')

    def forward(self, input_data):
        """Define model structure."""
        # Encoder
        # 102-100-98-96 .. 48-46-44-42 .. 21-19-17-15
        # 54-52-50-48   .. 24-22-20-18 .. 9-7-5-3
        # level 1
        x1 = self.c2(self.c1(input_data))
        x2 = self.maxpool(x1)
        # level 2
        x3 = self.c4(self.c3(x2))
        x4 = self.maxpool(x3)
        # level 3
        x5 = self.c6(self.c5(x4))
        # Decoder
        # Level 3
        print(x5.shape)
        y5 = torch.flatten(x5, start_dim=1, end_dim=2)
        print(y5.shape)
        # level 2
        y4 = self.trans1(y5)
        y3 = self.c8(self.c7(y4))
        # level 1
        y2 = self.trans2(y3)
        y1 = self.c10(self.c9(y2))
        # Output
        out = self.classifier(self.dropout(y1))
        return out

In [None]:
class SpectroSpatialNet(nn.Module):
    """3D Spectral-Spatial CNN for semantic segmentation."""

    def __init__(self, args):
        """
        Initialize the SpectroSpatial model.

        n_channels, int, number of input channel
        size_e, int list, size of the feature maps of convs for the encoder
        size_d, int list, size of the feature maps of convs for the decoder
        n_class = int,  the number of classes
        """
        # necessary for all classes extending the module class
        super(SpectroSpatialNet, self).__init__()

        self.maxpool = nn.MaxPool3d(2, 2, return_indices=False)
        self.dropout = nn.Dropout3d(p=0.5, inplace=True)

        self.n_channels = args['n_channel']
        self.size_e = args['size_e']
        self.size_d = args['size_d']
        self.n_class = args['n_class']

        # Encoder layer definitions
        def c_en_3d(in_ch, out_ch, k_size=3, pad=(0,1,1), pad_mode='zeros',
                    bias=False):
            """Create default conv layer for the encoder."""
            return nn.Sequential(nn.Conv3d(in_ch, out_ch, kernel_size=k_size,
                                           padding=pad, padding_mode=pad_mode,
                                           bias=bias),
                                 nn.BatchNorm3d(out_ch), nn.ReLU())

        def c_de_2d(in_ch, out_ch, k_size=3, pad=1, pad_mode='zeros',
                    bias=False):
            """Create default conv layer for the decoder."""
            return nn.Sequential(nn.Conv2d(in_ch, out_ch, kernel_size=k_size,
                                           padding=pad, padding_mode=pad_mode,
                                           bias=bias),
                                 nn.BatchNorm2d(out_ch), nn.ReLU())

        self.c1 = c_en_3d(self.n_channels, self.size_e[0])
        self.c2 = c_en_3d(self.size_e[0], self.size_e[1])
        self.c3 = c_en_3d(self.size_e[1], self.size_e[2])
        
        self.c4 = c_en_3d(self.size_e[2], self.size_e[3])
        self.c5 = c_en_3d(self.size_e[3], self.size_e[4])
        self.c6 = c_en_3d(self.size_e[4], self.size_e[5])
        
        self.c7 = c_en_3d(self.size_e[5], self.size_e[6])
        self.c8 = c_en_3d(self.size_e[6], self.size_e[7])
        self.c9 = c_en_3d(self.size_e[7], self.size_e[8])

        self.trans1 = nn.ConvTranspose2d(self.size_d[0], self.size_d[1],
                                      kernel_size=2, stride=2)
        self.c10 = c_de_2d(self.size_d[1], self.size_d[2])
        self.c11 = c_de_2d(self.size_d[2], self.size_d[3])
        self.c12 = c_de_2d(self.size_d[3], self.size_d[4])
        self.trans2 = nn.ConvTranspose2d(self.size_d[4], self.size_d[5],
                                      kernel_size=2, stride=2)
        
        self.c13 = c_de_2d(self.size_d[5], self.size_d[6])
        self.c14 = c_de_2d(self.size_d[6], self.size_d[7])
        self.c15 = c_de_2d(self.size_d[7], self.size_d[8])

        # Final classifying layer
        self.classifier = nn.Conv2d(self.size_d[8], self.n_class,
                                    1, padding=0)

        # Weight initialization
        self.c1[0].apply(self.init_weights)
        self.c2[0].apply(self.init_weights)
        self.c3[0].apply(self.init_weights)
        
        self.c4[0].apply(self.init_weights)
        self.c5[0].apply(self.init_weights)
        self.c6[0].apply(self.init_weights)
        
        self.c7[0].apply(self.init_weights)
        self.c8[0].apply(self.init_weights)
        self.c9[0].apply(self.init_weights)

        self.c10[0].apply(self.init_weights)
        self.c11[0].apply(self.init_weights)
        self.c12[0].apply(self.init_weights)
        
        self.c13[0].apply(self.init_weights)
        self.c14[0].apply(self.init_weights)
        self.c15[0].apply(self.init_weights)

        self.classifier.apply(self.init_weights)

        # Put the model on GPU memory
        if torch.cuda.is_available():
            self.cuda()
            torch.backends.cudnn.enabled = True
            torch.backends.cudnn.benchmark = True

    def init_weights(self, layer):  # gaussian init for the conv layers
        """Initialise layer weights."""
        nn.init.kaiming_normal_(
            layer.weight, mode='fan_out', nonlinearity='relu')

    def forward(self, input_data):
        """Define model structure."""
        # Encoder
        # 102-100-98-96 .. 48-46-44-42 .. 21-19-17-15
        # 54-52-50-48   .. 24-22-20-18 .. 9-7-5-3
        # level 1
        x1 = self.c3(self.c2(self.c1(input_data)))
        x2 = self.maxpool(x1)
        # level 2
        x3 = self.c6(self.c5(self.c4(x2)))
        x4 = self.maxpool(x3)
        # level 3
        x5 = self.c9(self.c8(self.c7(x4)))
        # Decoder
        # Level 3
        y5 = torch.flatten(x5, start_dim=1, end_dim=2)
        # level 2
        y4 = self.trans1(y5)
        y3 = self.c12(self.c11(self.c10(y4)))
        # level 1
        y2 = self.trans2(y3)
        y1 = self.c15(self.c14(self.c13(y2)))
        # Output
        out = self.classifier(self.dropout(y1))
        return out

In [None]:
def augment(obs, g_t):
    """the data augmentation function, introduces random noise and rotation"""
    sigma, clip= 0.01, 0.03 

    # Random noise
    rand = torch.clamp(torch.mul(sigma, torch.randn([1, 1, 102, tile_shape[0],tile_shape[1]])), -clip, clip)
    obs = torch.add(obs, rand)

    # Random rotation 0 90 180 270 degree
    n_turn = np.random.randint(4) #number of 90 degree turns, random int between 0 and 3
    obs = torch.rot90(obs, n_turn, dims=(3,4))
    g_t = torch.rot90(g_t, n_turn, dims=(1,2))

    return obs, g_t

In [None]:
tile_number = 46
visualisation_utils.show_augment_spectro_spatial(preprocessed_tiles, tile_number, augment)

__train__ trains the network for one epoch. This function contains a for loop, which loads the training data in individual batches. Each batch of training data goes through the network, after which we compute the loss function (cross-entropy). Last step of training is performing an optimiser step, which changes the networks heights.

__eval__ evaluates the results on a validation set, should be done periodically during training to check for overfitting.

__train_full__ performs the full training loop.

In [None]:
def train(model, optimizer, args):
    """train for one epoch"""
    model.train() #switch the model in training mode
  
    #the loader function will take care of the batching
    loader = torch.utils.data.DataLoader(dataset, batch_size=args['batch_size'], sampler=args['train_subsampler'])
    loader = tqdm.tqdm(loader, ncols=500)
  
    #will keep track of the loss
    loss_meter = tnt.meter.AverageValueMeter()

    for index, (tiles, gt) in enumerate(loader):
    
        optimizer.zero_grad() #put gradient to zero

        tiles, gt = augment(tiles, gt)

        pred = model(tiles.cuda()) #compute the prediction

        loss = nn.functional.cross_entropy(pred.cpu(),gt, weight=args['class_weights'])
        loss.backward() #compute gradients

        for p in model.parameters(): #we clip the gradient at norm 1
            p.grad.data.clamp_(-1, 1) #this helps learning faster

        optimizer.step() #one SGD step
        loss_meter.add(loss.item())
        
    return loss_meter.value()[0]

def eval(model, sampler):
    """eval on test/validation set"""
  
    model.eval() #switch in eval mode
  
    loader = torch.utils.data.DataLoader(dataset, batch_size=1, sampler=sampler)
    loader = tqdm.tqdm(loader, ncols=500)
  
    loss_meter = tnt.meter.AverageValueMeter()

    with torch.no_grad():
        for index, (tiles, gt) in enumerate(loader):
            pred = model(tiles.cuda())
            loss = nn.functional.cross_entropy(pred.cpu(), gt)
            loss_meter.add(loss.item())

    return loss_meter.value()[0]


def train_full(args):
    """The full training loop"""

    #initialize the model
    model = SpectroSpatialNet(args)

    print(f'Total number of parameters: {sum([p.numel() for p in model.parameters()])}')
  
    #define the Adam optimizer
    optimizer = optim.Adam(model.parameters(), lr=args['lr'])
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=args['scheduler_milestones'],
                                               gamma=args['scheduler_gamma'])
  
    train_loss = np.empty(args['n_epoch'])
    test_epochs = []
    test_loss = []

    for i_epoch in range(args['n_epoch']):
        #train one epoch
        print(f'Epoch #{str(i_epoch+1)}')
        train_loss[i_epoch] = train(model, optimizer, args)
        scheduler.step()

        # Periodic testing on the validation set
        if (i_epoch == args['n_epoch'] - 1) or ((i_epoch + 1) % args['n_epoch_test'] == 0):
            print('Evaluation')
            loss_test = eval(model, args['test_subsampler'])
            test_epochs.append(i_epoch + 1)
            test_loss.append(loss_test)
            
    plt.figure(figsize=(10, 10))
    plt.subplot(1,1,1,ylim=(0,5), xlabel='Epoch #', ylabel='Loss')
    plt.plot([i+1 for i in range(args['n_epoch'])], train_loss, label='Training loss')
    plt.plot(test_epochs, test_loss, label='Validation loss')
    plt.legend()
    plt.show()
    print(train_loss)
    print(test_loss)
    args['loss_test'] = test_loss[-1]
    
    return model

### 2.3. Hyperparameter definition
Training networks requires first setting several hyperparameters, please feel free to play around with them and try different values for the number of training epochs, learning rate or batch size.

- __n_channel__ - number of channels, set to 1 for our task
- __n_class__ - number of classification classes
- __size_e__ - number of filters in each NN layer of the encoder
- __size_d__ - number of filters in each NN layer of the decoder
- __crossval_nfolds__ - Number of folds for crossvalidation
- __n_epoch_test__ - after how many training epochs do we validate on the validation set
- __scheduler_milestones__ - after how many epochs do we reduce the training rate
- __scheduler_gamma__ - by what factor do we reduce the training rate
- __class_weights__ - training weights for individual classes, used to offset imbalanced class distribution

- __n_epoch__ - how many epochs are performed during training
- __lr__ - how fast can individual network parameters change during one training epoch
- __batch_size__ - how many tiles should be included in each gradient descent step

In [None]:
args = { #Dict to store all model parameters
    'n_channel': 1,
    'n_class': len(unique),
    'size_e': [16,16,16,16,16,16,32,32,32],
    'size_d': [480,32,32,32,32,32,32,32,32,32],
    
    'crossval_nfolds': 3,
    'n_epoch_test': 2,          #periodicity of evaluation on test set
    'scheduler_milestones': [60,80,90],
    'scheduler_gamma': 0.3,
    'class_weights': torch.tensor([0.0, 0.2, 0.34, 0.033, 0.16, 0.14, 0.03, 0.014, 0.023, 0.06]),

    'n_epoch': 5,
    'lr': 1e-6,
    'batch_size': 4,
}
model_save_folder = '../../models/Pavia/3D'

print(f'''Number of models to be trained:
    {args['crossval_nfolds']}
Number of spectral channels:
    {args['n_channel']}
Initial learning rate:
    {args['lr']}
Batch size:
    {args['batch_size']}
Number of training epochs:
    {args['n_epoch']}''')

### 2.4 Network training

In [None]:
## Training a 3D network
kfold = KFold(n_splits = args['crossval_nfolds'], shuffle=True)
trained_models = []
for fold, (train_ids, test_ids) in enumerate(kfold.split(dataset)):
    print(f'Training starts for model number {str(fold+1)}')
    
    a = perf_counter()
    args['train_subsampler'] = torch.utils.data.SubsetRandomSampler(train_ids)
    args['test_subsampler'] = torch.utils.data.SubsetRandomSampler(test_ids)
    
    trained_models.append((train_full(args), args['loss_test']))
    
    state_dict_path = join(model_save_folder, f'fold_{str(fold)}.pt')
    torch.save(trained_models[fold][0].state_dict(), state_dict_path)
    print(f'Model saved to: {state_dict_path}')
    print(f'Training finished in {str(perf_counter()-a)}s')
    print('\n\n')

print(f'Resulting loss for individual folds: \n{[i for _, i in trained_models]}')
print(f'Mean loss across all folds: \n{np.mean([i for _, i in trained_models])}')

## 3. Applying the network

### 3.1. Loading a trained model

In [None]:
# Parameters for model definition
args = {
    'n_class': 10,
    'n_channel': 1,
    'size_e': [16,16,16,16,16,16,32,32,32],
    'size_d': [480,32,32,32,32,32,32,32,32,32],
}
# Path to the state_dictionary
state_dict_path = '../../models/Pavia/3D/fold_0.pt'

In [None]:
model = SpectroSpatialNet(args)
model.load_state_dict(torch.load(state_dict_path))
model.eval()

### 3.2. Loading and preprocessing the data

In [None]:
source_path = '../../data/LH_202008_54bands_9cm.tif'
source_path = '../../data/Pavia_centre/Pavia.mat'

tile_shape = (128, 128)
overlap = 64
offset_topleft = (0, 0)

In [None]:
start = perf_counter()
#raster_orig = image_preprocessing.read_gdal_with_geoinfo(source_path, offset_topleft)
raster_orig = image_preprocessing.read_pavia_centre(source_path, out_shape=(1088, 1088, 102))

dataset_full_tiles = image_preprocessing.run_tiling_dims(raster_orig['imagery'], out_shape=tile_shape, 
                                                    out_overlap=overlap, offset=offset_topleft)
dataset_full = image_preprocessing.normalize_tiles_3d(dataset_full_tiles, nodata_vals=[0])
dataset = tnt.dataset.TensorDataset(dataset_full['imagery'])
end = perf_counter()
print(f'Loading the imagery took {end - start} seconds.')

print(dataset_full_tiles['imagery'].shape)
print(dataset_full_tiles['dimensions'])

### 3.3. Applying the CNN and exporting results
The following snippet applies the CNN and exports the resulting classified raster into output_path for further analysis (e.g. validation in GIS):

In [None]:
output_path = '../../results/test_result_3d.tif'

start = perf_counter()
arr_class = inference_utils.combine_tiles_2d(model, dataset, tile_shape, overlap, dataset_full_tiles['dimensions'])
print(np.unique(arr_class, return_counts=True))
inference_utils.export_result(output_path, arr_class, raster_orig['geoinfo'])
print(f'The processing took {perf_counter() - start} seconds.')

You can also visualise the result:

In [None]:
visualisation_utils.show_classified(raster_orig['imagery'][:, :, [25, 15, 5]],
                                    loaded_raster['reference'], arr_class)