# Training with Ensemble Neural Network Approach

### This notebook shows how to train Ensemble Neural Network that is built upon Metric learning and CNN structures


In [None]:
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import Dataset, DataLoader
from torch import optim

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
padded_value = -1
sequencelength = 45

bands = ['B1', 'B10', 'B11', 'B12', 'B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8',
   'B8A', 'B9', 'QA10', 'QA20', 'QA60', 'doa']

selected_bands = ['B1', 'B10', 'B11', 'B12', 'B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8', 'B8A', 'B9']

selected_band_idxs = np.array([bands.index(b) for b in selected_bands])

def transform(x):
    x = x[x[:, 0] != padded_value, :] # remove padded values
    
    # choose selected bands
    x = x[:,selected_band_idxs] * 1e-4 # scale reflectances to 0-1

    # choose with replacement if sequencelength smaller als choose_t
    replace = False if x.shape[0] >= sequencelength else True
    idxs = np.random.choice(x.shape[0], sequencelength, replace=replace)
    idxs.sort()

    x = x[idxs]

    return torch.from_numpy(x).type(torch.FloatTensor).to(device)

def target_transform(y):
    y = frh01.mapping.loc[y].id
    return torch.tensor(y, dtype=torch.long, device=device)



### Load the dataset

In [None]:
import dataset
data_path = "path_to_breizhcrops_dataset"
data_path = "/home/firatk/Desktop/finalproject/breizhcrops_data"

# load training data
frh01 = dataset.BreizhDataset(region="frh01", root=data_path, transform=transform,
                                target_transform=target_transform, padding_value=padded_value)
frh02 = dataset.BreizhDataset(region="frh02", root=data_path, transform=transform,
                                target_transform=target_transform, padding_value=padded_value)

frh01.isClassification = True
frh02.isClassification = True

In [None]:
from models import LSTM, ClassificationModel, CnnNet, EnsembleNet
from torch.nn import CrossEntropyLoss

hidden_dims = 128
num_layers = 3
num_classes = 13
input_dim = 13
bidirectional = True

lstm = LSTM(input_dim=input_dim, hidden_dims=hidden_dims, num_classes=num_classes, num_layers=num_layers, dropout=0.2, bidirectional=True, use_layernorm=True)
lstm.load("lstm_model_path.pth")

classifier = ClassificationModel((hidden_dims + hidden_dims * bidirectional) * num_layers, hidden_dims ,num_classes)
classifier.load("ann_model_path.pth")

cnn = CnnNet()
cnn.load("cnn_model_path.pth")

ensemble = EnsembleNet(lstm, classifier, cnn)

### Below shows the traning part of the Ensemble Neural Network

In [None]:
batch_size = 64
data = torch.utils.data.ConcatDataset([frh01,frh02])
trainingDataLoader = torch.utils.data.DataLoader(data,batch_size=batch_size,shuffle=True,num_workers=0)

# training configuration
epochs = 101
lrDecreaseStep = 5
earlyStoppingStep = 2

loss = 0
minLoss = 999
lossNotDecreasedCounter = 0
lrDecreasedCounter = 0

ensemble.train()
if torch.cuda.is_available():
    ensemble.cuda()

# define loss function and the optimizer
lossFn = CrossEntropyLoss()
optimizer = optim.SGD(ensemble.parameters(), lr=0.001, momentum=0.5)

for epoch in range(epochs):
    epochLoss = 0
    for (batch_id,data) in enumerate(trainingDataLoader):

        # consider only one element, instead of pair like in siamese structure
        x1, _, labels = data

        if torch.cuda.is_available():
          x1 = x1.cuda()
          labels = labels.cuda()

        # forward once over the input element
        out1 = ensemble.forward(x1)

        loss = lossFn(out1, labels)
        epochLoss += loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch_id % 100 == 0:
            print("Iteration {}: loss {:.2f}".format(batch_id,loss.item()))

    batchId = (batch_id + 1)
    avgLoss = epochLoss / batchId

    # check the loss to decide early stopping or learning rate decreasing
    if avgLoss < minLoss:
        lossNotDecreasedCounter = 0
        minLoss = avgLoss
    else:
        lossNotDecreasedCounter +=1

    if lossNotDecreasedCounter == lrDecreaseStep:
        lossNotDecreasedCounter = 0
        minLoss = 999
        lrDecreasedCounter +=1
        print("Decrease learning rate...")

        for g in optimizer.param_groups:
            lr = g['lr']
            g['lr'] = lr * 0.1

    if lrDecreasedCounter == earlyStoppingStep:
        print("Earyl stopping...")
        print("Saving the model...")
        torch.save(cnn.state_dict(),"cnn_model.pth")

    if epoch % 10 == 0:
      print("Saving....")
      torch.save(cnn.state_dict(),"cnn_model.pth")

    print("Epoch number {}\n Current loss {}\n".format(epoch,avgLoss))