In [13]:
import sys
sys.path.append('/home/v-runmao/projects/DomShift-ATMF')

from importlib import reload

import precompute_vis
reload(precompute_vis)

from analysis import helper
reload(helper)

from analysis.helper import get_model_from_root, get_pred_and_prob, do_for_all_algo_and_tgt
from analysis.helper import cmp_binary_pred_and_label, display_metrics_nicely
import pandas as pd
import numpy as np
import torch

In [16]:
@do_for_all_algo_and_tgt(['ERM'], 'ATMF')
def run(algorithm=None, target=None, seed=0):
    data = get_target_shuffled_text(target, seed)
    model = get_model_from_root(f'../../pt/all_best/{algorithm}_BEST', f'{algorithm}_tgt_{target}').to('cuda:0')
    model.eval()
    
    # forward the shuffled text
    shuffled_pred, _ = get_pred_and_prob(model, data['shuffled_text'].tolist())
    data['shuffled_pred'] = shuffled_pred
    
    # replace string label with binary integer
    l2i = {'positive': 1, 'negative': 0}
    data.replace(to_replace={'label': l2i, 'ERM_pred': l2i, 'shuffled_pred': l2i}, inplace=True)
    
    # get acc, f1, etc.
    first = True
    print('Target: {}'.format({'A': 'Amazon', 'T': 'Twitter', 'M': 'MSN', 'F': 'Finance'}[target]))
    for has_negation in [True, False]:  # subset with or without negations
        for label in [1, 0]:  # subset with positive or negative sentiment
            mask = (data['label'] == label) & (data['has_negation'] == has_negation)
            subset = data[mask]

            label = subset['label']
            origin_pred = subset['ERM_pred']
            shuffled_pred = subset['shuffled_pred']

            origin_res = cmp_binary_pred_and_label(origin_pred, label)
            shuffled_res = cmp_binary_pred_and_label(shuffled_pred, label)

            display_metrics_nicely(origin_res, header=first)
            display_metrics_nicely(shuffled_res, header=False)
            first &= False


def get_target_shuffled_text(target, seed):
    rs = np.random.RandomState(seed)
    
    data = pd.read_csv('./postprocessed.csv')
    mask = (data['domain'] == target)  # only test data
    data = data[mask].copy()
    
    shuffled = []
    for t in data['text']:
        tokens = t.strip().split()
        rs.shuffle(tokens)
        shuffled.append(' '.join(tokens))
    data['shuffled_text'] = shuffled
    return data

The order of the **8** rows of each table is:
- has_negation & label=1 & not shuffled
- has_negation & label=1 & shuffled
- has_negation & label=0 & not shuffled
- has_negation & label=0 & shuffled
- no_negation & label=1 & not shuffled
- no_negation & label=1 & shuffled
- no_negation & label=0 & not shuffled
- no_negation & label=0 & shuffled

In [17]:
run(seed=0)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=125.0), HTML(value='')))


Target: Amazon
----------------------------------------------------------------
N       Acc     Precision   Recall  F1      Specificity     
----------------------------------------------------------------
2516    0.881   1.000       0.881   0.937   -1.000          
2516    0.604   1.000       0.604   0.753   -1.000          
3345    0.810   0.000       -1.000  0.000   0.810           
3345    0.807   0.000       -1.000  0.000   0.807           
1484    0.965   1.000       0.965   0.982   -1.000          
1484    0.881   1.000       0.881   0.937   -1.000          
655     0.666   0.000       -1.000  0.000   0.666           
655     0.678   0.000       -1.000  0.000   0.678           


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=136.0), HTML(value='')))


Target: Twitter
----------------------------------------------------------------
N       Acc     Precision   Recall  F1      Specificity     
----------------------------------------------------------------
142     0.859   1.000       0.859   0.924   -1.000          
142     0.577   1.000       0.577   0.732   -1.000          
2294    0.854   0.000       -1.000  0.000   0.854           
2294    0.721   0.000       -1.000  0.000   0.721           
1660    0.965   1.000       0.965   0.982   -1.000          
1660    0.927   1.000       0.927   0.962   -1.000          
4557    0.722   0.000       -1.000  0.000   0.722           
4557    0.617   0.000       -1.000  0.000   0.617           


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=66.0), HTML(value='')))


Target: MSN
----------------------------------------------------------------
N       Acc     Precision   Recall  F1      Specificity     
----------------------------------------------------------------
231     0.688   1.000       0.688   0.815   -1.000          
231     0.481   1.000       0.481   0.649   -1.000          
136     0.926   0.000       -1.000  0.000   0.926           
136     0.882   0.000       -1.000  0.000   0.882           
2879    0.831   1.000       0.831   0.908   -1.000          
2879    0.730   1.000       0.730   0.844   -1.000          
973     0.836   0.000       -1.000  0.000   0.836           
973     0.823   0.000       -1.000  0.000   0.823           


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=31.0), HTML(value='')))


Target: Finance
----------------------------------------------------------------
N       Acc     Precision   Recall  F1      Specificity     
----------------------------------------------------------------
5       1.000   1.000       1.000   1.000   -1.000          
5       0.800   1.000       0.800   0.889   -1.000          
11      0.909   0.000       -1.000  0.000   0.909           
11      0.636   0.000       -1.000  0.000   0.636           
1358    0.951   1.000       0.951   0.975   -1.000          
1358    0.914   1.000       0.914   0.955   -1.000          
593     0.621   0.000       -1.000  0.000   0.621           
593     0.634   0.000       -1.000  0.000   0.634           
