In [3]:
from model import SomDST
from utils.data_utils import prepare_dataset, MultiWozDataset
from utils.data_utils import  domain2id, OP_SET, make_turn_label, postprocessing # make_slot_meta,
from utils.eval_utils import compute_prf, compute_acc, per_domain_join_accuracy
from utils.ckpt_utils import download_ckpt, convert_ckpt_compatible
from evaluation import model_evaluation
from transformers import ElectraTokenizer, ElectraModel, PretrainedConfig, BertConfig, AdamW, get_linear_schedule_with_warmup

import torch
import torch.nn as nn
import torch.optim as optim
import torch.cuda.amp as amp
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
import numpy as np
import argparse
import random
import os
import json
import time
import datetime
import pickle
from tqdm import tqdm

In [4]:
# Torch Device 설정 
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device Name : {device}')

Device Name : cuda


In [5]:
def masked_cross_entropy_for_value(logits, target, pad_idx=0):
    mask = target.ne(pad_idx)
    logits_flat = logits.view(-1, logits.size(-1))
    smoother = 1e-9
    log_probs_flat = torch.log(logits_flat+smoother)
    target_flat = target.view(-1, 1)
    losses_flat = -torch.gather(log_probs_flat, dim=1, index=target_flat)
    losses = losses_flat.view(*target.size())
    losses = losses * mask.float()
    loss = losses.sum() / (mask.sum().float())
    return loss


### args setting

In [6]:
from argparse import Namespace

parser = argparse.ArgumentParser()

# Required parameters
args = {
    "data_root":'data',  # default='data/mwz2.1'
    "train_data":'train_dials.json',
    "dev_data":'dev_dials.json',
    "test_data":'test_dials.json',
    "ontology_data":'ontology.json',
    "slot_meta":'slot_meta.json',
    # "--vocab_path":'assets/vocab.txt',
    "bert_config_path":'assets/koelectra_base.json',
    "bert_ckpt_path":'assets/koelectra_base.bin',
    "save_dir":'outputs',

    "random_seed":2021, 
    "num_workers":4, 
    "batch_size": 28, 
    "enc_warmup":0.1, 
    "dec_warmup":0.1, 
    "enc_lr":4e-5, # default 4e-5
    "dec_lr":1e-4, 
    "n_epochs": 50, 
    "eval_epoch":1, 

    "op_code":"4",
    "slot_token":"[SLOT]",
    "dropout":0.1, 
    "hidden_dropout_prob":0.1, 
    "attention_probs_dropout_prob":0.1, 
    "decoder_teacher_forcing":0.8, 
    "word_dropout":0.1, 
    "not_shuffle_state":False,
    "shuffle_p":0.5, 

    "n_history":1, 
    "max_seq_length":512, 
    "msg":None,
    "exclude_domain": True, 
}

args = Namespace(**args)

args.train_data_path = os.path.join(args.data_root, args.train_data)
args.dev_data_path = os.path.join(args.data_root, args.dev_data)
args.test_data_path = os.path.join(args.data_root, args.test_data)
args.ontology_data = os.path.join(args.data_root, args.ontology_data)
args.slot_meta = os.path.join(args.data_root, args.slot_meta)
args.shuffle_state = False if args.not_shuffle_state else True


print('pytorch version: ', torch.__version__)

pytorch version:  1.7.0+cu101


In [7]:
def worker_init_fn(worker_id):
        np.random.seed(args.random_seed + worker_id)

n_gpu = 0
if torch.cuda.is_available():
    n_gpu = torch.cuda.device_count()
np.random.seed(args.random_seed)
random.seed(args.random_seed)
rng = random.Random(args.random_seed)
torch.manual_seed(args.random_seed)
if n_gpu > 0:
    torch.cuda.manual_seed(args.random_seed)
    torch.cuda.manual_seed_all(args.random_seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

if not os.path.exists(args.save_dir):
    os.mkdir(args.save_dir)

ontology = json.load(open(args.ontology_data))
slot_meta = json.load(open(args.slot_meta))
op2id = OP_SET[args.op_code]
tokenizer = ElectraTokenizer.from_pretrained("monologg/koelectra-base-v3-discriminator", \
    additional_special_tokens = ['[SLOT]', '[NULL]','[EOS]', ' ; ']) 

Special tokens have been added in the vocabulary, make sure the associated word embedding are fine-tuned or trained.


In [8]:
######## DATA PREPERATION ########
print('Making Train_Data_Raw....')
if not os.path.exists('./raw_data'):
    os.makedirs('./raw_data')
    
if not os.path.exists('./raw_data/train_data_raw'):
    train_data_raw = prepare_dataset(data_path=args.train_data_path,
                                    tokenizer=tokenizer,
                                    slot_meta=slot_meta,
                                    n_history=args.n_history,
                                    max_seq_length=args.max_seq_length,
                                    op_code=args.op_code)
    with open('./raw_data/train_data_raw', 'wb') as f:
        pickle.dump(train_data_raw, f)
else:
    with open('./raw_data/train_data_raw', 'rb') as f:
        train_data_raw = pickle.load(f)


print('Making Dev_Data_Raw....')
if not os.path.exists('./raw_data/dev_data_raw'):
    dev_data_raw = prepare_dataset(data_path=args.dev_data_path,
                                tokenizer=tokenizer,
                                slot_meta=slot_meta,
                                n_history=args.n_history,
                                max_seq_length=args.max_seq_length,
                                op_code=args.op_code)
    with open('./raw_data/dev_data_raw', 'wb') as f:
        pickle.dump(dev_data_raw, f)
else:
    with open('./raw_data/dev_data_raw', 'rb') as f:
        dev_data_raw = pickle.load(f)

train_data = MultiWozDataset(train_data_raw,
                             tokenizer,
                             slot_meta,
                             args.max_seq_length,
                             rng,
                             ontology,
                             args.word_dropout,
                             args.shuffle_state,
                             args.shuffle_p)

print("# train examples %d" % len(train_data_raw))
print("# dev examples %d" % len(dev_data_raw))

if not os.path.exists('./raw_data/test_data_raw'):
    test_data_raw = prepare_dataset(data_path=args.test_data_path,
                                    tokenizer=tokenizer,
                                    slot_meta=slot_meta,
                                    n_history=args.n_history,
                                    max_seq_length=args.max_seq_length,
                                    op_code=args.op_code)
    with open('./raw_data/test_data_raw', 'wb') as f:
        pickle.dump(test_data_raw, f)
else:
    with open('./raw_data/test_data_raw', 'rb') as f:
        test_data_raw = pickle.load(f)
print("# test examples %d" % len(test_data_raw))


Making Train_Data_Raw....
Making Dev_Data_Raw....
# train examples 45410
# dev examples 4977
# test examples 14771


In [9]:
######## MODEL CONFIG ########
model_config = PretrainedConfig.from_json_file(args.bert_config_path)
model_config.dropout = args.dropout
model_config.attention_probs_dropout_prob = args.attention_probs_dropout_prob
model_config.hidden_dropout_prob = args.hidden_dropout_prob
# model_config.vocab_size += 4

model = SomDST(model_config, len(op2id), len(domain2id), op2id['update'], args.exclude_domain)

# if not os.path.exists(args.bert_ckpt_path):
#     args.bert_ckpt_path = download_ckpt(args.bert_ckpt_path, args.bert_config_path, 'assets')

# ckpt = torch.load(args.bert_ckpt_path, map_location='cpu')
# model.encoder.bert.load_state_dict(ckpt)

# re-initialize added special tokens ([SLOT], [NULL], [EOS])
# print('Re-Initialize Special Tokens...')
model.encoder.bert.embeddings.word_embeddings.weight.data[-1].normal_(mean=0.0, std=0.02)
model.encoder.bert.embeddings.word_embeddings.weight.data[-2].normal_(mean=0.0, std=0.02)
model.encoder.bert.embeddings.word_embeddings.weight.data[-3].normal_(mean=0.0, std=0.02)
model.encoder.bert.embeddings.word_embeddings.weight.data[-4].normal_(mean=0.0, std=0.02)
model.to(device)

####### STEPS ########
num_train_steps = int(len(train_data_raw) / args.batch_size * args.n_epochs)

Special tokens have been added in the vocabulary, make sure the associated word embedding are fine-tuned or trained.


In [10]:
####### ENC/DEC OPTIMIZER ########
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
enc_param_optimizer = list(model.encoder.named_parameters())
enc_optimizer_grouped_parameters = [
    {'params': [p for n, p in enc_param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
    {'params': [p for n, p in enc_param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ]

enc_optimizer = AdamW(enc_optimizer_grouped_parameters, lr=args.enc_lr) # optim.Adam(enc_optimizer_grouped_parameters, lr=args.enc_lr, amsgrad=True) # 
enc_scheduler = get_linear_schedule_with_warmup(enc_optimizer, num_warmup_steps=int(num_train_steps * args.enc_warmup),
                                     num_training_steps=num_train_steps)

dec_param_optimizer = list(model.decoder.parameters())
dec_optimizer = AdamW(dec_param_optimizer, lr=args.dec_lr) # optim.Adam(dec_param_optimizer, lr=args.dec_lr, amsgrad=True) # 
dec_scheduler = get_linear_schedule_with_warmup(dec_optimizer, num_warmup_steps=int(num_train_steps * args.dec_warmup),
                                     num_training_steps=num_train_steps)

In [11]:
if n_gpu > 1:
    model = torch.nn.DataParallel(model)

print('Making Train DataLoader...')
train_sampler = RandomSampler(train_data)
train_dataloader = DataLoader(train_data,
                              sampler=train_sampler,
                              batch_size=args.batch_size,
                              collate_fn=train_data.collate_fn,
                              num_workers=args.num_workers,
                              worker_init_fn=worker_init_fn)

Making Train DataLoader...


In [12]:
import wandb

wandb.init(project="SOM-DST")
config = wandb.config

[34m[1mwandb[0m: Currently logged in as: [33mtaepd[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.10.30 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


In [17]:
def get_lr(scheduler):
    return scheduler.get_last_lr()[0]

loss_fnc = nn.CrossEntropyLoss()
best_score = {'epoch': 0, 'joint_acc': 0, 'op_acc': 0, 'final_slot_f1': 0}

######## TRAINING ########
print("Let's Do the Training!")
for epoch in range(args.n_epochs):
    start = time.time()
    batch_loss = []
    model.train()
    for step, batch in enumerate(train_dataloader):
        batch = [b.to(device) if not isinstance(b, int) else b for b in batch]
        input_ids, input_mask, segment_ids, state_position_ids, op_ids,\
         domain_ids, gen_ids, max_value, max_update = batch

        if rng.random() < args.decoder_teacher_forcing:  # teacher forcing
            teacher = gen_ids
        else:
            teacher = None
        with amp.autocast():
            domain_scores, state_scores, gen_scores = model(input_ids=input_ids,
                                                            token_type_ids=segment_ids,
                                                            state_positions=state_position_ids,
                                                            attention_mask=input_mask,
                                                            max_value=max_value,
                                                            op_ids=op_ids,
                                                            max_update=max_update,
                                                            teacher=teacher)

            loss_s = loss_fnc(state_scores.view(-1, len(op2id)), op_ids.view(-1))
            loss_g = masked_cross_entropy_for_value(gen_scores.contiguous(),
                                                    gen_ids.contiguous(),
                                                    tokenizer.vocab['[PAD]'])

            loss = loss_s + loss_g
            if args.exclude_domain is not True: # exclude domain True로 해야!
                loss_d = loss_fnc(domain_scores.view(-1, len(domain2id)), domain_ids.view(-1))
                loss = loss + loss_d
            batch_loss.append(loss.item()) # loss_s + loss_g

            loss.backward()

        # print('ENC - Optimizer/Scheduler Step')
        enc_optimizer.step()
        enc_scheduler.step()
        # print('DEC - Optimizer/Scheduler Step')
        dec_optimizer.step()
        dec_scheduler.step()
        model.zero_grad()

        if step % 100 == 0:
            if args.exclude_domain is not True:
                print("[%d/%d] [%d/%d] mean_loss : %.3f, state_loss : %.3f, gen_loss : %.3f, dom_loss : %.3f" \
                      % (epoch+1, args.n_epochs, step,
                         len(train_dataloader), np.mean(batch_loss),
                         loss_s.item(), loss_g.item(), loss_d.item()))
                wandb.log({
                    "train/mean_loss": np.mean(batch_loss),
                    "train/state_loss": loss_s.item(),
                    "train/gen_loss": loss_g.item(),
                    "train/dom_loss": loss_d.item(),
                    "train/epoch": epoch+1,
                    "train/enc_learning rate": get_lr(enc_scheduler),
                    "train/dec_learning rate": get_lr(dec_scheduler)
                })
            else:
                print("[%d/%d] [%d/%d] mean_loss : %.3f, state_loss : %.3f, gen_loss : %.3f" \
                      % (epoch+1, args.n_epochs, step,
                         len(train_dataloader), np.mean(batch_loss),
                         loss_s.item(), loss_g.item()))
                wandb.log({
                    "train/mean_loss": np.mean(batch_loss),
                    "train/state_loss": loss_s.item(),
                    "train/gen_loss": loss_g.item(),
                    "train/epoch": epoch+1,
                    "train/enc_learning rate": get_lr(enc_scheduler),
                    "train/dec_learning rate": get_lr(dec_scheduler)
                })
            batch_loss = []

    sec = time.time() - start
    times = str(datetime.timedelta(seconds=sec)).split(".")
    times = times[0]
    print(f'<<<<<<<<<<  {epoch+1} EPOCH spent : {times}  >>>>>>>>>>')

#     if (epoch+1) % 5 == 0:
#         model_to_save = model
#         save_path = os.path.join(args.save_dir, f'model_best_epoch{epoch+1}.bin')
#         torch.save(model_to_save.state_dict(), save_path)

    if (epoch+1) % args.eval_epoch == 0:
        eval_res = model_evaluation(model, dev_data_raw, tokenizer, slot_meta, epoch+1, args.op_code)
        wandb.log({
            'eval/epoch': epoch+1,
            'eval/joint_acc': eval_res['joint_acc'],
            'eval/slot_acc':eval_res['slot_acc'],
            'eval/slot_f1': eval_res['slot_f1'],
            'eval/op_acc': eval_res['op_acc'],
            'eval/op_f1': eval_res['op_f1'],
            'eval/final_slot_f1':eval_res['final_slot_f1'],
        })
        
        if eval_res['joint_acc'] > best_score['joint_acc']:
            best_score = eval_res
            model_to_save = model
            save_path = os.path.join(args.save_dir, 'model_best.bin')
            torch.save(model_to_save.state_dict(), save_path)
        print("Best Score : ", best_score)
        print("\n")

Let's Do the Training!
[1/50] [0/1622] mean_loss : 2.146, state_loss : 0.141, gen_loss : 2.005
[1/50] [100/1622] mean_loss : 2.278, state_loss : 0.189, gen_loss : 2.336
[1/50] [200/1622] mean_loss : 2.207, state_loss : 0.190, gen_loss : 1.779
[1/50] [300/1622] mean_loss : 2.162, state_loss : 0.140, gen_loss : 2.230
[1/50] [400/1622] mean_loss : 2.130, state_loss : 0.155, gen_loss : 1.680
[1/50] [500/1622] mean_loss : 2.049, state_loss : 0.216, gen_loss : 1.687
[1/50] [600/1622] mean_loss : 2.019, state_loss : 0.164, gen_loss : 1.652
[1/50] [700/1622] mean_loss : 1.994, state_loss : 0.153, gen_loss : 2.132
[1/50] [800/1622] mean_loss : 1.947, state_loss : 0.167, gen_loss : 1.760
[1/50] [900/1622] mean_loss : 1.884, state_loss : 0.175, gen_loss : 1.777
[1/50] [1000/1622] mean_loss : 1.732, state_loss : 0.198, gen_loss : 1.809
[1/50] [1100/1622] mean_loss : 1.604, state_loss : 0.159, gen_loss : 1.330
[1/50] [1200/1622] mean_loss : 1.482, state_loss : 0.150, gen_loss : 1.210
[1/50] [1300/1

KeyboardInterrupt: 

In [13]:
############### TEST DATA EVALUATION ####################
print("Test using best model...")
# best_epoch = best_score['epoch']
best_epoch = 99
ckpt_path = os.path.join(args.save_dir, 'model_best.bin')
model = SomDST(model_config, len(op2id), len(domain2id), op2id['update'], args.exclude_domain)
ckpt = torch.load(ckpt_path, map_location='cpu')
model.load_state_dict(ckpt, strict=False)
model.to(device)

model_evaluation(model, test_data_raw, tokenizer, slot_meta, 35, args.op_code,
                 is_gt_op=False, is_gt_p_state=False, is_gt_gen=False, for_eval=True)



Test using best model...


Special tokens have been added in the vocabulary, make sure the associated word embedding are fine-tuned or trained.


------------------------------
op_code: 4, is_gt_op: False, is_gt_p_state: False, is_gt_gen: False
Epoch 35 joint accuracy :  0.013743145352379664
Epoch 35 slot turn accuracy :  0.8017075500793691
Epoch 35 slot turn F1:  0.013743145352379664
Epoch 35 op accuracy :  0.8017075500793691
Epoch 35 op F1 :  {'delete': 0, 'update': 0, 'dontcare': 0, 'carryover': 0.8905033375639501}
Epoch 35 op hit count :  {'delete': 0, 'update': 0, 'dontcare': 0, 'carryover': 532891}
Epoch 35 op all count :  {'delete': 107137, 'update': 0, 'dontcare': 0, 'carryover': 557558}
Final Joint Accuracy :  0.0
Final slot turn F1 :  0.0
Latency Per Prediction : 32.307839 ms
-----------------------------



NameError: name 'pickle' is not defined

In [16]:
from copy import deepcopy

def model_evaluation(model, test_data, tokenizer, slot_meta, epoch, op_code='4',
                     is_gt_op=False, is_gt_p_state=False, is_gt_gen=False, for_eval=False):
    model.eval()
    op2id = OP_SET[op_code]
    id2op = {v: k for k, v in op2id.items()}
    id2domain = {v: k for k, v in domain2id.items()}

    slot_turn_acc, joint_acc, slot_F1_pred, slot_F1_count = 0, 0, 0, 0
    final_joint_acc, final_count, final_slot_F1_pred, final_slot_F1_count = 0, 0, 0, 0
    op_acc, op_F1, op_F1_count = 0, {k: 0 for k in op2id}, {k: 0 for k in op2id}
    all_op_F1_count = {k: 0 for k in op2id}

    tp_dic = {k: 0 for k in op2id}
    fn_dic = {k: 0 for k in op2id}
    fp_dic = {k: 0 for k in op2id}

    results = {}
    last_dialog_state = {}
    wall_times = []
    s_logits = []
    g_logits = []
    g_idx = []

    for di, i in enumerate(test_data):
        if i.turn_id == 0:
            last_dialog_state = {}
        if is_gt_p_state is False:
            i.last_dialog_state = deepcopy(last_dialog_state)
            i.make_instance(tokenizer, word_dropout=0.)
        else:  # ground-truth previous dialogue state
            last_dialog_state = deepcopy(i.gold_p_state)
            i.last_dialog_state = deepcopy(last_dialog_state)
            i.make_instance(tokenizer, word_dropout=0.)
        
        
        input_ids = torch.LongTensor([i.input_id]).to(device)
        input_mask = torch.LongTensor([i.input_mask]).to(device) # data_utils MultiWozDataset에서는 Long이던데 여기선 why FloatTensor? -> changed into LongTensor
        segment_ids = torch.LongTensor([i.segment_id]).to(device)
        state_position_ids = torch.LongTensor([i.slot_position]).to(device)

        d_gold_op, _, _ = make_turn_label(slot_meta, last_dialog_state, i.gold_state,
                                          tokenizer, op_code, dynamic=True) # 여기는 왜 dynamic True일까?
        gold_op_ids = torch.LongTensor([d_gold_op]).to(device)

        start = time.perf_counter()
        MAX_LENGTH = 9
        with torch.no_grad():
            # ground-truth state operation
            gold_op_inputs = gold_op_ids if is_gt_op else None
            d, s, g = model(input_ids=input_ids,
                            token_type_ids=segment_ids,
                            state_positions=state_position_ids,
                            attention_mask=input_mask,
                            max_value=MAX_LENGTH,
                            op_ids=gold_op_inputs) # None
        

        s_logit = s
        
        g_logit, _g_id = torch.topk(g, 100, dim=-1)

        g_id = _g_id

        s_logits.append(s_logit)
        g_logits.append(g_logit)
        g_idx.append(g_id)

        _, op_ids = s.view(-1, len(op2id)).max(-1)

        if g.size(1) > 0:
            generated = g.squeeze(0).max(-1)[1].tolist()
        else:
            generated = []


        if is_gt_op:
            pred_ops = [id2op[a] for a in gold_op_ids[0].tolist()]
        else:
            pred_ops = [id2op[a] for a in op_ids.tolist()]
        gold_ops = [id2op[a] for a in d_gold_op]

        if is_gt_gen:
            # ground_truth generation
            gold_gen = {'-'.join(ii.split('-')[:2]): ii.split('-')[-1] for ii in i.gold_state}
        else:
            gold_gen = {}
        generated, last_dialog_state = postprocessing(slot_meta, pred_ops, last_dialog_state,
                                                      generated, tokenizer, op_code, gold_gen)
        end = time.perf_counter()
        wall_times.append(end - start)
        pred_state = []
        for k, v in last_dialog_state.items():
            pred_state.append('-'.join([k, v]))

        # print(f'pred_state: {pred_state}')
        if set(pred_state) == set(i.gold_state):
            joint_acc += 1
        key = str(i.id) + '_' + str(i.turn_id)
        results[key] = [pred_state, i.gold_state]

        # Compute prediction slot accuracy
        temp_acc = compute_acc(set(i.gold_state), set(pred_state), slot_meta)
        slot_turn_acc += temp_acc

        # Compute prediction F1 score
        temp_f1, temp_r, temp_p, count = compute_prf(i.gold_state, pred_state)
        slot_F1_pred += temp_f1
        slot_F1_count += count

        # Compute operation accuracy
        temp_acc = sum([1 if p == g else 0 for p, g in zip(pred_ops, gold_ops)]) / len(pred_ops)
        op_acc += temp_acc

        if i.is_last_turn:
            final_count += 1
            if set(pred_state) == set(i.gold_state):
                final_joint_acc += 1
            final_slot_F1_pred += temp_f1
            final_slot_F1_count += count

        # Compute operation F1 score
        for p, g in zip(pred_ops, gold_ops):
            all_op_F1_count[g] += 1
            if p == g:
                tp_dic[g] += 1
                op_F1_count[g] += 1
            else:
                fn_dic[g] += 1
                fp_dic[p] += 1

    joint_acc_score = joint_acc / len(test_data)
    turn_acc_score = slot_turn_acc / len(test_data)
    slot_F1_score = slot_F1_pred / slot_F1_count
    op_acc_score = op_acc / len(test_data)
    final_joint_acc_score = final_joint_acc / final_count
    final_slot_F1_score = final_slot_F1_pred / final_slot_F1_count
    latency = np.mean(wall_times) * 1000
    op_F1_score = {}
    for k in op2id.keys():
        tp = tp_dic[k]
        fn = fn_dic[k]
        fp = fp_dic[k]
        precision = tp / (tp+fp) if (tp+fp) != 0 else 0
        recall = tp / (tp+fn) if (tp+fn) != 0 else 0
        F1 = 2 * precision * recall / float(precision + recall) if (precision + recall) != 0 else 0
        op_F1_score[k] = F1

    print("------------------------------")
    print('op_code: %s, is_gt_op: %s, is_gt_p_state: %s, is_gt_gen: %s' % \
          (op_code, str(is_gt_op), str(is_gt_p_state), str(is_gt_gen)))
    print("Epoch %d joint accuracy : " % epoch, joint_acc_score)
    print("Epoch %d slot turn accuracy : " % epoch, turn_acc_score)
    print("Epoch %d slot turn F1: " % epoch, slot_F1_score)
    print("Epoch %d op accuracy : " % epoch, op_acc_score)
    print("Epoch %d op F1 : " % epoch, op_F1_score)
    print("Epoch %d op hit count : " % epoch, op_F1_count)
    print("Epoch %d op all count : " % epoch, all_op_F1_count)
    print("Final Joint Accuracy : ", final_joint_acc_score)
    print("Final slot turn F1 : ", final_slot_F1_score)
    print("Latency Per Prediction : %f ms" % latency)
    print("-----------------------------\n")
    if for_eval:
        json.dump(results, open('test_preds_%d.json' % epoch, 'w'), ensure_ascii=False)

         # logit 저장
        with open('./outputs/s_logits', 'wb') as f1:
            pickle.dump(s_logits, f1)  
        with open('./outputs/g_logits', 'wb') as f2:
            pickle.dump(s_logits, f2)  
        with open('./outputs/g_idx', 'wb') as f3:
            pickle.dump(s_logits, f3)  
#         np.save(os.path.join(f"outputs", r's_logits.npy'), np.concatenate(s_logits))
#         np.save(os.path.join(f"outputs", r'g_logits.npy'), np.concatenate(g_logits))
#         np.save(os.path.join(f"outputs", r'g_idx.npy'), np.concatenate(g_idx))

    else:
        json.dump(results, open('./preds/preds_%d.json' % epoch, 'w'), ensure_ascii=False)
    per_domain_join_accuracy(results, slot_meta)

    scores = {'epoch': epoch, 'joint_acc': joint_acc_score,
              'slot_acc': turn_acc_score, 'slot_f1': slot_F1_score,
              'op_acc': op_acc_score, 'op_f1': op_F1_score, 'final_slot_f1': final_slot_F1_score}
    return scores

In [19]:
pred = json.load(open('test_preds_42.json'))

In [20]:
predictions = {}
for k, v in pred.items():
    k = k.rsplit('_', 1)[0]+'_'+k.rsplit('_', 1)[1].split('-')[0] +'-' +k.rsplit('_', 1)[1].split('-')[0]
    predictions[k] = v[0]
predictions

json.dump(
        predictions,
#         open(f"{args.output_dir}/predictions.csv", "w"),
        open(f"predictions.csv", "w"),
        indent=2,
        ensure_ascii=False,
    )

[34m[1mwandb[0m: Network error resolved after 0:00:58.485487, resuming normal operation.
