In [1]:
from scr.dataset import get_data_from_json, get_label_set

dataDict = get_data_from_json('data/data_full.json')
trainList = dataDict['train']
valList = dataDict['val']
testList = dataDict['test']
labelSet = get_label_set(trainList, valList, testList)

  from .autonotebook import tqdm as notebook_tqdm


The keys found in this json are:  dict_keys(['oos_val', 'val', 'train', 'oos_test', 'test', 'oos_train'])


In [2]:
from scr.dataset import SentenceLabelDataset

trainSet = SentenceLabelDataset(trainList, labelSet)

In [3]:
from torch.utils.data import DataLoader
from scr.dataset import collate_dynamic_padding

dataloader = DataLoader(trainSet, 
                        batch_size = 8, 
                        num_workers = 2, 
                        shuffle = True,
                        collate_fn = collate_dynamic_padding,
                        pin_memory = True)

In [2]:
from scr.evaluate import MultiLabelEvaluator
import torch

targets = [[1,0,1], [1,0,0], [1,1,0],[1,1,1]]
preds = [[1,0,0], [0,1,0], [1,1,0], [1,1,0]]
targets = torch.as_tensor(targets)
preds = torch.as_tensor(preds)

evaluator = MultiLabelEvaluator(probs = preds, targets=targets)

print('Target matrix is: \n',targets)

print('Prediction matrix is: \n',preds)

accuracies, avgAcc = evaluator.get_accuracy()
print('Accuracies of each class is: ',accuracies)
print('Average Accuracy: ',avgAcc)

precisions, avgPrecision = evaluator.get_precision()
print('Precision of each class is: ',precisions)
print('Average Accuracy: ',avgPrecision)

recalls, avgRecall = evaluator.get_recall()
print('Recall of each class is: ',recalls)
print('Average Recall: ',avgRecall)

Target matrix is: 
 tensor([[1, 0, 1],
        [1, 0, 0],
        [1, 1, 0],
        [1, 1, 1]])
Prediction matrix is: 
 tensor([[1, 0, 0],
        [0, 1, 0],
        [1, 1, 0],
        [1, 1, 0]])
Accuracies of each class is:  tensor([0.7500, 0.6667, 0.0000])
Average Accuracy:  0.5555555820465088
Precision of each class is:  tensor([1.0000, 0.6667,    nan])
Average Accuracy:  tensor(0.8333)
Recall of each class is:  tensor([0.7500, 1.0000, 0.0000])
Average Recall:  tensor(0.6250)


In [3]:
evaluator.truePositive

tensor([[1, 0, 0],
        [0, 0, 0],
        [1, 1, 0],
        [1, 1, 0]], dtype=torch.int32)

In [11]:
torch.where(evaluator.truePositive==0, preds, 0)

tensor([[0, 0, 0],
        [0, 1, 0],
        [0, 0, 0],
        [0, 0, 0]])

In [4]:
evaluator.positive

tensor([[1, 0, 1],
        [1, 1, 0],
        [1, 1, 0],
        [1, 1, 1]], dtype=torch.int32)

In [46]:
print(preds)
print(preds[:,2])


tensor([[1, 1, 0],
        [0, 0, 1],
        [0, 0, 1],
        [0, 0, 1]], dtype=torch.int32)
tensor([0, 1, 1, 1], dtype=torch.int32)
