Sample Crop Classification using Eurocrops Data

Model architecture adapted from [Garnot et al., 2020](https://openaccess.thecvf.com/content_CVPR_2020/papers/Garnot_Satellite_Image_Time_Series_Classification_With_Pixel-Set_Encoders_and_Temporal_CVPR_2020_paper.pdf)


If using TinyEuroCrops, only the section below needs to be adjusted for desired variables and directories

In [14]:
# Training related parameters 
# number of epochs of training
n_epochs = 10
# size of the batches
batch_size = 128
# focal loss:  focusing parameter
gamma = 1
# adam: learning rate
learning_rate = 1e-5
# adam: decay of first order momentum of gradient
b1 = 0.9
# adam: decay of first order momentum of gradient
b2 = 0.999
# adam: weight decay (L2 penalty)
weight_decay = 1e-6
# print frequency of progress meter
print_freq = 50


# name of checkpoint 
CP_name= 'checkpoint.pth.tar'
# initialize a dummy best accuracy 
best_acc1 = 0
# location to save current tensorboard session
current_run = '/Users/ayshahchan/Desktop/PhD/runs/psetae'

# Data Directory
# this section is specific to TinyEuroCrops, should other data be used please adjust the Data Loader Section for the correct directories as well
# root location of data
root = '/Users/ayshahchan/Desktop/Education/ESPACE/thesis/codes/data'
partition = "train"
# This notebook uses the train section of Austrian TinyEuroCrops
country='AT_T33UWP'

Data Loader for EuroCrop Demo Data TinyEuroCrops

The data loader loads the data from 4 different files: one containing the spectral reflectances for training, one containing the spectral reflectances for testing, one containing the labels for training and the last containing the labels for testing.

The code assumes the files are saved under different folders in the same root directory. Should the paths are different, please adjust the code in this section accordingly. This code assumes TinyEuroCrops file structure.


Note: although the processing process is similar to the webinar codes, some column names are different. When applying this code on newly processed datasets: please take note the column names crpgrpc and crpgrpn may be hcat_c and hcat_n instead

In [15]:
import torch
import torch.utils.data
import pandas as pd  
import os
import numpy as np
from numpy import genfromtxt
import datetime
import h5py


In [16]:
BANDS = ['B1', 'B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8','B9','B10', 'B11', 'B12',
       'B8A']
classes = [33101011, 33101012, 33101021, 33101022, 33101032, 33101041, 33101042,
 33101051, 33101052, 33101060, 33101071, 33101072, 33101080, 33101100,
 33102020, 33103000, 33104000, 33105000, 33106020, 33106042, 33106050,
 33106060, 33106080, 33106100, 33106120, 33106130, 33107000, 33109000,
 33110000, 33111010, 33111022, 33111023, 33112000, 33114000, 33200000,
 33301010, 33301030, 33301040, 33304000, 33305000, 33402000, 33500000,
 33600000, 33700000,]
NORMALIZING_FACTOR = 1e-4
PADDING_VALUE = -1

class EuroCropsDataset(torch.utils.data.Dataset):

    def __init__(self, root, partition, country ):
        if torch.cuda.is_available():
            self.device = torch.device("cuda")

        else:
            self.device = "cpu"

        self.partition = partition
        
        

        self.root = root
        if self.partition == "train":
            self.h5_file_path = os.path.join(self.root, "HDF5s", "train", country+"_train"+".h5")
        elif  self.partition == "test":    
            self.h5_file_path = os.path.join(self.root, "HDF5s", "test", country+"_test"+".h5")
        

        
        h5_file = h5py.File(self.h5_file_path)
        
        region_all = []
        for name, h5obj in h5_file.items():
            if isinstance(h5obj,h5py.Group):
                region_all.append(name)
        all_labelsfile = []
        all_data= []
        for i in range(len(region_all)):
            region = region_all[i]
            csv_file_name = 'demo_eurocrops_' + region + '.csv'
            if self.partition == "train":
                csv_file_path = os.path.join(self.root, "csv_labels", "train", csv_file_name)
        
            elif  self.partition == "test":    
                csv_file_path = os.path.join(self.root, "csv_labels", "test", csv_file_name)

            labelsfile = pd.read_csv(csv_file_path, index_col=0)
            all_labelsfile.append(labelsfile)
            data = pd.read_hdf(self.h5_file_path, region)
            all_data.append(data)

        
        self.labelsfile = pd.concat(all_labelsfile)
        
        self.mapping = self.labelsfile.set_index("crpgrpc")
        self.classes = self.labelsfile["crpgrpc"].unique()
        self.crpgrpn = self.labelsfile.groupby("crpgrpc").first().crpgrpn.values
        self.nclasses = len(self.classes)

        
        self.data = pd.concat(all_data)
        
        
        ids = list(self.data.index)
        self.ids = ids    
        print('{} parcels in file with {} classes '.format(len(ids),self.nclasses))
       
        

    
    def __len__(self):
        return len(self.ids)

    def __getitem__(self, idx):
        

        spectral_data = self.data.loc[self.ids[idx]].sort_index()
        
        self.id = self.ids[idx]

        
        crop_no = self.labelsfile.loc[self.id]['crpgrpc']
        label = self.labelsfile.loc[self.id]['crpgrpn']
        
        y_label = classes.index(int(crop_no))

        #length = max(map(len,spectral_data))
        length = 13
        
        spectral_data_array = np.empty((0, length))

        for ii in range(spectral_data.shape[0]):
            if np.isnan(spectral_data[ii]).all()==True:
                test = np.zeros(length)
                test.shape = (-1, len(test))
                spectral_data_array = np.concatenate((spectral_data_array,test))
            else:
                test = np.array(spectral_data[ii])* NORMALIZING_FACTOR
                test.shape = (-1, len(test))
                spectral_data_array = np.concatenate((spectral_data_array,test))
               


        

        X = torch.tensor(spectral_data_array).type(torch.FloatTensor).to(self.device)
        # y= torch.from_numpy(np.array(crop_no)).type(torch.LongTensor).to(self.device)
        
        dates_json = spectral_data.index
        max_len = len(spectral_data)
        # Instead of taking the position, the numbers of days since the first observation is used
        days = torch.zeros(max_len)
        date_0 = dates_json[0]
        date_0 = datetime.datetime.strptime(str(date_0), "%Y%m%d")
        days[0] = 0
        for i in range(max_len - 1):
            date = dates_json[i + 1]
            date = datetime.datetime.strptime(str(date), "%Y%m%d")
            days[i + 1] = (date - date_0).days
        days = days.unsqueeze(1)
        
        return {'data':X, 'label':y_label, 'ids':self.id, 'crop name':label, 'dates':days}

            



Defining the Model
There are three main parts of the model: the multilayer perceptron encoder, the attention layer and the multilayer perceptron decoder. 

![model diagram](webinar_demo/model.png)
![attention diagram](webinar_demo/attention_block.png)


You can adjust the model as you wish, whether it is adding more layers or removing the dropout layers. Removing the dropout layers will lead to faster training but may also lead to overfitting.



In [17]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.autograd import Variable

Multilayer Perceptrons


In [18]:

class MLP1A(nn.Module):
    """
    Shared Multilayer Perceptron composed of a succession of fully connected layers,
    batch-norms and ReLUs
    INPUT (N x L x C)
    """

    def __init__(self):
        super(MLP1A, self).__init__()
        self.fc1 = nn.Linear(13, 32)
        self.bn1 = nn.BatchNorm1d(num_features=32)
        self.dropout = nn.Dropout(0.1)


    def forward(self, x):
        x = self.fc1(x)

        x = x.transpose(2, 1)
        x = self.bn1(x)     # BN1d takes [batch_size x channels x seq_len] 
        x = x.transpose(2, 1)
        x = self.dropout(x) 
        x = F.relu(x)
        return x


class MLP1B(nn.Module):
    """
    Shared Multilayer Perceptron composed of a succession of fully connected layers,
    batch-norms and ReLUs
    INPUT (N x L x C)
    """

    def __init__(self):
        super(MLP1B, self).__init__()

        self.fc2 = nn.Linear(32, 64)
        self.bn2 = nn.BatchNorm1d(num_features=64)

    def forward(self, x):
        x = self.fc2(x)
        x = x.transpose(2, 1)
        x = self.bn2(x)   # BN1d takes [batch_size x channels x seq_len]
        x = x.transpose(2, 1)
        x = F.relu(x)
        return x
    
class MLP2A(nn.Module):
    """
    Multilayer Perceptron number three
    INPUT (N x L)
    Lazylinear is used as the number of infeatures changes with different sequence lengths
    """

    def __init__(self):
        super(MLP2A, self).__init__()
        self.fc1 = nn.LazyLinear(64)
        self.bn1 = nn.BatchNorm1d(num_features=64)
        self.dropout = nn.Dropout(0.1)


    def forward(self, x):
        x = self.fc1(x)
        x = self.bn1(x)
        x = self.dropout(x)
        x = F.relu(x)

        return x



        
class MLP2B(nn.Module):
    """
    Multilayer Perceptron number three
    INPUT (N x L)
    """

    def __init__(self):
        super(MLP2B, self).__init__()

        self.fc2 = nn.Linear(64, 64)
        self.bn2 = nn.BatchNorm1d(num_features=64)

    def forward(self, x):

        x = self.fc2(x)
        x = self.bn2(x)
        x = F.relu(x)
        return x


class MLP3A(nn.Module):
    """
    Decoder Multilayer Perceptron.
    INPUT (N x L)
    """

    def __init__(self):
        super(MLP3A, self).__init__()
  
        self.fc2 = nn.Linear(64, 32)
        self.bn2 = nn.BatchNorm1d(num_features=32)


    def forward(self, x):
 
        x = self.fc2(x)
        x = self.bn2(x)
        x = F.relu(x)

        return x

class MLP3B(nn.Module):
    """
    Decoder Multilayer Perceptron.
    INPUT (N x L)
    """

    def __init__(self):
        super(MLP3B, self).__init__()

        self.fc3 = nn.Linear(32, 44)


    def forward(self, x):


        x = self.fc3(x)
        return x          

Spectral Encoder

A simple encoder for initial feature extraction mainly in the spectral domain.

In [19]:

    
class SpectralEncoder(nn.Module):
    """
    First part of the presented architecture.
    Yields to a spatio-spectral embedding at time t
    """

    def __init__(self, device):
        super(SpectralEncoder, self).__init__()
        self.device = device
        
        self.mlp1a = MLP1A()
        self.mlp1b = MLP1B()
        

    def forward(self, x):  # x: [batch_size x  seq_len x channels:10]

        
        batch_size, seq_len, channels = x.shape
        

        mlp1_output = self.mlp1a(x)                              # [batch_size x seq_len x hidden_state:32]
        mlp1_output = self.mlp1b(mlp1_output)                   # [batch_size x seq_len x hidden_state:64]


        pooled = mlp1_output.contiguous().view(batch_size, seq_len, -1)  # [batch_size x seq_len x hidden_state:64]
        pooled = pooled.type('torch.FloatTensor')

        return pooled


Attention Layer

For Positional Encoding: one of the inputs is days. This refers to the number of days since first data point due to the irregular frequency of data acquisition. This should be a readily availble output of the data loader.

In [20]:
class PositionalEncoding(nn.Module):

    def __init__(self,days, d_e=64, max_len=80):
        super(PositionalEncoding, self).__init__()        

        # Calculate the positional encoding p
        p = torch.zeros(max_len, d_e)
        
        div_term = torch.exp(torch.arange(0, d_e, 2).float() * (-math.log(1000.0) / d_e))
        p[:, 0::2] = torch.sin(days * div_term)
        p[:, 1::2] = torch.cos(days * div_term)
        p = p.unsqueeze(0)
        
        self.register_buffer('p', p)

    def forward(self, x):
        
        x = x + self.p
        return x




class AttentionLayer(nn.Module):
    def __init__(self, d_k, h):
        super(AttentionLayer, self).__init__()
        self.h = h
        self.d_k = d_k
        self.d_e = self.d_k * self.h

        self.fc1_q = nn.Linear(self.d_e, self.d_e)
        self.fc1_k = nn.Linear(self.d_e, self.d_e)
        self.fc1_v = nn.Linear(self.d_e, self.d_e)

        self.fc1 = nn.Linear(self.d_k*self.h, self.d_e)
        self.dropout = nn.Dropout(0.1)
        self.layer_norm = nn.LayerNorm(self.d_k*self.h)


        self.q = ()
        self.k = ()
        self.v = ()
        self.attention_scores = ()
        self.attention_output = ()
        self.attention_probs = ()


    def forward(self, e_p, batch_size, seq_len):
        q = self.fc1_q(e_p)         # [batch_size x seq_len x hidden_state:64]
        q = q.view(batch_size,  seq_len, self.h, self.d_k)     # [batch_size x seq_len x num_heads x d_k]
        q = q.permute(0, 2, 1, 3).contiguous().view(-1,seq_len,  self.d_k)     # [batch_size * num_heads x seq_len x d_k]
        self.q = q
        # Keys
        k = self.fc1_k(e_p)                                                 # [batch_size x seq_len x hidden_state:64]
        k = k.view(batch_size, seq_len, self.h, self.d_k)                   # [batch_size x seq_len x num_heads x d_k]
        k = k.permute(0, 2, 1, 3).contiguous().view(-1, seq_len, self.d_k)  # [batch_size * num_heads x seq_len x d_k]
        self.k = k
        # Values 
          
        v = self.fc1_v(e_p)  # [batch_size * num_heads x seq_len x hidden:64]
        v = v.view(batch_size, seq_len, self.h, self.d_k)       # [batch_size x seq_len x num_heads x d_k]
        v = v.permute(0, 2, 1, 3).contiguous().view(-1, seq_len, self.d_k) # [batch_size * num_heads x seq_len x d_k]
        self.v = v
        
        # Attention
        attention_scores = q.matmul(k.transpose(-2, -1)) / math.sqrt(self.d_k)      # [batch_size * num_heads x seq_len x seq_len]
        self.attention_scores = attention_scores        # ,4,80,80  batch_size * num_heads x seq_len x seq_len
        
        attention_probs = F.softmax(attention_scores, dim=-1)               # [batch_size * num_heads x 64 (d_e) x seq_len]
       
        attention_output = torch.matmul(attention_probs, v).squeeze()       #  [batch_size* 4 (num_h) x 80 (seq length) x 64 (d_e)]
        
        
        attention_output = attention_output.view(self.h, batch_size,seq_len, self.d_k)  # num_heads x batch_size x seq_len x dk 16
        attention_output = attention_output.permute(1,2,0,3).contiguous().view(batch_size,seq_len,-1) # batch_size x seq_len x dk * num_heads (d_e 64)


        ################ adding drop out and layer norm to reduce overfitting
        attention_output = self.dropout(self.fc1(attention_output)) # batch_size x seq_len x dk * num_heads (d_e 64)
        attention_output = self.layer_norm(attention_output+e_p) # batch_size x seq_len x dk * num_heads (d_e 64)
        
        self.attention_output = attention_output
        self.attention_probs = attention_probs

        return attention_output 

Temporal Attention Encoder

Encoder that primarily extracts features in the temporal domain by leveraging the attention mechanism


In [21]:
class TAE(nn.Module):
    def __init__(self) -> None:
        super(TAE, self).__init__()
        self.mlp2a = MLP2A()
        self.mlp2b = MLP2B()
        self.mlp3a = MLP3A()
        self.decoder = MLP3B()


    def forward(self, attention_output):    
        
        o_hat = self.mlp2a(attention_output)
        o_hat = self.mlp2b(o_hat)
        o_hat = self.mlp3a(o_hat)

        
        output = self.decoder(o_hat)
        return output


Combining all the individual modules

In [22]:



class PSE_TAE(nn.Module):

    def __init__(self, device, heads=4, d_e= 64):
        super(PSE_TAE, self).__init__()
        self.spectral_encoder = SpectralEncoder(device)
        
  
            
        self.d_e = d_e
        self.d_k = d_e // heads
        self.h = heads

        self.attn = AttentionLayer(self.d_k,self.h)
        self.tae = TAE()


             
        self.R0 = ()

    def forward(self, x, days):
        encoding = self.spectral_encoder(x)
        

        batch_size, seq_len, hidden_state = encoding.size()

        pos_encoding = PositionalEncoding(days[0,:,:].squeeze(0), d_e = self.d_e, max_len=seq_len)

        e_p = pos_encoding(encoding)
        # Queries
        attention_output1 = self.attn(e_p, batch_size,seq_len)
        attention_output = attention_output1.contiguous().view(batch_size, -1) # batch_size x seq_len * dk * num_heads (d_e 64)
        # Output
        batch_size, seq_len = attention_output.size()
        output = self.tae(attention_output)

        return output

Define a Loss Function

In [23]:

class FocalLoss(nn.Module):
    """
    From https://github.com/clcarwin/focal_loss_pytorch
    """

    def __init__(self, gamma=0, alpha=None, size_average=True):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.alpha = alpha
        if isinstance(alpha, (float, int)): self.alpha = torch.Tensor([alpha, 1 - alpha])
        if isinstance(alpha, list): self.alpha = torch.Tensor(alpha)
        self.size_average = size_average

    def forward(self, input, target):
        if input.dim() > 2:
            input = input.view(input.size(0), input.size(1), -1)  # N,C,H,W => N,C,H*W
            input = input.transpose(1, 2)  # N,C,H*W => N,H*W,C
            input = input.contiguous().view(-1, input.size(2))  # N,H*W,C => N*H*W,C
        target = target.view(-1, 1)

        logpt = F.log_softmax(input, dim=1)
        logpt = logpt.gather(1, target)
        logpt = logpt.view(-1)
        pt = Variable(logpt.data.exp())

        if self.alpha is not None:
            if self.alpha.type() != input.data.type():
                self.alpha = self.alpha.type_as(input.data)
            at = self.alpha.gather(0, target.data.view(-1))
            logpt = logpt * Variable(at)

        loss = -1 * (1 - pt) ** self.gamma * logpt
        if self.size_average:
            return loss.mean()
        else:
            return loss.sum()


Setting up for training

In [24]:
from torch.utils.data import DataLoader, random_split

from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter(current_run)
import time
import shutil
import numpy as np
import os
import torch

In [25]:
def train(train_loader,
          pse_tae,
          focal_loss,
          optimizer,
          epoch,
          print_freq,
          device):

    batch_time = AverageMeter('Time', ':6.3f')
    data_time = AverageMeter('Data', ':6.3f')
    # pse_time = AverageMeter('PSE', ':6.3f')
    # tae_time = AverageMeter('TAE', ':6.3f')
    # decode_time = AverageMeter('Decode', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@2', ':6.2f')
    progress = ProgressMeter(len(train_loader),
                             [batch_time, data_time, losses, top1, top5],
                             prefix="Epoch: [{}]".format(epoch))

    # -------------------------
    #  Put Models in Train Mode
    # -------------------------
    pse_tae.train()

    end = time.time()
    # running_loss = 0.0
    # running_correct = 0
    for i, batch in enumerate(train_loader):
        data_time.update(time.time() - end)

        # Get training data
        data = batch['data'].to(device)
        label = batch['label'].to(device)
        days = batch["dates"].to(device)
        
        
        # ---------------------------------
        #  Train Everything together
        # ---------------------------------
        optimizer.zero_grad()
        output = pse_tae(data,days)
        _, prediction = torch.max(output.data, 1)

        # ---------------------------------
        #  Loss
        # ---------------------------------

        # Focal Loss between output and target
        loss = focal_loss(output.to(device), label)

        # ---------------------------------
        #  Record Stats
        # ---------------------------------
        # Measure accuracy and record loss
        acc1, acc5 = accuracy(output, label, topk=(1, 2))
        losses.update(loss.item(), data.size(0))
        top1.update(acc1[0], data.size(0))
        top5.update(acc5[0], data.size(0))

        # ---------------------------------
        #  Gradient & SGD step
        # ---------------------------------
        loss.backward()
        optimizer.step()

        # ---------------------------------
        #  Time
        # ---------------------------------
        batch_time.update(time.time() - end)
        end = time.time()


        if i % print_freq == 0:
            progress.display(i)
            # for tensorboard
            writer.add_scalar('train loss', loss.item(), epoch * len(train_loader) + i)
            
            writer.add_scalar('accuracy best', acc1, epoch * len(train_loader) + i)
            

    



        


def validate(val_loader, pse_tae, focal_loss, epoch, print_freq, device):
    batch_time = AverageMeter('Time', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    progress = ProgressMeter(len(val_loader), [batch_time, losses, top1, top5], prefix='Test: ')

    # -------------------------
    #  Put Models in Eval Mode
    # -------------------------
    pse_tae.eval()

    with torch.no_grad():
        end = time.time()
        for i, val_batch in enumerate(val_loader):
            # Get validation data
            data_val = val_batch['data'].to(device)
            label_val = val_batch['label'].to(device)
            days_val = val_batch["dates"].to(device)
     

            # -------------------------
            #  Compute Predictions
            # -------------------------
            output = pse_tae(data_val, days_val)
            loss = focal_loss(output.to(device), label_val)
            _, prediction = torch.max(output.data, 1)

            # ---------------------------------
            #  Record Stats
            # ---------------------------------
            # measure accuracy and record loss
            acc1, acc5 = accuracy(output, label_val, topk=(1, 5))
            losses.update(loss.item(), data_val.size(0))
            top1.update(acc1[0], data_val.size(0))
            top5.update(acc5[0], data_val.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % print_freq == 0:
                progress.display(i)
                writer.add_scalar('val loss', loss.item(), epoch * len(val_loader) + i)
            #writer.add_scalar('accuracy', running_correct/100, epoch * len(train_loader) + i)
                writer.add_scalar('val accuracy', acc1, epoch * len(val_loader) + i)

        print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'.format(top1=top1, top5=top5))

    return top1.avg, losses.avg


def save_checkpoint(state, is_best, filename=CP_name):
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, 'best_'+filename)

def load_checkpoint(model, optimizer,  filename=CP_name):
    # Note: Input model & optimizer should be pre-defined.  This routine only updates their states.
    start_epoch = 0
    if os.path.isfile(filename):
        print("=> loading checkpoint '{}'".format(filename))
        checkpoint = torch.load(filename)
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        
        print("=> loaded checkpoint '{}' (epoch {})"
                  .format(filename, checkpoint['epoch']))
    else:
        print("=> no checkpoint found at '{}'".format(filename))

    return model, optimizer, start_epoch

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)


class ProgressMeter(object):
    def __init__(self, num_batches, meters, prefix=""):
        self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
        self.meters = meters
        self.prefix = prefix

    def display(self, batch):
        entries = [self.prefix + self.batch_fmtstr.format(batch)]
        entries += [str(meter) for meter in self.meters]
        print('\t'.join(entries))

    def _get_batch_fmtstr(self, num_batches):
        num_digits = len(str(num_batches // 1))
        fmt = '{:' + str(num_digits) + 'd}'
        return '[' + fmt + '/' + fmt.format(num_batches) + ']'


def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

class EarlyStopper:
    def __init__(self, patience=10, min_delta=0.2):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.min_validation_loss = np.inf

    def early_stop(self, validation_loss):
        if validation_loss < self.min_validation_loss:
            self.min_validation_loss = validation_loss
            self.counter = 0
        elif validation_loss > (self.min_validation_loss + self.min_delta):
            self.counter += 1
            if self.counter >= self.patience:
                return True
        return False

Actual Training

In [26]:



dataset = EuroCropsDataset(root=root,partition=partition,country=country)
           
fold_len = int(len(dataset) / 5)
n_train = len(dataset) -  fold_len
train_set, val_set =random_split(dataset,(n_train, fold_len),generator=torch.Generator().manual_seed(42))  


train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False)


if torch.cuda.is_available():
    device = torch.device("cuda")

else:
    device = "cpu"

pse_tae = PSE_TAE(device).to(device)

focal_loss = FocalLoss(gamma).to(device)

# Adam Optimizer
optimizer = torch.optim.Adam(pse_tae.parameters(), lr=learning_rate, betas=(b1, b2), weight_decay=weight_decay)

# stops training when validation loss no longer decreases by min_delta after X epochs as defined by patience.
early_stopper = EarlyStopper(patience=5, min_delta=0.2)

train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False)



pse_tae, optimizer, start_epoch = load_checkpoint(pse_tae, optimizer)


for epoch in range(start_epoch, n_epochs):

    # ----------
    #  Training
    # ----------

    train(train_loader,
            pse_tae,
            focal_loss,
            optimizer,
            epoch,
            print_freq,
            device)

    # -----------
    #  Validation
    # -----------

    acc1, val_loss = validate(val_loader,
                    pse_tae,
                    focal_loss,
                    epoch,
                    print_freq,
                    device)

    # -----------
    #  Remember best acc@1 and save checkpoint
    # -----------
    is_best = acc1 > best_acc1
    best_acc1 = max(acc1, best_acc1)
    
    if early_stopper.early_stop(val_loss):             
        break

    save_checkpoint({
            'epoch': epoch + 1,
            'state_dict': pse_tae.state_dict(),
            'best_acc1': best_acc1,
            'optimizer' : optimizer.state_dict(),
        }, is_best)

345970 parcels in file with 44 classes 
=> no checkpoint found at 'checkpoint.pth.tar'




Epoch: [0][   0/2163]	Time  1.982 ( 1.982)	Data  1.762 ( 1.762)	Loss 3.7913e+00 (3.7913e+00)	Acc@1   2.34 (  2.34)	Acc@2   3.91 (  3.91)
Epoch: [0][  50/2163]	Time  0.239 ( 0.281)	Data  0.183 ( 0.218)	Loss 3.7034e+00 (3.8050e+00)	Acc@1   7.03 (  3.74)	Acc@2  10.94 (  6.43)
Epoch: [0][ 100/2163]	Time  0.242 ( 0.263)	Data  0.184 ( 0.202)	Loss 3.7096e+00 (3.7637e+00)	Acc@1   5.47 (  4.56)	Acc@2   9.38 (  7.80)
Epoch: [0][ 150/2163]	Time  0.241 ( 0.257)	Data  0.184 ( 0.197)	Loss 3.6428e+00 (3.7221e+00)	Acc@1   5.47 (  5.74)	Acc@2  11.72 (  9.41)
Epoch: [0][ 200/2163]	Time  0.240 ( 0.254)	Data  0.184 ( 0.194)	Loss 3.5550e+00 (3.6899e+00)	Acc@1   7.81 (  6.72)	Acc@2  12.50 ( 10.73)
Epoch: [0][ 250/2163]	Time  0.242 ( 0.252)	Data  0.185 ( 0.192)	Loss 3.5681e+00 (3.6571e+00)	Acc@1  10.16 (  7.88)	Acc@2  16.41 ( 12.34)
Epoch: [0][ 300/2163]	Time  0.240 ( 0.251)	Data  0.184 ( 0.192)	Loss 3.4271e+00 (3.6268e+00)	Acc@1  19.53 (  9.09)	Acc@2  24.22 ( 13.87)
Epoch: [0][ 350/2163]	Time  0.257 ( 0.250