In [None]:
import torch
import torch.nn as nn
from sklearn.metrics import matthews_corrcoef, f1_score

import datetime
from termcolor import colored
import numpy as np
from tqdm import tqdm


## 1. class BASE
- metric 관련 함수

In [None]:
class BASE(nn.Module):
    '''
        BASE model
    '''
    def __init__(self, args):
        super(BASE, self).__init__()
        self.args = args

        # cached tensor for speed
        # self.I_way = nn.Parameter(torch.eye(self.args.way, dtype=torch.float),
        #                           requires_grad=False)
    
    @staticmethod
    def compute_acc(pred, true, dim=1, nomax=False):
        '''
            Compute the accuracy.
            @param pred: batch_size * num_classes
            @param true: batch_size
        '''
        if nomax: return torch.mean((pred == true).float()).item()
        else: return torch.mean((torch.argmax(pred, dim=dim) == true).float()).item()
        
    @staticmethod
    def compute_f1(y_pred, true, dim=1, nomax=False,  labels=None, average='weighted'):
        '''
            Compute the weighted f1 score.
            @param pred: batch_size * num_classes
            @param true: batch_size
        '''
        if not nomax: _, y_pred = torch.max(y_pred, dim)

        f1 = f1_score(true.cpu().detach().numpy(), y_pred.cpu().detach().numpy(), average=average, labels=labels )

        return f1

    @staticmethod
    def compute_mcc(y_pred, true, dim=1, nomax=False):
        '''
            Compute the matthews correlation coeficient.
            @param pred: batch_size * num_classes
            @param true: batch_size
        '''
        if not nomax: _, y_pred = torch.max(y_pred, dim)

        mcc = matthews_corrcoef(true.cpu().detach().numpy(), y_pred.cpu().detach().numpy())

        return mcc

## 2. Test 함수
- model output(=logits)과 label과의 cross-entropy 계산

In [None]:
def test(test_data, model, args, verbose=True, target='val', loader=None):
    '''
        Evaluate the model on a bag of sampled tasks. Return the mean accuracy, 
        the weighted f1 score and the matthew correlation coeficient and their
        associated std. (ensure the model used is modified to return the values)
    '''
    model['ebd'].eval()
    model['clf'].eval()

    acc, f1, mcc, f1_micro, trues, preds = [], [], [], [], [], []
    # if loader is None:
    #     loader = DataLoader(SupervisedDataset(test_data, args), batch_size=args.batch_size, num_workers=2, shuffle=False)

    for batch in tqdm(loader, desc=colored('Testing regular on %s' % (target), 'yellow'), total=loader.__len__()):
        #res_acc, res_f1, res_mcc, res_f1_micro, res_pred, res_true = test_one(batch, model, args, out=(target=='test'))
        res_acc, res_f1, res_mcc = test_one(batch, model, args, out=(target=='test'))
        acc.append(res_acc)
        f1.append(res_f1)
        mcc.append(res_mcc)

    acc, f1, mcc = np.array(acc), np.array(f1), np.array(mcc)

    if verbose:
        print("{}, {:s} {:>7.4f} ({:s} {:>7.4f}), {:s} {:>7.4f} ({:s} {:>7.4f}), {:s} {:>7.4f} ({:s} {:>7.4f})".format(
                datetime.datetime.now().strftime('%02y/%02m/%02d %H:%M:%S'),
                colored("acc mean", "blue"),
                np.mean(acc),
                colored("std", "blue"),
                np.std(acc),
                colored("f1 mean", "blue"),
                np.mean(f1),
                colored("std", "blue"),
                np.std(f1),
                colored("mcc mean", "blue"),
                np.mean(mcc),
                colored("std", "blue"),
                np.std(mcc),
                colored("f1 micro mean", "blue"),
                ), flush=True)

        # latex table
        print("{:s} & {:s} & {:>7.4f} \\tiny $\\pm {:>7.4f}$ & {:>7.4f} \\tiny $\\pm {:>7.4f}$ & {:>7.4f} \\tiny $\\pm {:>7.4f}$".format(
                args.embedding.replace('_', '\\_'),
                args.classifier.replace('_', '\\_'),
                np.mean(acc),
                np.std(acc),
                np.mean(f1),
                np.std(f1),
                np.mean(mcc),
                np.std(mcc),
                ), flush=True)
   
    return np.mean(acc), np.std(acc), np.mean(f1), np.std(f1), np.mean(mcc), np.std(mcc)


def test_one(batch, model, args, out=False):
    '''
        Evaluate the model on one sampled task. Return the accuracy.
    '''

    batch['text'] = batch['text'].to(device)
    batch['label'] = batch['label'].to(device)

    # Embedding the document
    XS = model['ebd'](batch) # embedding된 input data (=calibration data)
    YS = batch['label']

    out_XS = model['clf'](XS, YS=None)

    output = out_XS.view(-1, args.n_classes)  # new shape: [32*35, 7]
    target = YS.view(-1)  # new shape: [32*35]

    #loss = F.cross_entropy(output, YS)
    acc = BASE.compute_acc(output, target)
    f1 = BASE.compute_f1(output, target)
    mcc = BASE.compute_mcc(output, target)
    return acc, f1, mcc

In [None]:
test_acc, test_std, _, _, _, _ = test(test_data, model, args, target='test', loader=test_loader)