Sample Crop Classification using Eurocrops Data

Model architecture adapted from [Garnot et al., 2020](https://openaccess.thecvf.com/content_CVPR_2020/papers/Garnot_Satellite_Image_Time_Series_Classification_With_Pixel-Set_Encoders_and_Temporal_CVPR_2020_paper.pdf)

Additional Layerwise Relevance Propagation component from [Chan et al., 2023](https://ieeexplore.ieee.org/document/10281498)


If using TinyEuroCrops, only the section below needs to be adjusted for desired variables and directories

In [14]:
# Training related parameters 
# number of epochs of training
n_epochs = 10
# size of the batches
batch_size = 128
# focal loss:  focusing parameter
gamma = 1
# adam: learning rate
learning_rate = 1e-5
# adam: decay of first order momentum of gradient
b1 = 0.9
# adam: decay of first order momentum of gradient
b2 = 0.999
# adam: weight decay (L2 penalty)
weight_decay = 1e-6
# print frequency of progress meter
print_freq = 50


# name of checkpoint 
CP_name= 'checkpoint.pth.tar'
# initialize a dummy best accuracy 
best_acc1 = 0
# location to save current tensorboard session
current_run = '/Users/ayshahchan/Desktop/PhD/runs/psetae'

# Data Directory
# this section is specific to TinyEuroCrops, should other data be used please adjust the Data Loader Section for the correct directories as well
# root location of data
root = '/Users/ayshahchan/Desktop/Education/ESPACE/thesis/codes/data'
partition = "train"
# This notebook uses the train section of Austrian TinyEuroCrops
country='AT_T33UWP'

Data Loader for EuroCrop Demo Data TinyEuroCrops

The data loader loads the data from 4 different files: one containing the spectral reflectances for training, one containing the spectral reflectances for testing, one containing the labels for training and the last containing the labels for testing.

The code assumes the files are saved under different folders in the same root directory. Should the paths are different, please adjust the code in this section accordingly. This code assumes TinyEuroCrops file structure.


Note: although the processing process is similar to the webinar codes, some column names are different. When applying this code on newly processed datasets: please take note the column names crpgrpc and crpgrpn may be hcat_c and hcat_n instead

Defining the Model
There are three main parts of the model: the multilayer perceptron encoder, the attention layer and the multilayer perceptron decoder. 

![model diagram](webinar_demo/model.png)
![attention diagram](webinar_demo/attention_block.png)


You can adjust the model as you wish, whether it is adding more layers or removing the dropout layers. Removing the dropout layers will lead to faster training but may also lead to overfitting.



In [None]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.autograd import Variable

Multilayer Perceptrons

See mlps.py

Spectral Encoder

A simple encoder for initial feature extraction mainly in the spectral domain.

See pse.py

Attention Layer

For Positional Encoding: one of the inputs is days. This refers to the number of days since first data point due to the irregular frequency of data acquisition. This should be a readily availble output of the data loader.

See positional_encoding.py, attention_layer.py, lrp_ln.py

Attention layer based on Garnot and Marc's implementation

forward_lrp drops the dropout layer + detaches softmax and variance(in the layer norm) as per LRP propagation rules adapted from [Ali et al., 2022](https://arxiv.org/abs/2202.07304)

Temporal Attention Encoder

Encoder that primarily extracts features in the temporal domain by leveraging the attention mechanism

See tae.py


Combining all the individual modules

See pse_tae.py

Define a Loss Function

See focalloss.py

Setting up for training

In [None]:
from torch.utils.data import DataLoader, random_split

from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter(current_run)
import time
import shutil
import numpy as np
import os
import torch

from focalloss import FocalLoss
from pse_tae import PSE_TAE
from eurocrops_dataloader import EuroCropsDataset

In [25]:
def train(train_loader,
          pse_tae,
          focal_loss,
          optimizer,
          epoch,
          print_freq,
          device):

    batch_time = AverageMeter('Time', ':6.3f')
    data_time = AverageMeter('Data', ':6.3f')
    # pse_time = AverageMeter('PSE', ':6.3f')
    # tae_time = AverageMeter('TAE', ':6.3f')
    # decode_time = AverageMeter('Decode', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@2', ':6.2f')
    progress = ProgressMeter(len(train_loader),
                             [batch_time, data_time, losses, top1, top5],
                             prefix="Epoch: [{}]".format(epoch))

    # -------------------------
    #  Put Models in Train Mode
    # -------------------------
    pse_tae.train()

    end = time.time()
    # running_loss = 0.0
    # running_correct = 0
    for i, batch in enumerate(train_loader):
        data_time.update(time.time() - end)

        # Get training data
        data = batch['data'].to(device)
        label = batch['label'].to(device)
        days = batch["dates"].to(device)
        
        
        # ---------------------------------
        #  Train Everything together
        # ---------------------------------
        optimizer.zero_grad()
        output = pse_tae(data,days)
        _, prediction = torch.max(output.data, 1)

        # ---------------------------------
        #  Loss
        # ---------------------------------

        # Focal Loss between output and target
        loss = focal_loss(output.to(device), label)

        # ---------------------------------
        #  Record Stats
        # ---------------------------------
        # Measure accuracy and record loss
        acc1, acc5 = accuracy(output, label, topk=(1, 2))
        losses.update(loss.item(), data.size(0))
        top1.update(acc1[0], data.size(0))
        top5.update(acc5[0], data.size(0))

        # ---------------------------------
        #  Gradient & SGD step
        # ---------------------------------
        loss.backward()
        optimizer.step()

        # ---------------------------------
        #  Time
        # ---------------------------------
        batch_time.update(time.time() - end)
        end = time.time()


        if i % print_freq == 0:
            progress.display(i)
            # for tensorboard
            writer.add_scalar('train loss', loss.item(), epoch * len(train_loader) + i)
            
            writer.add_scalar('accuracy best', acc1, epoch * len(train_loader) + i)
            

    



        


def validate(val_loader, pse_tae, focal_loss, epoch, print_freq, device):
    batch_time = AverageMeter('Time', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    progress = ProgressMeter(len(val_loader), [batch_time, losses, top1, top5], prefix='Test: ')

    # -------------------------
    #  Put Models in Eval Mode
    # -------------------------
    pse_tae.eval()

    with torch.no_grad():
        end = time.time()
        for i, val_batch in enumerate(val_loader):
            # Get validation data
            data_val = val_batch['data'].to(device)
            label_val = val_batch['label'].to(device)
            days_val = val_batch["dates"].to(device)
     

            # -------------------------
            #  Compute Predictions
            # -------------------------
            output = pse_tae(data_val, days_val)
            loss = focal_loss(output.to(device), label_val)
            _, prediction = torch.max(output.data, 1)

            # ---------------------------------
            #  Record Stats
            # ---------------------------------
            # measure accuracy and record loss
            acc1, acc5 = accuracy(output, label_val, topk=(1, 5))
            losses.update(loss.item(), data_val.size(0))
            top1.update(acc1[0], data_val.size(0))
            top5.update(acc5[0], data_val.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % print_freq == 0:
                progress.display(i)
                writer.add_scalar('val loss', loss.item(), epoch * len(val_loader) + i)
            #writer.add_scalar('accuracy', running_correct/100, epoch * len(train_loader) + i)
                writer.add_scalar('val accuracy', acc1, epoch * len(val_loader) + i)

        print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'.format(top1=top1, top5=top5))

    return top1.avg, losses.avg


def save_checkpoint(state, is_best, filename=CP_name):
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, 'best_'+filename)

def load_checkpoint(model, optimizer,  filename=CP_name):
    # Note: Input model & optimizer should be pre-defined.  This routine only updates their states.
    start_epoch = 0
    if os.path.isfile(filename):
        print("=> loading checkpoint '{}'".format(filename))
        checkpoint = torch.load(filename)
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        
        print("=> loaded checkpoint '{}' (epoch {})"
                  .format(filename, checkpoint['epoch']))
    else:
        print("=> no checkpoint found at '{}'".format(filename))

    return model, optimizer, start_epoch

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)


class ProgressMeter(object):
    def __init__(self, num_batches, meters, prefix=""):
        self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
        self.meters = meters
        self.prefix = prefix

    def display(self, batch):
        entries = [self.prefix + self.batch_fmtstr.format(batch)]
        entries += [str(meter) for meter in self.meters]
        print('\t'.join(entries))

    def _get_batch_fmtstr(self, num_batches):
        num_digits = len(str(num_batches // 1))
        fmt = '{:' + str(num_digits) + 'd}'
        return '[' + fmt + '/' + fmt.format(num_batches) + ']'


def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

class EarlyStopper:
    def __init__(self, patience=10, min_delta=0.2):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.min_validation_loss = np.inf

    def early_stop(self, validation_loss):
        if validation_loss < self.min_validation_loss:
            self.min_validation_loss = validation_loss
            self.counter = 0
        elif validation_loss > (self.min_validation_loss + self.min_delta):
            self.counter += 1
            if self.counter >= self.patience:
                return True
        return False

Actual Training

In [26]:



dataset = EuroCropsDataset(root=root,partition=partition,country=country)
           
fold_len = int(len(dataset) / 5)
n_train = len(dataset) -  fold_len
train_set, val_set =random_split(dataset,(n_train, fold_len),generator=torch.Generator().manual_seed(42))  


train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False)


if torch.cuda.is_available():
    device = torch.device("cuda")

else:
    device = "cpu"

pse_tae = PSE_TAE(device).to(device)

focal_loss = FocalLoss(gamma).to(device)

# Adam Optimizer
optimizer = torch.optim.Adam(pse_tae.parameters(), lr=learning_rate, betas=(b1, b2), weight_decay=weight_decay)

# stops training when validation loss no longer decreases by min_delta after X epochs as defined by patience.
early_stopper = EarlyStopper(patience=5, min_delta=0.2)

train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False)



pse_tae, optimizer, start_epoch = load_checkpoint(pse_tae, optimizer)


for epoch in range(start_epoch, n_epochs):

    # ----------
    #  Training
    # ----------

    train(train_loader,
            pse_tae,
            focal_loss,
            optimizer,
            epoch,
            print_freq,
            device)

    # -----------
    #  Validation
    # -----------

    acc1, val_loss = validate(val_loader,
                    pse_tae,
                    focal_loss,
                    epoch,
                    print_freq,
                    device)

    # -----------
    #  Remember best acc@1 and save checkpoint
    # -----------
    is_best = acc1 > best_acc1
    best_acc1 = max(acc1, best_acc1)
    
    if early_stopper.early_stop(val_loss):             
        break

    save_checkpoint({
            'epoch': epoch + 1,
            'state_dict': pse_tae.state_dict(),
            'best_acc1': best_acc1,
            'optimizer' : optimizer.state_dict(),
        }, is_best)

345970 parcels in file with 44 classes 
=> no checkpoint found at 'checkpoint.pth.tar'




Epoch: [0][   0/2163]	Time  1.982 ( 1.982)	Data  1.762 ( 1.762)	Loss 3.7913e+00 (3.7913e+00)	Acc@1   2.34 (  2.34)	Acc@2   3.91 (  3.91)
Epoch: [0][  50/2163]	Time  0.239 ( 0.281)	Data  0.183 ( 0.218)	Loss 3.7034e+00 (3.8050e+00)	Acc@1   7.03 (  3.74)	Acc@2  10.94 (  6.43)
Epoch: [0][ 100/2163]	Time  0.242 ( 0.263)	Data  0.184 ( 0.202)	Loss 3.7096e+00 (3.7637e+00)	Acc@1   5.47 (  4.56)	Acc@2   9.38 (  7.80)
Epoch: [0][ 150/2163]	Time  0.241 ( 0.257)	Data  0.184 ( 0.197)	Loss 3.6428e+00 (3.7221e+00)	Acc@1   5.47 (  5.74)	Acc@2  11.72 (  9.41)
Epoch: [0][ 200/2163]	Time  0.240 ( 0.254)	Data  0.184 ( 0.194)	Loss 3.5550e+00 (3.6899e+00)	Acc@1   7.81 (  6.72)	Acc@2  12.50 ( 10.73)
Epoch: [0][ 250/2163]	Time  0.242 ( 0.252)	Data  0.185 ( 0.192)	Loss 3.5681e+00 (3.6571e+00)	Acc@1  10.16 (  7.88)	Acc@2  16.41 ( 12.34)
Epoch: [0][ 300/2163]	Time  0.240 ( 0.251)	Data  0.184 ( 0.192)	Loss 3.4271e+00 (3.6268e+00)	Acc@1  19.53 (  9.09)	Acc@2  24.22 ( 13.87)
Epoch: [0][ 350/2163]	Time  0.257 ( 0.250

To derive R values from LRP

In [None]:

import h5py
import joypy

import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
import datetime

import sys
sys.path.append("/Users/ayshahchan/Desktop/ESPACE/thesis/codes/XAI_thesis/train")
import os
import pandas as pd
import numpy as np
import mlps as mlps
import matplotlib.cm as cm
import pickle

from collections import Counter



In [None]:
PATH = '/Users/ayshahchan/Desktop/ESPACE/thesis/codes/XAI_thesis/models/best_checkpoint.pth.tar'


data = dataset.data
classes = dataset.crpgrpn
crpNames = dataset.classes 
sample = data.loc[1].sort_index().fillna(0)
sorteddf = data.sort_index(axis=1)
date_values = pd.to_datetime(sample.index)


dates_json = sample.index
max_len = len(sample)
# Instead of taking the position, the numbers of days since the first observation is used
days = torch.zeros(max_len)
date_0 = dates_json[0]
date_0 = datetime.datetime.strptime(str(date_0), "%Y-%m-%d")
days[0] = 0
for i in range(max_len - 1):
    date = dates_json[i + 1]
    date = datetime.datetime.strptime(str(date), "%Y-%m-%d")
    days[i + 1] = (date - date_0).days
days = days.unsqueeze(1)
days = days.unsqueeze(0).to(device)

In [None]:


model = PSE_TAE(device)
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['state_dict'])
optimizer = torch.optim.Adam(model.parameters(),  lr=1e-5, betas=(0.9, 0.999),weight_decay=1e-6)
optimizer.load_state_dict(checkpoint['optimizer'])
epoch = checkpoint['epoch']
loss = checkpoint['best_acc1']
model.to(device)
Rdict = []
iddict = []
crpdict = []
cropNames = []
predictid = []
for i, batch in enumerate(train_loader):
     
    x = batch['data'].to(device)
    x_label = batch['label'].to(device)
    # days = batch["dates"].to(device)
    
    x_with_grad = torch.autograd.Variable(x, requires_grad=True)
    parcel_id = batch['ids'].squeeze().detach().numpy()
    
    
    crop_name = batch['crop name'][0]
    

    model.eval()

    

    outs = model.forward_and_explain(x_with_grad, x_label, days)
    attribution = outs['R'].squeeze()
    predictid.append(outs['logits'].argmax().squeeze())
    Rdict.append(attribution)
    iddict.append(parcel_id)
    crpdict.append(x_label.cpu().numpy())
    cropNames.append(crop_name)


Rdict = np.array(Rdict)
predictid = np.array(torch.tensor(predictid).cpu())


Sample code to plot the relevance scores at each timestep and the corresponding band-specific relevance scores 

In [None]:
NORMALIZING_FACTOR = 1e-4
BANDS = ['B1', 'B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8','B9','B10', 'B11', 'B12',
       'B8A']

parcel_id = 113
dict_id = np.where(iddict==parcel_id)[0].item()
attribution = Rdict[dict_id].squeeze()
dataval = sorteddf.loc[parcel_id]
print(cropNames[dict_id])
# fill missing arrays with 0 and normalize
a = [np.array(ii) for ii in dataval ] 
x = np.empty([len(a), 13])
for ii in range(len(a)):
    if a[ii] is None or np.all(a[ii] == None):
        x[ii,:] =  np.zeros(13)
    else:
        x[ii,:] = np.array(a[ii]* NORMALIZING_FACTOR)
x = np.nan_to_num(x)


r_df = pd.DataFrame(attribution, columns=BANDS, index = date_values)
# change date range according to data
idx = pd.date_range('01-06-2019', '12-27-2019')
test = r_df.reindex(idx, fill_value=0) 
x_range = list(range(len(test)))
fig, axes = joypy.joyplot(test,overlap=2, kind="values", x_range=x_range, linecolor="black", colormap=cm.tab20,
grid=True, figsize=(10,8))
xtick_range = [datetime.date(2021, y, 1).timetuple().tm_yday-6 for y in range(2,13)]
xlabel = [ y.isoformat()[0:7] for y in test.index.date[xtick_range]]
axes[-1].set_xticks(xtick_range)
axes[-1].set_xticklabels(xlabel)

test = np.sum(attribution,axis=1)
center_function = lambda x: x - np.nanmean(x)
normR =(test - min(test)) / ( max(test) - min(test) ) -0.5
centeredR = center_function(normR)*2


fig, axs = plt.subplots(2,1, figsize=(12,8))

colors = cm.tab20(np.linspace(0, 1, len(attribution.T)))
for z, c in zip(x.T, colors):
    axs[0].plot(date_values, z,'o-',color=c)
axs[0].set_ylabel("Input", fontsize=18)
axs[0].set_xlim([date_values[0], date_values[len(date_values)-1]])

axs[0].legend(BANDS)
axs[1].grid()
axs[1].plot(date_values, centeredR, 'o-',color='black')
axs[1].set_ylabel("R", fontsize=18)
#axs[1].set_ylim([-1, 1])
axs[1].set_xlim([date_values[0], date_values[len(date_values)-1]])