In [1]:
import torch
from torch.utils.data import DataLoader
from model import SBERT
from trainer import SBERTFineTuner
from dataset import FinetuneDataset
import numpy as np
import random
import argparse

In [2]:
cl1_names={ 0: 'Other', 1: 'Summer Cereals', 2: 'Winter Cereals', 3: 'Permanant Grasslands'}
cl2_names={0: 'Ungrouped Crops', 1: 'Oilseeds', 2: 'Permanant Crops', 3: 'Permanant Grasslands', 
          4: 'Vegetables & Root Crops', 5: 'Corn', 6: 'Summer Other Cereals', 7: 'Winter Wheat', 
          8: 'Winter Other Cereals'}

In [3]:
train_new_file = '/home/pc4dl/SYM2/SITS/data/L1_Case_I_New/Train.csv'
valid_new_file = '/home/pc4dl/SYM2/SITS/data/L1_Case_I_New/Validate.csv'
test_new_file = '/home/pc4dl/SYM2/SITS/data/L1_Case_I_New/Test.csv'

train_new_dataset = FinetuneDataset(train_new_file, 10, 64)
valid_new_dataset = FinetuneDataset(valid_new_file, 10, 64)
test_new_dataset = FinetuneDataset(test_new_file, 10, 64)

print("training samples: %d, validation samples: %d, testing samples: %d" % (train_new_dataset.TS_num, valid_new_dataset.TS_num, test_new_dataset.TS_num))

training samples: 28176, validation samples: 84341, testing samples: 54923


In [4]:
train_new_data_loader = DataLoader(train_new_dataset, shuffle=False, batch_size=128, drop_last=False)
valid_new_data_loader = DataLoader(valid_new_dataset, shuffle=False, batch_size=128, drop_last=False)
test_new_data_loader = DataLoader(test_new_dataset, shuffle=False, batch_size=128, drop_last=False)

In [5]:
sbert = SBERT(10, hidden=256, n_layers=3, attn_heads=8, dropout=0.1)

In [6]:
trainer_new = SBERTFineTuner(sbert, 4, train_dataloader=train_new_data_loader, valid_dataloader=valid_new_data_loader)

In [7]:
print("Testing SITS-BERT...")
trainer_new.load('../../checkpoints/CP_SR_VI_L1_Case_I_finetune/')

Testing SITS-BERT...
EP:7 Model loaded from: ../../checkpoints/CP_SR_VI_L1_Case_I_finetune/checkpoint.tar


'../../checkpoints/CP_SR_VI_L1_Case_I_finetune/checkpoint.tar'

In [8]:
OA_new, Kappa_new, AA_new, matrix_new, obs_new, pred_new = trainer_new.test(test_new_data_loader)
print('New Predict Summary\n')
print('test_OA = %.2f, test_kappa = %.3f, test_AA = %.3f' % (OA_new, Kappa_new, AA_new))

New Predict Summary

test_OA = 78.12, test_kappa = 0.698, test_AA = 0.741


In [9]:
obs_new_all = torch.cat(obs_new,dim=0)
pred_new_all = torch.cat(pred_new,dim=0)

In [10]:
obs_new_np = obs_new_all.numpy()
pred_new_np = pred_new_all.numpy()

In [11]:
from sklearn.metrics import confusion_matrix, accuracy_score, f1_score, classification_report
from collections import Counter



In [12]:
Counter(obs_new_np)

Counter({1: 19708, 2: 10854, 0: 6816, 3: 17545})

In [13]:
Counter(pred_new_np)

Counter({1: 13972, 2: 11350, 0: 8238, 3: 21363})

In [16]:
np.save('/home/pc4dl/SYM2/SITS/data/Predict_Cases/Case_I_Class_L1_Test_New_PredArray.npy', pred_new_np)

In [14]:
print("=================================\n")
print('SITS Model with SR+VI CL1 TEST NEW (2017-2018)')
print(confusion_matrix(obs_new_np, pred_new_np))
print("\n")

print(classification_report(y_true=obs_new_np, y_pred=pred_new_np, target_names=list(cl1_names.values()), digits=3))


print("=================================\n")


SITS Model with SR+VI CL1 TEST NEW (2017-2018)
[[ 2795   370   436  3215]
 [ 5214 13068   764   662]
 [   31   496  9941   386]
 [  198    38   209 17100]]


                      precision    recall  f1-score   support

               Other      0.339     0.410     0.371      6816
      Summer Cereals      0.935     0.663     0.776     19708
      Winter Cereals      0.876     0.916     0.895     10854
Permanant Grasslands      0.800     0.975     0.879     17545

            accuracy                          0.781     54923
           macro avg      0.738     0.741     0.730     54923
        weighted avg      0.807     0.781     0.782     54923


