In [None]:
# libraries from finetuning_parameters.py
from finetuning_parameters import get_args
from future.baseline_trainer import BaselineTuner
from future.modules import ptl2classes
from future.hooks import EvaluationRecorder

from data_loader.wrap_sampler import wrap_sampler
import data_loader.task_configs as task_configs
import data_loader.data_configs as data_configs
from future.collocate_fns import task2collocate_fn

import utils.checkpoint as checkpoint
import utils.logging as logging

import torch
import random
import os

# libraries from future/base.py
from torch.utils.data import SequentialSampler, RandomSampler
from future.hooks import EvaluationRecorder
import utils.eval_meters as eval_meters
from seqeval.metrics import f1_score as f1_score_tagging
import torch

# libraries from future/baseline_trainer.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from copy import deepcopy
from future.base import BaseTrainer
from future.hooks.base_hook import HookContainer
from future.hooks import EvaluationRecorder
from torch.utils.data import RandomSampler
from collections import defaultdict, Counter
from tqdm import tqdm

# and so on..
from finetuning_baseline import init_config, init_task, init_hooks

In [None]:
parser = get_args()
conf = parser.parse_args(args=[])

In [None]:
conf.dataset_name = 'mldoc'
conf.trn_languages = 'english'

if conf.dataset_name == 'pawsx':
    conf.eval_languages = 'english,german,chinese,french,japanese,korean,spanish'
elif conf.dataset_name == 'xnli':
    conf.eval_languages = 'english,arabic,bulgarian,chinese,french,german,greek,hindi,russian,spanish,swahili,thai,turkish,urdu,vietnamese'
elif conf.dataset_name == 'marc':
    conf.eval_languages = 'english,german,chinese,french,japanese,spanish'
elif conf.dataset_name == 'mldoc':
    conf.eval_languages = 'english,german,french,spanish,italian,russian,chinese,japanese'
    
conf.finetune_epochs = 10
conf.finetune_batch_size = 256
conf.eval_every_batch = 50
conf.override = False
conf.train_fast = False
conf.world = '0'
conf.finetune_lr = 1e-5

# conf.trans_train = True
# conf.trans_test = True

In [None]:
init_config(conf)
model, tokenizer, data_iter, metric_name, collocate_batch_fn = init_task(conf)

data_iter

In [None]:
adapt_loaders = {}
for language, language_dataset in data_iter.items():
    # NOTE: the sample dataset are refered
    adapt_loaders[language] = wrap_sampler(
        trn_batch_size=conf.finetune_batch_size,
        infer_batch_size=conf.inference_batch_size,
        language=language,
        language_dataset=language_dataset,
    )
hooks = init_hooks(conf, metric_name)

In [None]:
trainer = BaselineTuner(
        conf, collocate_batch_fn=collocate_batch_fn, logger=conf.logger, criterion=nn.CrossEntropyLoss()
    )
trainer.conf.eval_languages

In [None]:
choice = 'best' # best or last

pawsx_base_dir = '/input/jongwooko/xlt/checkpoint_baseline/pawsx/debug1'
marc_base_dir = '/input/jongwooko/xlt/checkpoint_baseline/marc/ce_origin_trans'
mldoc_base_dir = '/input/jongwooko/xlt/checkpoint_baseline/mldoc/ce_origin_trans'

use_mix = False
use_sup = False
training_type = 'ce_origin_all'

if conf.dataset_name == 'xnli':
    pass
elif conf.dataset_name == 'pawsx':
    if use_mix and use_sup:
        curr_dir = '1652076677_model_task-pawsx_flr-1.0E-05_ftbs-8_ftepcs-2_sd-3_trnfast-False_evalevery-300_tlang-de-zh-fr-ja-ko-es_vlang-en-de-zh-fr-ja-ko-es'
    elif use_mix and not use_sup:
        curr_dir = '1652076650_model_task-pawsx_flr-1.0E-05_ftbs-8_ftepcs-2_sd-3_trnfast-False_evalevery-300_tlang-de-zh-fr-ja-ko-es_vlang-en-de-zh-fr-ja-ko-es'
    elif not use_mix and use_sup:
        curr_dir = '1652076670_model_task-pawsx_flr-1.0E-05_ftbs-8_ftepcs-2_sd-3_trnfast-False_evalevery-300_tlang-de-zh-fr-ja-ko-es_vlang-en-de-zh-fr-ja-ko-es'
    else:
        curr_dir = '1652076646_model_task-pawsx_flr-1.0E-05_ftbs-8_ftepcs-2_sd-3_trnfast-False_evalevery-300_tlang-de-zh-fr-ja-ko-es_vlang-en-de-zh-fr-ja-ko-es'
    model_dir = os.path.join(pawsx_base_dir, curr_dir, 'state_dicts', '{}_state.pt'.format(choice))
elif conf.dataset_name == 'marc':
    if training_type == 'ce':
        curr_dir = '1653322466_model_task-marc_flr-1.0E-05_ftbs-32_ftepcs-2_sd-3_trnfast-False_evalevery-300_tlang-en_vlang-en'
    elif training_type == 'ce_origin_all':
        curr_dir = '1653322473_model_task-marc_flr-1.0E-05_ftbs-32_ftepcs-2_sd-3_trnfast-False_evalevery-300_tlang-en-de-zh-fr-ja-es_vlang-en'
    elif training_type == 'ce_trans_all':
        curr_dir = '1653371885_model_task-marc_flr-1.0E-05_ftbs-16_ftepcs-2_sd-3_trnfast-False_evalevery-300_tlang-de-zh-fr-ja-es_vlang-en'
    else:
        raise ValueError
    model_dir = os.path.join(marc_base_dir, curr_dir, 'state_dicts', '{}_state.pt'.format(choice))
elif conf.dataset_name == 'mldoc':
    if training_type == 'ce':
        curr_dir = '1653402142_model_task-mldoc_flr-1.0E-05_ftbs-32_ftepcs-10_sd-3_trnfast-False_evalevery-300_tlang-en_vlang-en'
    elif training_type == 'ce_origin_all':
        curr_dir = '1653402571_model_task-mldoc_flr-1.0E-05_ftbs-32_ftepcs-10_sd-3_trnfast-False_evalevery-300_tlang-en-ja-zh-fr-de-es-ru-it_vlang-en'
    elif training_type == 'ce_trans_all':
        pass
    else:
        raise ValueError
    model_dir = os.path.join(mldoc_base_dir, curr_dir, 'state_dicts', '{}_state.pt'.format(choice))
    
if choice == 'best':
    checkpoint = torch.load(model_dir)['best_state_dict']
elif choice == 'last':
    checkpoint = torch.load(model_dir)
    
if use_mix or use_sup:
    checkpoint = {k[8:]: v for k, v in checkpoint.items() if k.startswith('encoder.')}
    
model.load_state_dict(checkpoint, strict=True)

In [None]:
# trainer.train
opt, model = trainer._init_model_opt(model)
device = 'cuda:1'
trainer.model = model
trainer.model.to(device)
trainer.model.eval()

test_method = 'ensemble_prob' # ensemble_logit, ensemble_prob

lang_acc_lst = []
for language in ['english', 'german', 'french', 'spanish', 'chinese', 'japanese']:
    print ("Start language-{}".format(language))
    
    trn_iters = []
    egs = adapt_loaders[language].tst_egs
    trn_iters.append(iter(egs))
    
    correct = 0.
    total = 0.

    batches_per_epoch = max(len(ti) for ti in trn_iters)
    for batch_index in range(1, batches_per_epoch + 1):
        trn_loss = []
        for ti in trn_iters:
            try:
                batched = next(ti)
            except StopIteration:
                continue
            batched, golds, uids, _golds_tagging = trainer.collocate_batch_fn(
                batched
            )
            with torch.no_grad():
#                 if language == 'english': # len(golds.size()) == 1
#                     pass
#                 else:                     # len(golds.size()) == 2
#                     bsz = len(golds[:, 0])
#                     for k in batched.keys():
#                         batched[k] = torch.cat([batched[k][:, 0], batched[k][:, 1]], dim=0)
#                     golds = golds[:, 0]
                        
                for k, v in batched.items():
                    batched[k] = v.to(device)
                golds = golds.to(device)
                logits = trainer.model(**batched)
                
#                 if language == 'english':
#                     pass
#                 else:
#                     alpha = 0.5
#                     src_logits = logits[0][:bsz]
#                     tgt_logits = logits[0][bsz:]
#                     if test_method == 'ensemble_logit':
#                         logits = (alpha*src_logits + (1-alpha)*tgt_logits, )
#                     elif test_method == 'ensemble_prob':
#                         logits = (alpha*F.softmax(src_logits, dim=1) + (1-alpha)*F.softmax(tgt_logits, dim=1), )
#                     else:
#                         raise ValueError
                correct += (logits[0].max(1)[1] == golds).sum()
                total += len(logits[0])
    print (correct / total * 100)
    lang_acc_lst.append((correct / total * 100).item())
print (lang_acc_lst, np.mean(lang_acc_lst))