In [None]:
# libraries from finetuning_parameters.py
from finetuning_parameters import get_args
from future.baseline_trainer import BaselineTuner
from future.modules import ptl2classes
from future.hooks import EvaluationRecorder

from data_loader.wrap_sampler import wrap_sampler
import data_loader.task_configs as task_configs
import data_loader.data_configs as data_configs
from future.collocate_fns import task2collocate_fn

import utils.checkpoint as checkpoint
import utils.logging as logging

import torch
import random
import os

# libraries from future/base.py
from torch.utils.data import SequentialSampler, RandomSampler
from future.hooks import EvaluationRecorder
import utils.eval_meters as eval_meters
from seqeval.metrics import f1_score as f1_score_tagging
import torch

# libraries from future/baseline_trainer.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from copy import deepcopy
from future.base import BaseTrainer
from future.hooks.base_hook import HookContainer
from future.hooks import EvaluationRecorder
from torch.utils.data import RandomSampler
from collections import defaultdict, Counter
from tqdm import tqdm

# and so on..
from finetuning_baseline import init_config, init_task, init_hooks

In [None]:
parser = get_args()
conf = parser.parse_args(args=[])

conf.dataset_name = 'xnli'
conf.trn_languages = 'english'

if conf.dataset_name == 'pawsx':
    conf.eval_languages = 'english,german,chinese,french,japanese,korean,spanish'
elif conf.dataset_name == 'xnli':
    conf.eval_languages = 'english,arabic,bulgarian,chinese,french,german,greek,hindi,russian,spanish,swahili,thai,turkish,urdu,vietnamese'
elif conf.dataset_name == 'marc':
    conf.eval_languages = 'english,german,chinese,french,japanese,spanish'
elif conf.dataset_name == 'mldoc':
    conf.eval_languages = 'english,german,french,spanish,italian,russian,chinese,japanese'
elif conf.dataset_name == 'cls':
    conf.eval_languages = 'english,german,french,japanese'
    
conf.finetune_epochs = 10
conf.finetune_batch_size = 256
conf.eval_every_batch = 50
conf.override = False
conf.train_fast = False
conf.world = '0'
conf.finetune_lr = 1e-5

conf.use_cache = True

In [None]:
if conf.dataset_name in ['cls', 'marc', 'mldoc']:
    conf.test_mt = True
elif conf.dataset_name in ['pawsx', 'xnli']:
    conf.trans_test = True
else:
    raise ValueError
    
init_config(conf)
model, tokenizer, data_iter, metric_name, collocate_batch_fn = init_task(conf)

data_iter

In [None]:
adapt_loaders = {}
for language, language_dataset in data_iter.items():
    # NOTE: the sample dataset are refered
    adapt_loaders[language] = wrap_sampler(
        trn_batch_size=conf.finetune_batch_size,
        infer_batch_size=conf.inference_batch_size,
        language=language,
        language_dataset=language_dataset,
    )
hooks = init_hooks(conf, metric_name)

In [None]:
trainer = BaselineTuner(
        conf, collocate_batch_fn=collocate_batch_fn, logger=conf.logger, criterion=nn.CrossEntropyLoss()
    )
trainer.conf.eval_languages

In [None]:
def get_model(model, base_dir, dataset_name, training_tgt, experience_name, choice):    
    experience_dir = os.path.join(base_dir, dataset_name, '{}_{}'.format(dataset_name,
                                                                         experience_name))
    
    # experience_dir = os.path.join(base_dir, dataset_name, '{}_{}_{}'.format(dataset_name,
    #                                                                         training_tgt,
    #                                                                         experience_name))

    # experience_dir = os.path.join(base_dir, '{}_{}_{}'.format(dataset_name,
    #                                                           training_tgt,
    #                                                           experience_name))


    model_dir = os.path.join(experience_dir, os.listdir(experience_dir)[0], 'state_dicts', '{}_state.pt'.format(choice))

    if choice == 'best':
        checkpoint = torch.load(model_dir, map_location='cpu')['best_state_dict']
    elif choice == 'last':
        checkpoint = torch.load(model_dir, map_location='cpu')

    if 'supcon' in experience_name:
        checkpoint = {k[8:]: v for k, v in checkpoint.items() if k.startswith('encoder.')}

    print (model_dir)
    model.load_state_dict(checkpoint, strict=True)
    return model

# Performance

In [None]:
base_dir = '/data/FSXLT_dataset/checkpoint_baseline/'
dataset_name = 'xnli'
training_tgt = 'mt' # or, mt, bt
experience_name = 'supcon' # ce, mixup, supcon, supcon_mixup
choice = 'best' # best or last

model = get_model(model, base_dir, dataset_name, training_tgt, experience_name, choice)

In [None]:
# trainer.train
opt, model = trainer._init_model_opt(model)
device = 'cuda:3'
trainer.model = model
trainer.model.to(device)
trainer.model.eval()

test_method = 'ensemble_prob' # ensemble_logit, ensemble_prob

lang_acc_lst = []
for language in trainer.conf.eval_languages:
    print ("Start language-{}".format(language))
    
    trn_iters = []
    egs = adapt_loaders[language].tst_egs
    trn_iters.append(iter(egs))
    
    correct = 0.
    total = 0.

    batches_per_epoch = max(len(ti) for ti in trn_iters)
    for batch_index in range(1, batches_per_epoch + 1):
        for ti in trn_iters:
            try:
                batched = next(ti)
            except StopIteration:
                continue
            batched, golds, uids, _golds_tagging = trainer.collocate_batch_fn(
                batched
            )
            
            with torch.no_grad():
                if language == 'english': # len(golds.size()) == 1
                    pass
                else:                     # len(golds.size()) == 2
                    bsz = len(golds[:, 0])
                    for k in batched.keys():
                        batched[k] = torch.cat([batched[k][:, 0], batched[k][:, 1]], dim=0)
                    golds = golds[:, 0]
                        
                for k, v in batched.items():
                    batched[k] = v.to(device)
                golds = golds.to(device)
                logits = trainer.model(**batched)
                
                if language == 'english':
                    pass
                else:
                    alpha = 0.5 # 0.0 -> only target (original) / 1.0 -> only translated source (translate-test)
                    tgt_logits = logits[0][:bsz]
                    src_logits = logits[0][bsz:]
                    if test_method == 'ensemble_logit':
                        logits = (alpha*src_logits + (1-alpha)*tgt_logits, )
                    elif test_method == 'ensemble_prob':
                        logits = (alpha*F.softmax(src_logits, dim=1) + (1-alpha)*F.softmax(tgt_logits, dim=1), )
                    else:
                        raise ValueError
                correct += (logits[0].max(1)[1] == golds).sum()
                total += len(logits[0])
    print (correct / total * 100)
    lang_acc_lst.append((correct / total * 100).item())
print (np.round(lang_acc_lst, 1), np.round(np.mean(lang_acc_lst), 1))

# Entropy

In [None]:
import matplotlib.pyplot as plt

In [None]:
dict_list = []
acc_list = []

In [None]:
base_dir = '../saved/'
dataset_name = 'mldoc'
training_tgt = 'bt' # or, mt, bt
experience_name = 'baseline'
choice = 'best' # best or last

model = get_model(model, base_dir, dataset_name, training_tgt, experience_name, choice)

In [None]:
# trainer.train
opt, model = trainer._init_model_opt(model)
device = 'cuda:3'
trainer.model = model
trainer.model.to(device)
trainer.model.eval()

entropy_dict = {}
lang_acc_lst = []
for language in trainer.conf.eval_languages:
    print ("Start language-{}".format(language))
    
    entropy_dict[language] = []
    trn_iters = []
    egs = adapt_loaders[language].tst_egs
    trn_iters.append(iter(egs))
    
    correct = 0.
    total = 0.

    batches_per_epoch = max(len(ti) for ti in trn_iters)
    for batch_index in range(1, batches_per_epoch + 1):
        for ti in trn_iters:
            try:
                batched = next(ti)
            except StopIteration:
                continue
            batched, golds, uids, _golds_tagging = trainer.collocate_batch_fn(
                batched
            )
            
            with torch.no_grad():
#                 if language == 'english': # len(golds.size()) == 1
#                     pass
#                 else:                     # len(golds.size()) == 2
#                     bsz = len(golds[:, 0])
#                     for k in batched.keys():
#                         batched[k] = torch.cat([batched[k][:, 0], batched[k][:, 1]], dim=0)
#                     golds = golds[:, 0]
                        
                for k, v in batched.items():
                    batched[k] = v.to(device)
                golds = golds.to(device)
                logits = trainer.model(**batched)
                entropies = F.softmax(logits[0], dim=1) * F.log_softmax(logits[0], dim=1)
                entropies = -torch.sum(entropies, dim=1)
                entropy_dict[language] += entropies.tolist()
#                 if language == 'english':
#                     pass
#                 else:
#                     alpha = 0.5
#                     src_logits = logits[0][:bsz]
#                     tgt_logits = logits[0][bsz:]
#                     if test_method == 'ensemble_logit':
#                         logits = (alpha*src_logits + (1-alpha)*tgt_logits, )
#                     elif test_method == 'ensemble_prob':
#                         logits = (alpha*F.softmax(src_logits, dim=1) + (1-alpha)*F.softmax(tgt_logits, dim=1), )
#                     else:
#                         raise ValueError
                correct += (logits[0].max(1)[1] == golds).sum()
                total += len(logits[0])
    # print (correct / total * 100)
    lang_acc_lst.append((correct / total * 100).item())
print (np.round(lang_acc_lst, 1), np.round(np.mean(lang_acc_lst), 1))
dict_list.append(entropy_dict)
acc_list.append(lang_acc_lst)

entropy_dict = {}
lang_acc_lst = []
for language in trainer.conf.eval_languages:
    print ("Start language-{}".format(language))
    
    entropy_dict[language] = []
    trn_iters = []
    egs = bt_adapt_loaders[language].tst_egs
    trn_iters.append(iter(egs))
    
    correct = 0.
    total = 0.

    batches_per_epoch = max(len(ti) for ti in trn_iters)
    for batch_index in range(1, batches_per_epoch + 1):
        for ti in trn_iters:
            try:
                batched = next(ti)
            except StopIteration:
                continue
            batched, golds, uids, _golds_tagging = trainer.collocate_batch_fn(
                batched
            )
            
            with torch.no_grad():
#                 if language == 'english': # len(golds.size()) == 1
#                     pass
#                 else:                     # len(golds.size()) == 2
#                     bsz = len(golds[:, 0])
#                     for k in batched.keys():
#                         batched[k] = torch.cat([batched[k][:, 0], batched[k][:, 1]], dim=0)
#                     golds = golds[:, 0]
                        
                for k, v in batched.items():
                    batched[k] = v.to(device)
                golds = golds.to(device)
                logits = trainer.model(**batched)
                entropies = F.softmax(logits[0], dim=1) * F.log_softmax(logits[0], dim=1)
                entropies = -torch.sum(entropies, dim=1)
                entropy_dict[language] += entropies.tolist()
#                 if language == 'english':
#                     pass
#                 else:
#                     alpha = 0.5
#                     src_logits = logits[0][:bsz]
#                     tgt_logits = logits[0][bsz:]
#                     if test_method == 'ensemble_logit':
#                         logits = (alpha*src_logits + (1-alpha)*tgt_logits, )
#                     elif test_method == 'ensemble_prob':
#                         logits = (alpha*F.softmax(src_logits, dim=1) + (1-alpha)*F.softmax(tgt_logits, dim=1), )
#                     else:
#                         raise ValueError
                correct += (logits[0].max(1)[1] == golds).sum()
                total += len(logits[0])
    # print (correct / total * 100)
    lang_acc_lst.append((correct / total * 100).item())
print (np.round(lang_acc_lst, 1), np.round(np.mean(lang_acc_lst), 1))
dict_list.append(entropy_dict)
acc_list.append(lang_acc_lst)

In [None]:
OR_OR_ent = []
OR_OR_std = []
OR_BT_ent = []
OR_BT_std = []
BT_OR_ent = []
BT_OR_std = []
BT_BT_ent = []
BT_BT_std = []

for language in trainer.conf.eval_languages:
    OR_OR = dict_list[0][language]
    OR_BT = dict_list[1][language]
    BT_OR = dict_list[2][language]
    BT_BT = dict_list[3][language]
    OR_OR_ent.append(np.mean(OR_OR)); OR_OR_std.append(np.std(OR_OR)); 
    OR_BT_ent.append(np.mean(OR_BT)); OR_BT_std.append(np.std(OR_BT)); 
    BT_OR_ent.append(np.mean(BT_OR)); BT_OR_std.append(np.std(BT_OR)); 
    BT_BT_ent.append(np.mean(BT_BT)); BT_BT_std.append(np.std(BT_BT)); 
    # plt.hist(OR_OR, bins=50, label='OR-OR {:.4f}'.format(np.mean(OR_OR)))
    # plt.hist(OR_BT, bins=50, label='OR-BT {:.4f}'.format(np.mean(OR_BT)))
    # plt.hist(BT_OR, bins=50, label='BT-OR {:.4f}'.format(np.mean(BT_OR)))
    # plt.hist(BT_BT, bins=50, label='BT-BT {:.4f}'.format(np.mean(BT_BT)))
    # plt.title(language)
    # plt.legend()
    # plt.show()
    # plt.close()

In [None]:
plt.plot(trainer.conf.eval_languages, OR_OR_ent, marker='o', label='OR-OR')
plt.plot(trainer.conf.eval_languages, OR_BT_ent, marker='o', label='OR-BT')
plt.plot(trainer.conf.eval_languages, BT_OR_ent, marker='o', label='BT-OR')
plt.plot(trainer.conf.eval_languages, BT_BT_ent, marker='o', label='BT-BT')
plt.legend()
plt.show()
plt.close()

# plt.plot(trainer.conf.eval_languages, acc_list[0], marker='o', label='OR-OR')
# plt.plot(trainer.conf.eval_languages, acc_list[1], marker='o', label='OR-BT')
# plt.plot(trainer.conf.eval_languages, acc_list[2], marker='o', label='BT-OR')
# plt.plot(trainer.conf.eval_languages, acc_list[3], marker='o', label='BT-BT')
# plt.legend()
# plt.show()
# plt.close()