# Training with Metric Learning Approach

### This notebook shows how to train metric learning approach with the Siamese LSTM network and the Artificial Neural Network


In [None]:
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import Dataset, DataLoader
from torch import optim

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
padded_value = -1
sequencelength = 45

bands = ['B1', 'B10', 'B11', 'B12', 'B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8',
   'B8A', 'B9', 'QA10', 'QA20', 'QA60', 'doa']

selected_bands = ['B1', 'B10', 'B11', 'B12', 'B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8', 'B8A', 'B9']

selected_band_idxs = np.array([bands.index(b) for b in selected_bands])

def transform(x):
    x = x[x[:, 0] != padded_value, :] # remove padded values
    
    # choose selected bands
    x = x[:,selected_band_idxs] * 1e-4 # scale reflectances to 0-1

    # choose with replacement if sequencelength smaller als choose_t
    replace = False if x.shape[0] >= sequencelength else True
    idxs = np.random.choice(x.shape[0], sequencelength, replace=replace)
    idxs.sort()

    x = x[idxs]

    return torch.from_numpy(x).type(torch.FloatTensor).to(device)

def target_transform(y):
    y = frh01.mapping.loc[y].id
    return torch.tensor(y, dtype=torch.long, device=device)



### Load the dataset

In [None]:
import dataset
data_path = "path_to_breizhcrops_dataset"

# load training data
frh01 = dataset.BreizhDataset(region="frh01", root=data_path, transform=transform,
                                target_transform=target_transform, padding_value=padded_value)
frh02 = dataset.BreizhDataset(region="frh02", root=data_path, transform=transform,
                                target_transform=target_transform, padding_value=padded_value)   

# select elements of the same classes from different regions
frh01.setOtherDataset(frh02)
frh02.setOtherDataset(frh01)                             

In [None]:
from models import LSTM, ClassificationModel
from loss import ContrastiveLoss

# model configurations
hidden_dims = 128
num_layers = 3
num_classes = 13
input_dim = 13
bidirectional = True

# create Siamese LSTM model with these configurations
lstm = LSTM(input_dim=input_dim, hidden_dims=hidden_dims, num_classes=num_classes, num_layers=num_layers, dropout=0.2, bidirectional=True, use_layernorm=True)
classifier = ClassificationModel((hidden_dims + hidden_dims * bidirectional) * num_layers, hidden_dims ,num_classes)

### Below shows the traning part of the Siamese LSTM

In [None]:
data = torch.utils.data.ConcatDataset([frh01,frh02])
trainingDataLoader = torch.utils.data.DataLoader(data,batch_size=64,shuffle=True,num_workers=0)

# training configuration
epochs = 101
lrDecreaseStep = 5
earlyStoppingStep = 2

loss = 0
minLoss = 999
lossNotDecreasedCounter = 0
lrDecreasedCounter = 0


lstm.train()
if torch.cuda.is_available():
    lstm.cuda()

# create loss function and optimizer
lossFn = ContrastiveLoss(margin = 3)
optimizer = torch.optim.Adam(
        filter(lambda x: x.requires_grad, lstm.parameters()),
        betas=(0.9, 0.98), eps=1e-09, weight_decay=1e-6)

for epoch in range(epochs):
    epochLoss = 0
    for (batch_id,data) in enumerate(trainingDataLoader):

        # get paired data with corresponding label (1 for positive, 0 for negative)
        x1, x2, labels = data

        if torch.cuda.is_available():
          x1 = x1.cuda()
          x2 = x2.cuda()
          labels = labels.cuda()

        # forward once over the pair
        out1, out2 = lstm.forward(x1, x2)

        loss = lossFn(out1, out2, labels)
        epochLoss += loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch_id % 100 == 0:
            print("Iteration {}: loss {:.2f}".format(batch_id,loss.item()))

    batchId = (batch_id + 1)
    avgLoss = epochLoss / batchId

    # check the loss to decide early stopping or learning rate decreasing
    if avgLoss < minLoss - 0.01:
        lossNotDecreasedCounter = 0
        minLoss = avgLoss
    else:
        lossNotDecreasedCounter +=1

    if lossNotDecreasedCounter > lrDecreaseStep:
        lossNotDecreasedCounter = 0
        minLoss = 999
        lrDecreasedCounter +=1
        print("Decrease learning rate...")

        for g in optimizer.param_groups:
            lr = g['lr']
            g['lr'] = lr * 0.1

    if lrDecreasedCounter == earlyStoppingStep:
        print("Earyl stopping...")
        torch.save(lstm.state_dict(),"lstm_model.pth")

    if epoch % 10 == 0:
      print("Saving....")
      torch.save(lstm.state_dict(),"lstm_model.pth")

    print("Epoch number {}\n Current loss {}\n".format(epoch,avgLoss))

### Below shows the traning part of the Artificial Neural Network fed by Siamese LSTM

In [None]:
# load the trained lstm model
lstm.load("path_to_lstm_model.pth")

# set dataset to return proper label of the classes instead of 1 and 0
frh01.isClassification = True
frh02.isClassification = True

data = torch.utils.data.ConcatDataset([frh01,frh02])
trainingDataLoader = torch.utils.data.DataLoader(data,batch_size=64,shuffle=True,num_workers=0)

# training configuration
epochs = 101
lrDecreaseStep = 5
earlyStoppingStep = 2

loss = 0
minLoss = 999
lossNotDecreasedCounter = 0
lrDecreasedCounter = 0

classifier.train()
if torch.cuda.is_available():
    lstm = lstm.cuda()
    classifier = classifier.cuda()

optimizer = torch.optim.Adam(
        filter(lambda x: x.requires_grad, classifier.parameters()),
        betas=(0.9, 0.98), eps=1e-09)

for epoch in range(epochs):
    epochLoss = 0
    for (batch_id,data) in enumerate(trainingDataLoader):

        # get paired data with corresponding label of x1 instead of 1 and 0
        x1, x2, labels = data

        if torch.cuda.is_available():
          x1 = x1.cuda()
          x2 = x2.cuda()
          labels = labels.cuda()

        # forward once over LSTM and forward first output over ANN
        out1, out2 = lstm.forward(x1,x2)
        logProbs = classifier.forward(out1)

        loss = torch.nn.functional.nll_loss(logProbs, labels)
        epochLoss += loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch_id % 100 == 0:
            print("Iteration {}: loss {:.2f}".format(batch_id,loss.item()))

    batchId = (batch_id + 1)
    avgLoss = epochLoss / batchId

    # check the loss to decide early stopping or learning rate decreasing
    if avgLoss < minLoss - 0.01:
        lossNotDecreasedCounter = 0
        minLoss = avgLoss
    else:
        lossNotDecreasedCounter +=1

    if lossNotDecreasedCounter > lrDecreaseStep:
        lossNotDecreasedCounter = 0
        minLoss = 999
        lrDecreasedCounter +=1
        print("Decrease learning rate...")

        for g in optimizer.param_groups:
            lr = g['lr']
            g['lr'] = lr * 0.1

    if lrDecreasedCounter == earlyStoppingStep:
        print("Earyl stopping...")
        print("Saving the model...")
        torch.save(classifier.state_dict(),"ann_model.pth")

    if epoch % 10 == 0:
      print("Saving....")
      torch.save(classifier.state_dict(),"ann_model.pth")

    print("Epoch number {}\n Current loss {}\n".format(epoch,avgLoss))