In [1]:
import torch
from torch.utils.data import DataLoader
from model import SBERT
from trainer import SBERTFineTuner
from dataset import FinetuneDataset
import numpy as np
import random
import argparse

In [2]:
cl1_names={ 0: 'Other', 1: 'Summer Cereals', 2: 'Winter Cereals', 3: 'Permanant Grasslands'}
cl2_names={0: 'Ungrouped Crops', 1: 'Oilseeds', 2: 'Permanant Crops', 3: 'Permanant Grasslands', 
          4: 'Vegetables & Root Crops', 5: 'Corn', 6: 'Summer Other Cereals', 7: 'Winter Wheat', 
          8: 'Winter Other Cereals'}

In [3]:
train_old_file = '/home/pc4dl/SYM2/SITS/data/L1_Case_I_Old/Train.csv'
valid_old_file = '/home/pc4dl/SYM2/SITS/data/L1_Case_I_Old/Validate.csv'
test_old_file = '/home/pc4dl/SYM2/SITS/data/L1_Case_I_Old/Test.csv'

train_old_dataset = FinetuneDataset(train_old_file, 10, 64)
valid_old_dataset = FinetuneDataset(valid_old_file, 10, 64)
test_old_dataset = FinetuneDataset(test_old_file, 10, 64)

print("training samples: %d, validation samples: %d, testing samples: %d" % (train_old_dataset.TS_num, valid_old_dataset.TS_num, test_old_dataset.TS_num))

training samples: 28176, validation samples: 84341, testing samples: 57458


In [4]:
train_old_data_loader = DataLoader(train_old_dataset, shuffle=False, batch_size=128, drop_last=False)
valid_old_data_loader = DataLoader(valid_old_dataset, shuffle=False, batch_size=128, drop_last=False)
test_old_data_loader = DataLoader(test_old_dataset, shuffle=False, batch_size=128, drop_last=False)

In [5]:
sbert = SBERT(10, hidden=256, n_layers=3, attn_heads=8, dropout=0.1)

In [6]:
trainer_old = SBERTFineTuner(sbert, 4, train_dataloader=train_old_data_loader, valid_dataloader=valid_old_data_loader)

In [7]:
print("Testing SITS-BERT...")
trainer_old.load('../../checkpoints/CP_SR_VI_L1_Case_I_finetune/')

Testing SITS-BERT...
EP:7 Model loaded from: ../../checkpoints/CP_SR_VI_L1_Case_I_finetune/checkpoint.tar


'../../checkpoints/CP_SR_VI_L1_Case_I_finetune/checkpoint.tar'

In [8]:
OA_old, Kappa_old, AA_old, matrix_old, obs_old, pred_old = trainer_old.test(test_old_data_loader)
print('Old Predict Summary\n')
print('test_OA = %.2f, test_kappa = %.3f, test_AA = %.3f' % (OA_old, Kappa_old, AA_old))

Old Predict Summary

test_OA = 83.21, test_kappa = 0.766, test_AA = 0.781


In [9]:
obs_old_all = torch.cat(obs_old,dim=0)
pred_old_all = torch.cat(pred_old,dim=0)

In [10]:
obs_old_np = obs_old_all.numpy()
pred_old_np = pred_old_all.numpy()

In [11]:
from sklearn.metrics import confusion_matrix, accuracy_score, f1_score, classification_report
from collections import Counter



In [12]:
Counter(obs_old_np)

Counter({0: 8236, 3: 17786, 1: 20292, 2: 11144})

In [13]:
Counter(pred_old_np)

Counter({0: 6336, 3: 20773, 1: 18791, 2: 11558})

In [14]:
np.save('/home/pc4dl/SYM2/SITS/data/Predict_Cases/Case_I_Class_L1_Test_Old_PredArray.npy', pred_old_np)

In [15]:
print("=================================\n")
print('SITS Model with SR+VI CL1 TEST OLD (2010-2011)')
print(confusion_matrix(obs_old_np, pred_old_np))
print("\n")

print(classification_report(y_true=obs_old_np, y_pred=pred_old_np, target_names=list(cl1_names.values()), digits=3))


print("=================================\n")


SITS Model with SR+VI CL1 TEST OLD (2010-2011)
[[ 3169   763   767  3537]
 [ 2627 17271   232   162]
 [   88   534 10408   114]
 [  452   223   151 16960]]


                      precision    recall  f1-score   support

               Other      0.500     0.385     0.435      8236
      Summer Cereals      0.919     0.851     0.884     20292
      Winter Cereals      0.901     0.934     0.917     11144
Permanant Grasslands      0.816     0.954     0.880     17786

            accuracy                          0.832     57458
           macro avg      0.784     0.781     0.779     57458
        weighted avg      0.824     0.832     0.825     57458


