In [1]:
import torch
from torch.utils.data import DataLoader
from model import SBERT
from trainer import SBERTFineTuner
from dataset import FinetuneDataset
import numpy as np
import random
import argparse

In [2]:
cl1_names={ 0: 'Other', 1: 'Summer Cereals', 2: 'Winter Cereals', 3: 'Permanant Grasslands'}
cl2_names={0: 'Ungrouped Crops', 1: 'Oilseeds', 2: 'Permanant Crops', 3: 'Permanant Grasslands', 
          4: 'Vegetables & Root Crops', 5: 'Corn', 6: 'Summer Other Cereals', 7: 'Winter Wheat', 
          8: 'Winter Other Cereals'}

In [3]:
train_old_file = '/home/pc4dl/SYM2/SITS/data/L2_Case_II_Old/Train.csv'
valid_old_file = '/home/pc4dl/SYM2/SITS/data/L2_Case_II_Old/Validate.csv'
test_old_file = '/home/pc4dl/SYM2/SITS/data/L2_Case_II_Old/Test.csv'

train_old_dataset = FinetuneDataset(train_old_file, 10, 64)
valid_old_dataset = FinetuneDataset(valid_old_file, 10, 64)
test_old_dataset = FinetuneDataset(test_old_file, 10, 64)

print("training samples: %d, validation samples: %d, testing samples: %d" % (train_old_dataset.TS_num, valid_old_dataset.TS_num, test_old_dataset.TS_num))

training samples: 84139, validation samples: 28378, testing samples: 57458


In [4]:
train_old_data_loader = DataLoader(train_old_dataset, shuffle=False, batch_size=128, drop_last=False)
valid_old_data_loader = DataLoader(valid_old_dataset, shuffle=False, batch_size=128, drop_last=False)
test_old_data_loader = DataLoader(test_old_dataset, shuffle=False, batch_size=128, drop_last=False)

In [5]:
sbert = SBERT(10, hidden=256, n_layers=3, attn_heads=8, dropout=0.1)

In [6]:
trainer_old = SBERTFineTuner(sbert, 9, train_dataloader=train_old_data_loader, valid_dataloader=valid_old_data_loader)

In [7]:
print("Testing SITS-BERT...")
trainer_old.load('../../checkpoints/CP_SR_VI_L2_Case_II_finetune/')

Testing SITS-BERT...
EP:3 Model loaded from: ../../checkpoints/CP_SR_VI_L2_Case_II_finetune/checkpoint.tar


'../../checkpoints/CP_SR_VI_L2_Case_II_finetune/checkpoint.tar'

In [8]:
OA_old, Kappa_old, AA_old, matrix_old, obs_old, pred_old = trainer_old.test(test_old_data_loader)
print('Old Predict Summary\n')
print('test_OA = %.2f, test_kappa = %.3f, test_AA = %.3f' % (OA_old, Kappa_old, AA_old))

Old Predict Summary

test_OA = 77.03, test_kappa = 0.699, test_AA = 0.655


In [9]:
obs_old_all = torch.cat(obs_old,dim=0)
pred_old_all = torch.cat(pred_old,dim=0)

In [10]:
obs_old_np = obs_old_all.numpy()
pred_old_np = pred_old_all.numpy()

In [11]:
from sklearn.metrics import confusion_matrix, accuracy_score, f1_score, classification_report
from collections import Counter



In [12]:
Counter(obs_old_np)

Counter({2: 1371,
         3: 17786,
         5: 19725,
         8: 7095,
         7: 4049,
         0: 5204,
         4: 1205,
         6: 567,
         1: 456})

In [13]:
Counter(pred_old_np)

Counter({2: 2891,
         3: 19240,
         5: 18347,
         7: 2266,
         8: 8249,
         6: 2750,
         4: 2582,
         0: 910,
         1: 223})

In [14]:
np.save('/home/pc4dl/SYM2/SITS/data/Predict_Cases/Case_II_Class_L2_Test_Old_PredArray.npy', pred_old_np)

In [15]:
print("=================================\n")
print('SITS Model with SR+VI CL1 TEST OLD (2010-2011)')
print(confusion_matrix(obs_old_np, pred_old_np))
print("\n")

print(classification_report(y_true=obs_old_np, y_pred=pred_old_np, target_names=list(cl2_names.values()), digits=3))


print("=================================\n")


SITS Model with SR+VI CL1 TEST OLD (2010-2011)
[[  552    19   382  2900   489   292   352    88   130]
 [    8   201     1     1     5     7    15    11   207]
 [   26     0  1067   136    28    33    61     2    18]
 [  101     0   906 15985    44    68   515     6   161]
 [   49     1    46     5   951    81    65     2     5]
 [  101     0   437   128   994 17745   254    11    55]
 [   16     0    17     8    39    18   447     3    19]
 [   13     1     7     2    15    36   574  1529  1872]
 [   44     1    28    75    17    67   467   614  5782]]


                         precision    recall  f1-score   support

        Ungrouped Crops      0.607     0.106     0.181      5204
               Oilseeds      0.901     0.441     0.592       456
        Permanant Crops      0.369     0.778     0.501      1371
   Permanant Grasslands      0.831     0.899     0.863     17786
Vegetables & Root Crops      0.368     0.789     0.502      1205
                   Corn      0.967     0.900 