In [None]:
import os
import sys
sys.path.append('../src')

import copy
import datasets
import decouple
import numpy as np
import pandas as pd
import transformers as tfs
import torch
import tqdm

import analyze_utils
import data_collator
import eval_utils
import input_utils
import modeling_bert
import utils

SCRATCH_DIR = decouple.config('SCRATCH_PARENT_DIR')
NFS_DIR = decouple.config('NFS_PARENT_DIR')

# Loads data and initialize paths.

In [None]:
data_name = 'abcd'
model_path = f'{NFS_DIR}/exp/abcd/finetune/theta-cls-hibert-absolute-pos_init_from_bert-base-uncased-abcd-wwm_job_name-abcd-final-version/abcd-cluster-60-state-12/seed-9/checkpoint-2012/fp32'
target_split = 'test' 
batch_size = 64
cuda_device = 0

# Structure
num_clusters = 60
num_splits = 9
num_states = 12
num_merging = 0

fp16 = True
debug_mode = False
task_name = 'finetune'
pos_type = 'absolute'
embedding_name = 'bert_mean_pooler_output'
model_name = 'theta-cls-hibert'
data_config_path = f'../config/data/{data_name}.yaml'
coordinator_config_path = f'../config/model/theta-cls-hibert/cls-cluster-state-structure-hibert-absolute-pos-config-1layer-2head.json'

In [None]:
data_config = utils.read_yaml(data_config_path)

eval_fn = eval_utils.get_classification_finetuning_post_process_function(
    data_config['config']['num_labels'])

pop_keys = ['mlm_labels', 'labels', 'sentence_masked_idxs']

assert os.path.isdir(model_path), model_path
emb_dir = data_config['path']['embedding_dir'].format(nfs_dir=NFS_DIR)
assignment_dir = data_config['path']['assignment_dir'].format(nfs_dir=NFS_DIR)
assignment_dir = os.path.join(
    assignment_dir, f'{embedding_name}/num_clusters_{num_clusters}')

os.makedirs(emb_dir, exist_ok=True)
os.makedirs(assignment_dir, exist_ok=True)

print(f'{model_path = }')
print(f'{emb_dir = }')
print(f'{assignment_dir = }')

tokenizer = tfs.AutoTokenizer.from_pretrained(model_path)

model_config = tfs.AutoConfig.from_pretrained(model_path)

dataset_dir = data_config['path']['dataset_dir'].format(scratch_dir=SCRATCH_DIR)

columns = [
    'input_ids',
    'attention_mask',
    'num_turns',
    data_config['config']['label_name']]

local_ds_dir = \
    utils.get_dataset_dir_map(task_name, dataset_dir, model_path, debug_mode)

raw_ds = datasets.DatasetDict.load_from_disk(local_ds_dir['raw_ds'])

raw_ds = input_utils.add_assignments(raw_ds, assignment_dir, 'cluster')
columns.append('cluster_input_ids')
columns.append('cluster_attention_mask')

state_assignment_dir = \
    data_config['path']['state_assignment_dir'].format(nfs_dir=NFS_DIR)
state_assignment_dir = os.path.join(
    state_assignment_dir,
    f'{embedding_name}/num_clusters_{num_clusters}/num_splits_{num_splits}_num_states_{num_states}_num_merging_{num_merging}')
raw_ds = input_utils.add_assignments(raw_ds, state_assignment_dir, 'state')
columns.append('state_input_ids')
columns.append('state_attention_mask')

ds = copy.deepcopy(raw_ds)
ds.set_format(type='torch', columns=columns)

# Shows an example.

In [None]:
idx = 4
for turn in raw_ds['train']['dialogue'][idx]:
    utt = turn['turn']
    party = turn['party']
    print(f'{party:>10}: {utt}')

In [None]:
import hierarchical_models as hm

coordinator_config = tfs.AutoConfig.from_pretrained(
    coordinator_config_path,
    use_cache=False,
    num_labels=data_config['config']['num_labels'])

use_state_sequence_classifier = \
    getattr(coordinator_config, 'use_state_sequence_classifier', False)

use_cluster_sequence_classifier = \
    getattr(coordinator_config, 'use_cluster_sequence_classifier', False)

if use_cluster_sequence_classifier:
    num_clusters = num_clusters * 2
else:
    num_clusters = None
if use_state_sequence_classifier:
    num_states = num_states
else:
    num_states = None

model = hm.HierarchicalBertModelForConversationClassification.from_pretrained(
    model_path,
    coordinator_config=coordinator_config,
    num_states=num_states,
    num_clusters=num_clusters,
    use_state_sequence_classifier=use_state_sequence_classifier,
    use_cluster_sequence_classifier=use_cluster_sequence_classifier,
    state_sequence_encoder_type='transformer',
    cluster_sequence_encoder_type='transformer',
    num_labels=data_config['config']['num_labels'])

In [None]:
collate_fn = data_collator.ConversationDataCollator(
    data_config['config']['label_name'], tokenizer.pad_token_id)

data_loader = torch.utils.data.DataLoader(
    ds[target_split], batch_size=batch_size, collate_fn=collate_fn)

# Loads the model and extract embeddings.
You can skips if you already extracted embeddings.

In [None]:
_ = model.cuda(cuda_device)
_ = model.eval()
if fp16:
    model = model.half()

In [None]:
all_labels =[]
all_logits = []
for batch in tqdm.tqdm(data_loader):
    batch.pop('num_turns')
    utils.obj_to_device(batch, device=cuda_device)
    with torch.no_grad():
        output = model(**batch)
        all_logits.append(output['logits'].cpu().float())
        all_labels.append(batch['labels'].cpu())
all_logits = torch.cat(all_logits).numpy()
all_labels = torch.cat(all_labels).numpy()
predictions = {'logits': all_logits}
inputs = {'labels': all_labels}
results = {k: round(v, 3) for k, v in eval_fn(inputs, None, predictions, None, None).items()}

In [None]:
results