In [None]:
# libraries from finetuning_parameters.py
from finetuning_parameters import get_args
from future.baseline_trainer import BaselineTuner
from future.modules import ptl2classes
from future.hooks import EvaluationRecorder

from data_loader.wrap_sampler import wrap_sampler
import data_loader.task_configs as task_configs
import data_loader.data_configs as data_configs
from future.collocate_fns import task2collocate_fn

import utils.checkpoint as checkpoint
import utils.logging as logging

import torch
import random
import os

# libraries from future/base.py
from torch.utils.data import SequentialSampler, RandomSampler
from future.hooks import EvaluationRecorder
import utils.eval_meters as eval_meters
from seqeval.metrics import f1_score as f1_score_tagging
import torch

# libraries from future/baseline_trainer.py
import torch
import torch.nn as nn
import numpy as np
from copy import deepcopy
from future.base import BaseTrainer
from future.hooks.base_hook import HookContainer
from future.hooks import EvaluationRecorder
from torch.utils.data import RandomSampler
from collections import defaultdict, Counter
from tqdm import tqdm

# and so on..
from finetuning_baseline import init_config, init_task, init_hooks

os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   
os.environ["CUDA_VISIBLE_DEVICES"]="1"

In [None]:
# define task and model (same in finetuning_parameters.py)

In [None]:
parser = get_args()
conf = parser.parse_args(args=[])

In [None]:
conf.dataset_name = 'pawsx'
conf.trn_languages = 'english'
conf.eval_languages = 'english'
conf.finetune_epochs = 10
conf.finetune_batch_size = 256
conf.eval_every_batch = 50
conf.override = False
conf.train_fast = False
conf.world = '0'
conf.finetune_lr = 1e-5

In [None]:
init_config(conf)
model, tokenizer, data_iter, metric_name, collocate_batch_fn = init_task(conf)
adapt_loaders = {}
for language, language_dataset in data_iter.items():
    # NOTE: the sample dataset are refered
    adapt_loaders[language] = wrap_sampler(
        trn_batch_size=conf.finetune_batch_size,
        infer_batch_size=conf.inference_batch_size,
        language=language,
        language_dataset=language_dataset,
    )
hooks = init_hooks(conf, metric_name)

In [None]:
trainer = BaselineTuner(
        conf, collocate_batch_fn=collocate_batch_fn, logger=conf.logger
    )

In [None]:
labels = np.empty((0,))
features = np.empty((0, 768))

# trainer.train
opt, model = trainer._init_model_opt(model)
trainer.model = model
trainer.model.eval()

for epoch_index in tqdm(range(1, 1 + 1)):
    trn_iters = []
    for languge in trainer.conf.trn_languages:
        egs = adapt_loaders[language].trn_egs
        assert isinstance(egs.sampler, RandomSampler)
        trn_iters.append(iter(egs))
        
    batches_per_epoch = max(len(ti) for ti in trn_iters)
    for batch_index in range(1, batches_per_epoch + 1):
        trn_loss = []
        for ti in trn_iters:
            try:
                batched = next(ti)
            except StopIteration:
                continue
            batched, golds, uids, _golds_tagging = trainer.collocate_batch_fn(
                batched
            )
            with torch.no_grad():
                hidden = trainer.model.get_last_hidden(**batched)
                labels = np.concatenate((labels, golds.cpu()))
                features = np.concatenate((features, hidden.cpu()), axis=0)
                print (hidden.size())

In [None]:
import pickle

def save_pickle(file, data):
    with open(file, 'wb') as f:
        pickle.dump(data, f)
        
def load_pickle(file):
    with open(file, 'rb') as f:
        return pickle.load(f)

In [None]:
base_means = []
base_covs = []
output_dict = {}

for i in np.unique(labels):
    feature = features[labels == i]
    mean = np.mean(feature, axis=0)
    cov = np.cov(feature.T)
    base_means.append(mean)
    base_covs.append(cov)
    
    output_dict[str(int(i))] = [{"mean": mean, "cov": cov}]

In [None]:
save_pickle('features.pkl', output_dict)

In [None]:
load_pickle('features.pkl')