In [1]:
import torch
from torch.utils.data import DataLoader
from model import SBERT
from trainer import SBERTFineTuner
from dataset import FinetuneDataset
import numpy as np
import random
import argparse

In [2]:
cl1_names={ 0: 'Other', 1: 'Summer Cereals', 2: 'Winter Cereals', 3: 'Permanant Grasslands'}
cl2_names={0: 'Ungrouped Crops', 1: 'Oilseeds', 2: 'Permanant Crops', 3: 'Permanant Grasslands', 
          4: 'Vegetables & Root Crops', 5: 'Corn', 6: 'Summer Other Cereals', 7: 'Winter Wheat', 
          8: 'Winter Other Cereals'}

In [3]:
train_old_file = '/home/pc4dl/SYM2/SITS/data/L2_Case_I_Old/Train.csv'
valid_old_file = '/home/pc4dl/SYM2/SITS/data/L2_Case_I_Old/Validate.csv'
test_old_file = '/home/pc4dl/SYM2/SITS/data/L2_Case_I_Old/Test.csv'

train_old_dataset = FinetuneDataset(train_old_file, 10, 64)
valid_old_dataset = FinetuneDataset(valid_old_file, 10, 64)
test_old_dataset = FinetuneDataset(test_old_file, 10, 64)

print("training samples: %d, validation samples: %d, testing samples: %d" % (train_old_dataset.TS_num, valid_old_dataset.TS_num, test_old_dataset.TS_num))

training samples: 28176, validation samples: 84341, testing samples: 57458


In [4]:
train_old_data_loader = DataLoader(train_old_dataset, shuffle=False, batch_size=128, drop_last=False)
valid_old_data_loader = DataLoader(valid_old_dataset, shuffle=False, batch_size=128, drop_last=False)
test_old_data_loader = DataLoader(test_old_dataset, shuffle=False, batch_size=128, drop_last=False)

In [5]:
sbert = SBERT(10, hidden=256, n_layers=3, attn_heads=8, dropout=0.1)

In [6]:
trainer_old = SBERTFineTuner(sbert, 9, train_dataloader=train_old_data_loader, valid_dataloader=valid_old_data_loader)

In [7]:
print("Testing SITS-BERT...")
trainer_old.load('../../checkpoints/CP_SR_VI_L2_Case_I_finetune/')

Testing SITS-BERT...
EP:6 Model loaded from: ../../checkpoints/CP_SR_VI_L2_Case_I_finetune/checkpoint.tar


'../../checkpoints/CP_SR_VI_L2_Case_I_finetune/checkpoint.tar'

In [8]:
OA_old, Kappa_old, AA_old, matrix_old, obs_old, pred_old = trainer_old.test(test_old_data_loader)
print('Old Predict Summary\n')
print('test_OA = %.2f, test_kappa = %.3f, test_AA = %.3f' % (OA_old, Kappa_old, AA_old))

Old Predict Summary

test_OA = 75.84, test_kappa = 0.677, test_AA = 0.559


In [9]:
obs_old_all = torch.cat(obs_old,dim=0)
pred_old_all = torch.cat(pred_old,dim=0)

In [10]:
obs_old_np = obs_old_all.numpy()
pred_old_np = pred_old_all.numpy()

In [11]:
from sklearn.metrics import confusion_matrix, accuracy_score, f1_score, classification_report
from collections import Counter



In [12]:
Counter(obs_old_np)

Counter({2: 1371,
         3: 17786,
         5: 19725,
         8: 7095,
         7: 4049,
         0: 5204,
         4: 1205,
         6: 567,
         1: 456})

In [13]:
Counter(pred_old_np)

Counter({2: 1665,
         3: 22353,
         5: 18458,
         7: 4622,
         8: 5597,
         6: 647,
         4: 2738,
         0: 1217,
         1: 161})

In [14]:
np.save('/home/pc4dl/SYM2/SITS/data/Predict_Cases/Case_I_Class_L2_Test_Old_PredArray.npy', pred_old_np)

In [15]:
print("=================================\n")
print('SITS Model with SR+VI CL1 TEST OLD (2010-2011)')
print(confusion_matrix(obs_old_np, pred_old_np))
print("\n")

print(classification_report(y_true=obs_old_np, y_pred=pred_old_np, target_names=list(cl2_names.values()), digits=3))


print("=================================\n")


SITS Model with SR+VI CL1 TEST OLD (2010-2011)
[[  520    11   120  3430   533   339    45   130    76]
 [    8   126     0     7    12     5     3    44   251]
 [   33     0   698   557    32    32    14     4     1]
 [   80     0    76 17474    27    65    23    25    16]
 [   70     1    61    10   938   105     9     8     3]
 [  242     5   605   404  1035 17026   148   233    27]
 [   92     0    30    22    54    41   184   133    11]
 [   66     1    11    62    29   164   101  2506  1109]
 [  106    17    64   387    78   681   120  1539  4103]]


                         precision    recall  f1-score   support

        Ungrouped Crops      0.427     0.100     0.162      5204
               Oilseeds      0.783     0.276     0.408       456
        Permanant Crops      0.419     0.509     0.460      1371
   Permanant Grasslands      0.782     0.982     0.871     17786
Vegetables & Root Crops      0.343     0.778     0.476      1205
                   Corn      0.922     0.863 