In [1]:
import torch
from torch.utils.data import DataLoader
from model import SBERT
from trainer import SBERTFineTuner
from dataset import FinetuneDataset
import numpy as np
import random
import argparse

In [2]:
cl1_names={ 0: 'Other', 1: 'Summer Cereals', 2: 'Winter Cereals', 3: 'Permanant Grasslands'}
cl2_names={0: 'Ungrouped Crops', 1: 'Oilseeds', 2: 'Permanant Crops', 3: 'Permanant Grasslands', 
          4: 'Vegetables & Root Crops', 5: 'Corn', 6: 'Summer Other Cereals', 7: 'Winter Wheat', 
          8: 'Winter Other Cereals'}

In [3]:
train_new_file = '/home/pc4dl/SYM2/SITS/data/L2_Case_II_New/Train.csv'
valid_new_file = '/home/pc4dl/SYM2/SITS/data/L2_Case_II_New/Validate.csv'
test_new_file = '/home/pc4dl/SYM2/SITS/data/L2_Case_II_New/Test.csv'

train_new_dataset = FinetuneDataset(train_new_file, 10, 64)
valid_new_dataset = FinetuneDataset(valid_new_file, 10, 64)
test_new_dataset = FinetuneDataset(test_new_file, 10, 64)

print("training samples: %d, validation samples: %d, testing samples: %d" % (train_new_dataset.TS_num, valid_new_dataset.TS_num, test_new_dataset.TS_num))

training samples: 84139, validation samples: 28378, testing samples: 54923


In [4]:
train_new_data_loader = DataLoader(train_new_dataset, shuffle=False, batch_size=128, drop_last=False)
valid_new_data_loader = DataLoader(valid_new_dataset, shuffle=False, batch_size=128, drop_last=False)
test_new_data_loader = DataLoader(test_new_dataset, shuffle=False, batch_size=128, drop_last=False)

In [5]:
sbert = SBERT(10, hidden=256, n_layers=3, attn_heads=8, dropout=0.1)

In [6]:
trainer_new = SBERTFineTuner(sbert, 9, train_dataloader=train_new_data_loader, valid_dataloader=valid_new_data_loader)

In [7]:
print("Testing SITS-BERT...")
trainer_new.load('../../checkpoints/CP_SR_VI_L2_Case_II_finetune/')

Testing SITS-BERT...
EP:3 Model loaded from: ../../checkpoints/CP_SR_VI_L2_Case_II_finetune/checkpoint.tar


'../../checkpoints/CP_SR_VI_L2_Case_II_finetune/checkpoint.tar'

In [8]:
OA_new, Kappa_new, AA_new, matrix_new, obs_new, pred_new = trainer_new.test(test_new_data_loader)
print('New Predict Summary\n')
print('test_OA = %.2f, test_kappa = %.3f, test_AA = %.3f' % (OA_new, Kappa_new, AA_new))

New Predict Summary

test_OA = 72.66, test_kappa = 0.649, test_AA = 0.612


In [9]:
obs_new_all = torch.cat(obs_new,dim=0)
pred_new_all = torch.cat(pred_new,dim=0)

In [10]:
obs_new_np = obs_new_all.numpy()
pred_new_np = pred_new_all.numpy()

In [11]:
from sklearn.metrics import confusion_matrix, accuracy_score, f1_score, classification_report
from collections import Counter



In [12]:
Counter(obs_new_np)

Counter({5: 18670,
         8: 7153,
         1: 431,
         3: 17545,
         0: 3550,
         7: 3701,
         6: 1038,
         2: 1384,
         4: 1451})

In [13]:
Counter(pred_new_np)

Counter({5: 12027,
         8: 10039,
         1: 176,
         3: 19709,
         7: 1505,
         0: 1008,
         6: 3338,
         4: 5967,
         2: 1154})

In [14]:
np.save('/home/pc4dl/SYM2/SITS/data/Predict_Cases/Case_II_Class_L2_Test_New_PredArray.npy', pred_new_np)

In [15]:
print("=================================\n")
print('SITS Model with SR+VI CL1 TEST NEW (2017-2018)')
print(confusion_matrix(obs_new_np, pred_new_np))
print("\n")

print(classification_report(y_true=obs_new_np, y_pred=pred_new_np, target_names=list(cl2_names.values()), digits=3))


print("=================================\n")


SITS Model with SR+VI CL1 TEST NEW (2017-2018)
[[  499     3    92  2308   172    75   271    16   114]
 [   12   170    10     1    22     0    43     3   170]
 [   54     0   878   310    21    23    62     1    35]
 [   60     0   120 16839    31    24   140    14   317]
 [  107     1    21    19  1133    48   102     4    16]
 [  253     1    26   147  4498 11828  1780    17   120]
 [   10     0     4    48    66    15   752    24   119]
 [    3     0     1     1     9     2   107  1120  2458]
 [   10     1     2    36    15    12    81   306  6690]]


                         precision    recall  f1-score   support

        Ungrouped Crops      0.495     0.141     0.219      3550
               Oilseeds      0.966     0.394     0.560       431
        Permanant Crops      0.761     0.634     0.692      1384
   Permanant Grasslands      0.854     0.960     0.904     17545
Vegetables & Root Crops      0.190     0.781     0.305      1451
                   Corn      0.983     0.634 