In [None]:
# libraries from finetuning_parameters.py
from finetuning_parameters import get_args
from future.baseline_trainer import BaselineTuner
from future.modules import ptl2classes
from future.hooks import EvaluationRecorder

from data_loader.wrap_sampler import wrap_sampler
import data_loader.task_configs as task_configs
import data_loader.data_configs as data_configs
from future.collocate_fns import task2collocate_fn

import utils.checkpoint as checkpoint
import utils.logging as logging

import torch
import random
import os
import matplotlib.pyplot as plt

# libraries from future/base.py
from torch.utils.data import SequentialSampler, RandomSampler
from future.hooks import EvaluationRecorder
import utils.eval_meters as eval_meters
from seqeval.metrics import f1_score as f1_score_tagging
import torch

# libraries from future/baseline_trainer.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from copy import deepcopy
from future.base import BaseTrainer
from future.hooks.base_hook import HookContainer
from future.hooks import EvaluationRecorder
from torch.utils.data import RandomSampler
from collections import defaultdict, Counter
from tqdm import tqdm

# and so on..
from finetuning_baseline import init_config, init_task, init_hooks

In [None]:
parser = get_args()
conf = parser.parse_args(args=[])

conf.dataset_name = 'mldoc'
conf.trn_languages = 'english'

if conf.dataset_name == 'pawsx':
    conf.eval_languages = 'english,german,chinese,french,japanese,korean,spanish'
elif conf.dataset_name == 'xnli':
    conf.eval_languages = 'english,arabic,bulgarian,chinese,french,german,greek,hindi,russian,spanish,swahili,thai,turkish,urdu,vietnamese'
elif conf.dataset_name == 'marc':
    conf.eval_languages = 'english,german,chinese,french,japanese,spanish'
elif conf.dataset_name == 'mldoc':
    conf.eval_languages = 'english,german,french,spanish,italian,russian,chinese,japanese'
elif conf.dataset_name == 'cls':
    conf.eval_languages = 'english,german,french,japanese'
    
conf.finetune_epochs = 10
conf.finetune_batch_size = 256
conf.eval_every_batch = 50
conf.override = False
conf.train_fast = False
conf.world = '0'
conf.finetune_lr = 1e-5

In [None]:
if conf.dataset_name in ['cls', 'marc', 'mldoc']:
    conf.use_cache = True
    conf.test_mt = True
    if conf.dataset_name == 'cls':
        conf.domain = 'music'
elif conf.dataset_name in ['pawsx', 'xnli']:
    conf.use_cache = True
    conf.trans_test = True
else:
    raise ValueError
    
init_config(conf)
model, tokenizer, data_iter, metric_name, collocate_batch_fn = init_task(conf)

data_iter

In [None]:
adapt_loaders = {}
for language, language_dataset in data_iter.items():
    # NOTE: the sample dataset are refered
    adapt_loaders[language] = wrap_sampler(
        trn_batch_size=conf.finetune_batch_size,
        infer_batch_size=conf.inference_batch_size,
        language=language,
        language_dataset=language_dataset,
    )
hooks = init_hooks(conf, metric_name)

In [None]:
trainer = BaselineTuner(
        conf, collocate_batch_fn=collocate_batch_fn, logger=conf.logger, criterion=nn.CrossEntropyLoss()
    )
trainer.conf.eval_languages

In [None]:
def get_model(model, base_dir, dataset_name, training_tgt, experience_name, choice):    
    experience_dir = os.path.join(base_dir, dataset_name, '{}'.format(
                                                                         experience_name))
    # experience_dir = os.path.join(base_dir, dataset_name, '{}_{}'.format(training_tgt,
    #                                                                      experience_name))
    
    # experience_dir = os.path.join(base_dir, dataset_name, '{}_{}_{}'.format(dataset_name,
    #                                                                         training_tgt,
    #                                                                         experience_name))

    # experience_dir = os.path.join(base_dir, '{}_{}_{}'.format(dataset_name,
    #                                                           training_tgt,
    #                                                           experience_name))


    model_dir = os.path.join(experience_dir, os.listdir(experience_dir)[-1], 'state_dicts', '{}_state.pt'.format(choice))

    if choice == 'best':
        checkpoint = torch.load(model_dir, map_location='cpu')['best_state_dict']
    elif choice == 'last':
        checkpoint = torch.load(model_dir, map_location='cpu')

    if 'supcon' in experience_name:
        checkpoint = {k[8:]: v for k, v in checkpoint.items() if k.startswith('encoder.')}

    print (model_dir)
    model.load_state_dict(checkpoint, strict=True)
    return model

# Performance

In [None]:
base_dir = '/data/FSXLT_dataset/checkpoint_baseline/'
dataset_name = 'mldoc'
training_tgt = 'bt' # or, mt, bt
experience_name = 'mldoc1000_mt_baseline_seed99' # ce, mixup, supcon, supcon_mixup
choice = 'best' # best or last

# base_dir = '/data/FSXLT_dataset/checkpoint_baseline/xnli/xnli_mixup/1654840669_model_task-xnli_flr-3.0E-05_ftbs-32_ftepcs-2_sd-3_trnfast-False_evalevery-3000_tlang-ar-bg-zh-fr-de-e'
# dataset_name = 'xnli'
# training_tgt = 'or' # or, mt, bt
# experience_name = 'xnli_mixup' # ce, mixup, supcon, supcon_mixup
# choice = 'best' # best or last

model = get_model(model, base_dir, dataset_name, training_tgt, experience_name, choice)

In [None]:
# trainer.train
opt, model = trainer._init_model_opt(model)
device = 'cuda:0'
trainer.model = model
trainer.model.to(device)
trainer.model.eval()

test_method = 'ensemble_prob' # ensemble_logit, ensemble_prob

for alpha in [0.0, 1.0, 0.5]: # [0.0, 1.0, 0.5]: # 0.0 -> only target (original) / 1.0 -> only translated source (translate-test)
    lang_acc_lst = []
    for language in ['english', 'german', 'french', 'spanish', 'italian', 'russian', 'chinese', 'japanese']:
        print ("Start language-{}".format(language))

        trn_iters = []
        egs = adapt_loaders[language].tst_egs
        trn_iters.append(iter(egs))

        correct = 0.
        total = 0.

        batches_per_epoch = max(len(ti) for ti in trn_iters)
        for batch_index in range(1, batches_per_epoch + 1):
            for ti in trn_iters:
                try:
                    batched = next(ti)
                except StopIteration:
                    continue
                batched, golds, uids, _golds_tagging = trainer.collocate_batch_fn(
                    batched
                )

                with torch.no_grad():
                    if language == 'english': # len(golds.size()) == 1
                        pass
                    else:                     # len(golds.size()) == 2
                        bsz = len(golds[:, 0])
                        for k in batched.keys():
                            batched[k] = torch.cat([batched[k][:, 0], batched[k][:, 1]], dim=0)
                        golds = golds[:, 0]

                    for k, v in batched.items():
                        batched[k] = v.to(device)
                    golds = golds.to(device)
                    logits = trainer.model(**batched)

                    if language == 'english':
                        pass
                    else:
                        src_logits = logits[0][:bsz]
                        tgt_logits = logits[0][bsz:]
                        if test_method == 'ensemble_logit':
                            logits = (alpha*src_logits + (1-alpha)*tgt_logits, )
                        elif test_method == 'ensemble_prob':
                            logits = (alpha*F.softmax(src_logits, dim=1) + (1-alpha)*F.softmax(tgt_logits, dim=1), )
                        else:
                            raise ValueError
                    correct += (logits[0].max(1)[1] == golds).sum()
                    total += len(logits[0])
        print (correct / total * 100)
        lang_acc_lst.append((correct / total * 100).item())
    print (np.round(lang_acc_lst, 1), np.round(np.mean(lang_acc_lst), 1))

# Prob

In [None]:
# base_dir = '/data/FSXLT_dataset/checkpoint_baseline/'
# dataset_name = 'xnli'
# training_tgt = 'mt' # or, mt, bt
# experience_name = 'baseline' # ce, mixup, supcon, supcon_mixup
# choice = 'best' # best or last

base_dir = '/data/FSXLT_dataset/checkpoint_baseline/'
dataset_name = 'xnli'
training_tgt = 'or' # or, mt, bt
experience_name = 'xnli_supcon_mixup' # ce, mixup, supcon, supcon_mixup
choice = 'best' # best or last

model = get_model(model, base_dir, dataset_name, training_tgt, experience_name, choice)

# trainer.train
opt, model = trainer._init_model_opt(model)
device = 'cuda:1'
trainer.model = model
trainer.model.to(device)
trainer.model.eval()

test_method = 'ensemble_prob' # ensemble_logit, ensemble_prob

src_probs_all = []
tgt_probs_all = []
src_corr_all = []
tgt_corr_all = []
ens_corr_all = []

for alpha in [0.5]:
    lang_acc_lst = []
    for language in ['german']:
        print ("Start language-{}".format(language))

        trn_iters = []
        egs = adapt_loaders[language].tst_egs
        trn_iters.append(iter(egs))

        correct = 0.
        total = 0.

        batches_per_epoch = max(len(ti) for ti in trn_iters)
        for batch_index in range(1, batches_per_epoch + 1):
            for ti in trn_iters:
                try:
                    batched = next(ti)
                except StopIteration:
                    continue
                batched, golds, uids, _golds_tagging = trainer.collocate_batch_fn(
                    batched
                )

                with torch.no_grad():
                    if language == 'english': # len(golds.size()) == 1
                        pass
                    else:                     # len(golds.size()) == 2
                        bsz = len(golds[:, 0])
                        for k in batched.keys():
                            batched[k] = torch.cat([batched[k][:, 0], batched[k][:, 1]], dim=0)
                        golds = golds[:, 0]

                    for k, v in batched.items():
                        batched[k] = v.to(device)
                    golds = golds.to(device)
                    logits = trainer.model(**batched)

                    if language == 'english':
                        pass
                    else:
                        src_logits = logits[0][:bsz]
                        tgt_logits = logits[0][bsz:]
                        if test_method == 'ensemble_logit':
                            logits = (alpha*src_logits + (1-alpha)*tgt_logits, )
                        elif test_method == 'ensemble_prob':
                            src_probs = F.softmax(src_logits, dim=1)
                            tgt_probs = F.softmax(tgt_logits, dim=1)
                            logits = (alpha*src_probs + (1-alpha)*tgt_probs, )
                            
                            src_probs_all.append(src_probs[range(len(golds)),golds].cpu().numpy())
                            tgt_probs_all.append(tgt_probs[range(len(golds)),golds].cpu().numpy())
                            
                            src_corr_all.append((src_probs.max(1)[1] == golds).cpu().numpy())
                            tgt_corr_all.append((tgt_probs.max(1)[1] == golds).cpu().numpy())
                            ens_corr_all.append((logits[0].max(1)[1] == golds).cpu().numpy())
                        else:
                            raise ValueError
                            
src_probs_all_np = np.hstack(src_probs_all)
tgt_probs_all_np = np.hstack(tgt_probs_all)

src_corr_all_np = np.hstack(src_corr_all)
tgt_corr_all_np = np.hstack(tgt_corr_all)
ens_corr_all_np = np.hstack(ens_corr_all)

In [None]:
type1_idx = np.where((tgt_corr_all_np==True) &
                     (ens_corr_all_np==True))[0]

type2_idx = np.where((tgt_corr_all_np==True) &
                     (ens_corr_all_np==False))[0] # (-)

type3_idx = np.where((tgt_corr_all_np==False) &
                     (ens_corr_all_np==True))[0] # (+)

type4_idx = np.where((tgt_corr_all_np==False) &
                     (ens_corr_all_np==False))[0]

plt.figure(figsize=(12,12))
for i, idx in enumerate([type1_idx, type2_idx, type3_idx, type4_idx]):
    if i+1 == 2:
        plt.scatter(src_probs_all_np[idx], tgt_probs_all_np[idx], s=10, label='type{} (-) ({})'.format(i+1, len(idx)))
    elif i+1 == 3:
        plt.scatter(src_probs_all_np[idx], tgt_probs_all_np[idx], s=10, label='type{} (+) ({})'.format(i+1, len(idx)))
    else:
        plt.scatter(src_probs_all_np[idx], tgt_probs_all_np[idx], s=10, label='type{} ({})'.format(i+1, len(idx)))
plt.title(language)
plt.xlabel('probability on correct label (src)')
plt.ylabel('probability on correct label (tgt)')
plt.legend()
plt.show()
plt.close()

# Repr.

In [None]:
eval_langs = ['german']

base_dir = '/data/FSXLT_dataset/checkpoint_baseline/'
dataset_name = 'marc'
experience_name = 'baseline' # ce, mixup, supcon, supcon_mixup
choice = 'best' # best or last

# trainer.train
training_tgt = 'or' # or, mt, bt

model = get_model(model, base_dir, dataset_name, training_tgt, experience_name, choice)

opt, model = trainer._init_model_opt(model)
device = 'cuda:1'
trainer.model = model
trainer.model.to(device)
trainer.model.eval()

src_feats_lst = []
tgt_feats_lst = []
golds_lst = []

for language in eval_langs:
    print ("Start language-{}".format(language))
            
    trn_iters = []
    egs = adapt_loaders[language].tst_egs
    trn_iters.append(iter(egs))
    
    correct = 0.
    total = 0.

    batches_per_epoch = max(len(ti) for ti in trn_iters)
    for batch_index in range(1, batches_per_epoch + 1):
        for ti in trn_iters:
            try:
                batched = next(ti)
            except StopIteration:
                continue
            batched, golds, uids, _golds_tagging = trainer.collocate_batch_fn(
                batched
            )
            
            with torch.no_grad():
                if language == 'english': # len(golds.size()) == 1
                    pass
                else:                     # len(golds.size()) == 2
                    bsz = len(golds[:, 0])
                    for k in batched.keys():
                        batched[k] = torch.cat([batched[k][:, 0], batched[k][:, 1]], dim=0)
                    golds = golds[:, 0]
                        
                for k, v in batched.items():
                    batched[k] = v.to(device)
                golds = golds.to(device)
                logits, feats, *_ = trainer._model_forward(trainer.model, **batched)
                
                if language == 'english':
                    pass
                else:
                    src_feats = feats[:bsz]
                    tgt_feats = feats[bsz:]
                    
                    src_feats_lst.append(src_feats.cpu().numpy())
                    tgt_feats_lst.append(tgt_feats.cpu().numpy())
                    golds_lst.append(golds.cpu().numpy())

or_src_feats_np = np.vstack(src_feats_lst)
or_tgt_feats_np = np.vstack(tgt_feats_lst)
or_golds_np = np.hstack(golds_lst)

# trainer.train
training_tgt = 'bt' # or, mt, bt

model = get_model(model, base_dir, dataset_name, training_tgt, experience_name, choice)

opt, model = trainer._init_model_opt(model)
device = 'cuda:1'
trainer.model = model
trainer.model.to(device)
trainer.model.eval()

src_feats_lst = []
tgt_feats_lst = []
golds_lst = []

for language in eval_langs:
    print ("Start language-{}".format(language))
            
    trn_iters = []
    egs = adapt_loaders[language].tst_egs
    trn_iters.append(iter(egs))
    
    correct = 0.
    total = 0.

    batches_per_epoch = max(len(ti) for ti in trn_iters)
    for batch_index in range(1, batches_per_epoch + 1):
        for ti in trn_iters:
            try:
                batched = next(ti)
            except StopIteration:
                continue
            batched, golds, uids, _golds_tagging = trainer.collocate_batch_fn(
                batched
            )
            
            with torch.no_grad():
                if language == 'english': # len(golds.size()) == 1
                    pass
                else:                     # len(golds.size()) == 2
                    bsz = len(golds[:, 0])
                    for k in batched.keys():
                        batched[k] = torch.cat([batched[k][:, 0], batched[k][:, 1]], dim=0)
                    golds = golds[:, 0]
                        
                for k, v in batched.items():
                    batched[k] = v.to(device)
                golds = golds.to(device)
                logits, feats, *_ = trainer._model_forward(trainer.model, **batched)
                
                if language == 'english':
                    pass
                else:
                    src_feats = feats[:bsz]
                    tgt_feats = feats[bsz:]
                    
                    src_feats_lst.append(src_feats.cpu().numpy())
                    tgt_feats_lst.append(tgt_feats.cpu().numpy())
                    golds_lst.append(golds.cpu().numpy())

bt_src_feats_np = np.vstack(src_feats_lst)
bt_tgt_feats_np = np.vstack(tgt_feats_lst)
bt_golds_np = np.hstack(golds_lst)

In [None]:
# from tsnecuda import TSNE
from sklearn.manifold import TSNE
import pandas as pd

feats_all = np.concatenate([or_src_feats_np,
                            or_tgt_feats_np,
                            bt_src_feats_np,
                            bt_tgt_feats_np])
golds_all = np.concatenate([['or_src_'+str(a) for a in or_golds_np],
                            ['or_tgt_'+str(a) for a in or_golds_np],
                            ['bt_src_'+str(a) for a in bt_golds_np],
                            ['bt_tgt_'+str(a) for a in bt_golds_np],])

embedded_all = TSNE(n_components=2).fit_transform(feats_all)
# X_embedded = TSNE(n_components=2, perplexity=15, learning_rate=10).fit_transform(feats_all)

In [None]:
plt.figure(figsize=(18,18))
unique_golds = np.unique(golds_all).tolist()

colors = ['tab:blue', 'tab:orange', 'tab:green', 'tab:red', 'tab:purple',
          'tab:brown', 'tab:pink', 'tab:gray', 'tab:olive', 'tab:cyan']

ls = [0]

for idx, g in enumerate(unique_golds):
    selected_idx = np.where(golds_all==g)[0]
    
    if 'src' in g:
        marker = 'o'
    elif 'tgt' in g:
        marker = 'x'
    
    if 'or' in g:
        if idx % 5 in ls:
            plt.scatter(embedded_all[selected_idx, 0], embedded_all[selected_idx, 1], c=colors[idx%5], marker=marker, label=g)
    
# plt.legend()
# plt.show()
# plt.close()

# plt.figure(figsize=(18,18))
unique_golds = np.unique(golds_all).tolist()

colors = ['tab:blue', 'tab:orange', 'tab:green', 'tab:red', 'tab:purple',
          'tab:brown', 'tab:pink', 'tab:gray', 'tab:olive', 'tab:cyan']

for idx, g in enumerate(unique_golds):
    selected_idx = np.where(golds_all==g)[0]
    
    if 'src' in g:
        marker = 'o'
    elif 'tgt' in g:
        marker = 'x'
    
    if 'bt' in g:
        if idx % 5 in ls:
            plt.scatter(embedded_all[selected_idx, 0], embedded_all[selected_idx, 1], c=colors[idx%5-1], marker=marker, label=g)
    
plt.legend()
plt.show()
plt.close()