In [None]:
#pip install prettyprinter

In [None]:
#pip install ruamel.yaml

In [1]:
import argparse
import json
import os
import pickle
import random
import numpy as np

import torch
import torch.nn as nn

from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from tqdm import tqdm
from transformers import AdamW, get_linear_schedule_with_warmup, AutoTokenizer, AutoConfig

from data_utils import (YamlConfigManager, WOSDataset, get_examples_from_dialogues, load_dataset,
                        set_seed, custom_to_mask, custom_get_examples_from_dialogues, custom_load_dataset)

from evaluation import _evaluation
from inference import inference
from model import TRADE, masked_cross_entropy_for_value
from preprocessor import TRADEPreprocessor
from prettyprinter import cpprint

from pathlib import Path
import glob
import re

import wandb
import time

from torch.cuda.amp import GradScaler, autocast

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
cfg = YamlConfigManager('./config.yml', 'base').values
cpprint(cfg)

easydict.EasyDict({
    'data_dir': '../input/data/train_dataset',
    'model_dir': 'results',
    'train_batch_size': 4,
    'eval_batch_size': 8,
    'learning_rate': 3e-05,
    'adam_epsilon': 1e-08,
    'max_grad_norm': 1.0,
    'num_train_epochs': 30,
    'warmup_ratio': 0.0,
    'random_seed': 42,
    'n_gate': 5,
    'teacher_forcing_ratio': 0.5,
    'model_name_or_path': 'dsksd/bert-ko-small-minimal',
    'proj_dim': 'None',
    'tag': ['trade'],
    'use_kfold': False,
    'num_k': 0,
    'val_ratio': 0.1,
    'scheduler': 'Linear',
    'mask': True
})


In [4]:
# Get current learning rate
def get_lr(scheduler):
    return scheduler.get_last_lr()[0]

In [5]:
# random seed 고정
set_seed(cfg.random_seed)

# Data Loading
train_data_file = f"{cfg.data_dir}/train_dials.json"
slot_meta = json.load(open(f"{cfg.data_dir}/slot_meta.json"))
# train_data, dev_data, dev_labels = load_dataset(train_data_file, cfg.val_ratio)
train_data, dev_data, dev_labels = custom_load_dataset(train_data_file, cfg.val_ratio, k=8)

train_examples = custom_get_examples_from_dialogues(
    train_data, user_first=False, dialogue_level=False
)
dev_examples = custom_get_examples_from_dialogues(
    dev_data, user_first=False, dialogue_level=False
)

 17%|█▋        | 1086/6300 [00:00<00:00, 10859.75it/s]

k  8
[ 149 4765 4335 2340 4677 6176 5227 6556 5449  657]


100%|██████████| 6300/6300 [00:00<00:00, 8423.36it/s] 
100%|██████████| 700/700 [00:00<00:00, 2595.12it/s]


In [6]:
train_examples[0]

DSTInputExample(guid='snowy-hat-8324:관광_식당_11-0', context_turns=[], current_turn=['', ' # ', '서울 중앙에 있는 박물관을 찾아주세요', ' * '], label=['관광-종류-박물관', '관광-지역-서울 중앙'])

In [7]:
# Define Preprocessor
tokenizer = AutoTokenizer.from_pretrained(cfg.model_name_or_path)

# Dealing with long texts The maximum sequence length of BERT is 512.
processor = TRADEPreprocessor(slot_meta, tokenizer, max_seq_length=512, n_gate=cfg.n_gate)

In [8]:
# Extracting Featrues
# cpprint('Extracting Features...')
# train_features = processor.sep_custom_convert_examples_to_features(train_examples)
# dev_features = processor.sep_custom_convert_examples_to_features(dev_examples)

In [9]:
# # 전체 train data InputFeatur 저장
# with open('custom_train_features.txt', 'wb') as f:
#     pickle.dump(train_features, f)
# with open('custom_dev_features.txt', 'wb') as f:
#     pickle.dump(dev_features, f)

In [10]:
# 저장된 파일 사용
with open('custom_train_features.txt', 'rb') as f:
    train_features = pickle.load(f)
with open('custom_dev_features.txt', 'rb') as f:
    dev_features = pickle.load(f)

In [11]:
# Slot Meta tokenizing for the decoder initial inputs
tokenized_slot_meta = []
for slot in slot_meta:
    tokenized_slot_meta.append(
        tokenizer.encode(slot.replace("-", " "), add_special_tokens=False)
    )

In [12]:
# Model 선언
config = AutoConfig.from_pretrained('dsksd/bert-ko-small-minimal')
config.model_name_or_path = 'dsksd/bert-ko-small-minimal'
config.n_gate = cfg.n_gate
config.proj_dim = None

model = TRADE(config, tokenized_slot_meta)

model.to(device)
print("Model is initialized")

  "num_layers={}".format(dropout, num_layers))


Model is initialized


In [13]:
# --wandb initialize with configuration
wandb.init(project='DST', tags=cfg.tag, config=cfg)

[34m[1mwandb[0m: Currently logged in as: [33mtaepd[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.10.30 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


In [14]:
train_data = WOSDataset(train_features)
train_sampler = RandomSampler(train_data)
train_loader = DataLoader(
    train_data,
    batch_size=cfg.train_batch_size,
    sampler=train_sampler,
    collate_fn=processor.collate_fn,
    num_workers=4,  # num_worker = 4 * num_GPU
    pin_memory=True,
)
print("# train:", len(train_data))

dev_data = WOSDataset(dev_features)
dev_sampler = SequentialSampler(dev_data)
dev_loader = DataLoader(
    dev_data,
    batch_size=cfg.eval_batch_size,
    sampler=dev_sampler,
    collate_fn=processor.collate_fn,
    num_workers=4,
    pin_memory=True,
)
print("# dev:", len(dev_data))




# train: 46127
# dev: 5118


In [15]:
# Optimizer 및 Scheduler 선언
n_epochs = cfg.num_train_epochs

no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
    {
        "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
        "weight_decay": 0.01,
    },
    {
        "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
        "weight_decay": 0.0,
    },
]

t_total = len(train_loader) * n_epochs
optimizer = AdamW(optimizer_grouped_parameters, lr=cfg.learning_rate, eps=cfg.adam_epsilon)
warmup_steps = int(t_total * cfg.warmup_ratio)
# learning rate decreases linearly from the initial lr set in the optimizer to 0
scheduler = get_linear_schedule_with_warmup(
    optimizer, num_warmup_steps=warmup_steps, num_training_steps=t_total
)
teacher_forcing = cfg.teacher_forcing_ratio

loss_fnc_1 = masked_cross_entropy_for_value  # generation
loss_fnc_2 = nn.CrossEntropyLoss()  # gating
loss_fnc_pretrain = nn.CrossEntropyLoss()  # MLM pretrain

In [16]:
# 모델 저장될 파일 위치 생성
if not os.path.exists(f"{cfg.model_dir}"):
    os.mkdir(f"{cfg.model_dir}")
if not os.path.exists(f"{cfg.model_dir}/{wandb.run.name}"):
    os.mkdir(f"{cfg.model_dir}/{wandb.run.name}")

In [17]:
json.dump(
    vars(cfg),
    open(f"{cfg.model_dir}/{wandb.run.name}/exp_config.json", "w"),
    indent=2,
    ensure_ascii=False,
)
json.dump(
    slot_meta,
    open(f"{cfg.model_dir}/slot_meta.json", "w"),
    indent=2,
    ensure_ascii=False,
)

### Pretraining

In [18]:
eval_data = json.load(open(f"../input/data/eval_dataset/eval_dials.json", "r"))

eval_examples = get_examples_from_dialogues(
    eval_data, user_first=False, dialogue_level=False
)

# Extracting Featrues
eval_features = processor.convert_examples_to_features(eval_examples)
eval_data = WOSDataset(eval_features)
eval_sampler = SequentialSampler(eval_data)
eval_loader = DataLoader(
    eval_data,
    batch_size=8,
    sampler=eval_sampler,
    collate_fn=processor.collate_fn,
)

100%|██████████| 2000/2000 [00:00<00:00, 15715.12it/s]
Token indices sequence length is longer than the specified maximum sequence length for this model (539 > 512). Running this sequence through the model will result in indexing errors


In [19]:
MLM_PRE = True

scaler = GradScaler()
n_pretrain_epochs = 10

def mlm_pretrain(loader, n_epochs):
    model.train()
    for step, batch in enumerate(tqdm(loader)):
        input_ids, segment_ids, input_masks, gating_ids, target_ids, guids = [b.to(device) if not isinstance(b, list) else b for b in batch]
        
        with autocast(): # 밑에 해당하는 코드를 자동으로 mixed precision으로 변환시켜서 실행
            logits, labels = model.forward_pretrain(input_ids, tokenizer)
            loss = loss_fnc_pretrain(logits.view(-1, config.vocab_size), labels.view(-1))

        scaler.scale(loss).backward()
        nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()
        optimizer.zero_grad()

        if step % 100 == 0:
            print('[%d/%d] [%d/%d] %f' % (epoch, n_epochs, step, len(loader), loss.item()))

if MLM_PRE:
    for epoch in range(n_pretrain_epochs):
        mlm_pretrain(eval_loader, n_pretrain_epochs)

  0%|          | 2/1847 [00:00<02:37, 11.70it/s]

[0/10] [0/1847] 10.665379


  6%|▌         | 102/1847 [00:08<02:21, 12.36it/s]

[0/10] [100/1847] 9.335777


 11%|█         | 202/1847 [00:16<02:17, 12.00it/s]

[0/10] [200/1847] 8.538736


 16%|█▋        | 302/1847 [00:24<02:19, 11.11it/s]

[0/10] [300/1847] 7.021785


 22%|██▏       | 402/1847 [00:32<02:05, 11.48it/s]

[0/10] [400/1847] 6.846267


 27%|██▋       | 502/1847 [00:41<01:53, 11.85it/s]

[0/10] [500/1847] 5.674328


 33%|███▎      | 602/1847 [00:49<01:39, 12.52it/s]

[0/10] [600/1847] 6.264143


 38%|███▊      | 702/1847 [00:57<01:31, 12.45it/s]

[0/10] [700/1847] 4.851836


 43%|████▎     | 802/1847 [01:05<01:35, 10.92it/s]

[0/10] [800/1847] 4.647188


 49%|████▉     | 902/1847 [01:13<01:13, 12.92it/s]

[0/10] [900/1847] 4.425786


 54%|█████▍    | 1002/1847 [01:21<01:10, 12.00it/s]

[0/10] [1000/1847] 4.906311


 60%|█████▉    | 1102/1847 [01:29<01:05, 11.40it/s]

[0/10] [1100/1847] 4.086486


 65%|██████▌   | 1202/1847 [01:38<00:53, 12.12it/s]

[0/10] [1200/1847] 4.000368


 70%|███████   | 1302/1847 [01:46<00:45, 11.97it/s]

[0/10] [1300/1847] 3.090740


 76%|███████▌  | 1404/1847 [01:54<00:31, 14.14it/s]

[0/10] [1400/1847] 3.678691


 81%|████████▏ | 1502/1847 [02:02<00:27, 12.54it/s]

[0/10] [1500/1847] 3.257125


 87%|████████▋ | 1602/1847 [02:10<00:19, 12.38it/s]

[0/10] [1600/1847] 3.567274


 92%|█████████▏| 1702/1847 [02:18<00:11, 12.91it/s]

[0/10] [1700/1847] 3.341910


 98%|█████████▊| 1802/1847 [02:27<00:03, 12.96it/s]

[0/10] [1800/1847] 3.799031


100%|██████████| 1847/1847 [02:30<00:00, 12.25it/s]
  0%|          | 2/1847 [00:00<02:16, 13.54it/s]

[1/10] [0/1847] 2.596624


  6%|▌         | 102/1847 [00:08<02:19, 12.55it/s]

[1/10] [100/1847] 2.063027


 11%|█         | 204/1847 [00:16<02:06, 12.98it/s]

[1/10] [200/1847] 3.133280


 16%|█▋        | 302/1847 [00:24<02:18, 11.18it/s]

[1/10] [300/1847] 3.293401


 22%|██▏       | 402/1847 [00:32<02:05, 11.50it/s]

[1/10] [400/1847] 3.286769


 27%|██▋       | 502/1847 [00:40<01:52, 12.01it/s]

[1/10] [500/1847] 2.455362


 33%|███▎      | 602/1847 [00:49<01:51, 11.13it/s]

[1/10] [600/1847] 2.783832


 38%|███▊      | 702/1847 [00:57<01:30, 12.62it/s]

[1/10] [700/1847] 2.380018


 43%|████▎     | 802/1847 [01:05<01:33, 11.12it/s]

[1/10] [800/1847] 2.640423


 49%|████▉     | 902/1847 [01:13<01:12, 13.03it/s]

[1/10] [900/1847] 2.031015


 54%|█████▍    | 1004/1847 [01:21<01:06, 12.67it/s]

[1/10] [1000/1847] 2.957128


 60%|█████▉    | 1102/1847 [01:30<01:12, 10.32it/s]

[1/10] [1100/1847] 2.324261


 65%|██████▌   | 1202/1847 [01:38<00:52, 12.22it/s]

[1/10] [1200/1847] 2.474546


 70%|███████   | 1302/1847 [01:46<00:45, 12.08it/s]

[1/10] [1300/1847] 1.785127


 76%|███████▌  | 1404/1847 [01:54<00:31, 14.07it/s]

[1/10] [1400/1847] 2.714947


 81%|████████▏ | 1502/1847 [02:02<00:27, 12.65it/s]

[1/10] [1500/1847] 2.519907


 87%|████████▋ | 1602/1847 [02:10<00:19, 12.46it/s]

[1/10] [1600/1847] 2.657184


 92%|█████████▏| 1702/1847 [02:18<00:11, 12.80it/s]

[1/10] [1700/1847] 2.265561


 98%|█████████▊| 1802/1847 [02:26<00:03, 12.97it/s]

[1/10] [1800/1847] 2.479088


100%|██████████| 1847/1847 [02:30<00:00, 12.27it/s]
  0%|          | 2/1847 [00:00<02:17, 13.38it/s]

[2/10] [0/1847] 1.355536


  6%|▌         | 102/1847 [00:08<02:18, 12.58it/s]

[2/10] [100/1847] 1.829472


 11%|█         | 204/1847 [00:16<02:06, 12.97it/s]

[2/10] [200/1847] 3.072183


 16%|█▋        | 302/1847 [00:24<02:17, 11.20it/s]

[2/10] [300/1847] 2.776514


 22%|██▏       | 402/1847 [00:32<02:04, 11.58it/s]

[2/10] [400/1847] 2.509109


 27%|██▋       | 502/1847 [00:40<01:52, 12.00it/s]

[2/10] [500/1847] 2.494088


 33%|███▎      | 602/1847 [00:48<01:39, 12.58it/s]

[2/10] [600/1847] 2.614070


 38%|███▊      | 702/1847 [00:56<01:30, 12.59it/s]

[2/10] [700/1847] 2.153256


 43%|████▎     | 802/1847 [01:05<01:34, 11.04it/s]

[2/10] [800/1847] 2.136165


 49%|████▉     | 902/1847 [01:13<01:12, 13.04it/s]

[2/10] [900/1847] 1.631295


 54%|█████▍    | 1002/1847 [01:20<01:09, 12.16it/s]

[2/10] [1000/1847] 2.528520


 60%|█████▉    | 1102/1847 [01:29<01:05, 11.33it/s]

[2/10] [1100/1847] 2.466581


 65%|██████▌   | 1202/1847 [01:37<00:52, 12.25it/s]

[2/10] [1200/1847] 2.268657


 70%|███████   | 1302/1847 [01:45<00:45, 12.11it/s]

[2/10] [1300/1847] 1.461890


 76%|███████▌  | 1404/1847 [01:53<00:31, 14.12it/s]

[2/10] [1400/1847] 2.195203


 81%|████████▏ | 1502/1847 [02:01<00:27, 12.71it/s]

[2/10] [1500/1847] 1.952316


 87%|████████▋ | 1604/1847 [02:09<00:18, 13.31it/s]

[2/10] [1600/1847] 2.421629


 92%|█████████▏| 1702/1847 [02:17<00:11, 13.00it/s]

[2/10] [1700/1847] 2.228364


 98%|█████████▊| 1802/1847 [02:25<00:03, 13.03it/s]

[2/10] [1800/1847] 2.585068


100%|██████████| 1847/1847 [02:29<00:00, 12.36it/s]
  0%|          | 2/1847 [00:00<02:16, 13.48it/s]

[3/10] [0/1847] 1.229987


  6%|▌         | 102/1847 [00:08<02:18, 12.62it/s]

[3/10] [100/1847] 1.617241


 11%|█         | 204/1847 [00:16<02:07, 12.93it/s]

[3/10] [200/1847] 2.109239


 16%|█▋        | 302/1847 [00:24<02:17, 11.25it/s]

[3/10] [300/1847] 2.617620


 22%|██▏       | 402/1847 [00:32<02:05, 11.49it/s]

[3/10] [400/1847] 2.407933


 27%|██▋       | 502/1847 [00:40<01:52, 11.97it/s]

[3/10] [500/1847] 1.823072


 33%|███▎      | 602/1847 [00:48<01:38, 12.58it/s]

[3/10] [600/1847] 2.416474


 38%|███▊      | 702/1847 [00:56<01:30, 12.61it/s]

[3/10] [700/1847] 1.905880


 43%|████▎     | 802/1847 [01:05<01:34, 11.08it/s]

[3/10] [800/1847] 1.844514


 49%|████▉     | 902/1847 [01:12<01:12, 13.00it/s]

[3/10] [900/1847] 1.378018


 54%|█████▍    | 1002/1847 [01:20<01:09, 12.14it/s]

[3/10] [1000/1847] 2.200979


 60%|█████▉    | 1102/1847 [01:29<01:04, 11.55it/s]

[3/10] [1100/1847] 1.725828


 65%|██████▌   | 1202/1847 [01:37<00:52, 12.19it/s]

[3/10] [1200/1847] 1.494859


 70%|███████   | 1302/1847 [01:45<00:45, 12.05it/s]

[3/10] [1300/1847] 1.143719


 76%|███████▌  | 1404/1847 [01:53<00:31, 14.11it/s]

[3/10] [1400/1847] 1.735090


 81%|████████▏ | 1502/1847 [02:01<00:27, 12.71it/s]

[3/10] [1500/1847] 1.646793


 87%|████████▋ | 1604/1847 [02:09<00:18, 13.29it/s]

[3/10] [1600/1847] 1.577593


 92%|█████████▏| 1702/1847 [02:17<00:11, 13.04it/s]

[3/10] [1700/1847] 1.652744


 98%|█████████▊| 1802/1847 [02:25<00:03, 13.01it/s]

[3/10] [1800/1847] 1.953696


100%|██████████| 1847/1847 [02:29<00:00, 12.36it/s]
  0%|          | 2/1847 [00:00<02:15, 13.64it/s]

[4/10] [0/1847] 0.749555


  6%|▌         | 102/1847 [00:08<02:18, 12.60it/s]

[4/10] [100/1847] 1.342145


 11%|█         | 204/1847 [00:16<02:05, 13.09it/s]

[4/10] [200/1847] 1.947136


 16%|█▋        | 302/1847 [00:24<02:17, 11.27it/s]

[4/10] [300/1847] 1.758126


 22%|██▏       | 402/1847 [00:32<02:04, 11.63it/s]

[4/10] [400/1847] 1.888599


 27%|██▋       | 502/1847 [00:40<01:51, 12.03it/s]

[4/10] [500/1847] 1.453396


 33%|███▎      | 602/1847 [00:48<01:38, 12.58it/s]

[4/10] [600/1847] 1.770952


 38%|███▊      | 702/1847 [00:56<01:31, 12.58it/s]

[4/10] [700/1847] 1.651220


 43%|████▎     | 802/1847 [01:05<01:33, 11.13it/s]

[4/10] [800/1847] 1.634023


 49%|████▉     | 902/1847 [01:12<01:12, 13.02it/s]

[4/10] [900/1847] 1.487540


 54%|█████▍    | 1002/1847 [01:20<01:09, 12.14it/s]

[4/10] [1000/1847] 1.730968


 60%|█████▉    | 1102/1847 [01:28<01:04, 11.57it/s]

[4/10] [1100/1847] 1.625255


 65%|██████▌   | 1202/1847 [01:37<00:52, 12.23it/s]

[4/10] [1200/1847] 1.150649


 70%|███████   | 1302/1847 [01:45<00:45, 12.05it/s]

[4/10] [1300/1847] 0.925997


 76%|███████▌  | 1404/1847 [01:53<00:31, 14.17it/s]

[4/10] [1400/1847] 1.388767


 81%|████████▏ | 1502/1847 [02:01<00:26, 12.95it/s]

[4/10] [1500/1847] 1.447459


 87%|████████▋ | 1602/1847 [02:09<00:19, 12.40it/s]

[4/10] [1600/1847] 1.389138


 92%|█████████▏| 1702/1847 [02:17<00:11, 13.07it/s]

[4/10] [1700/1847] 1.678451


 98%|█████████▊| 1802/1847 [02:25<00:03, 13.03it/s]

[4/10] [1800/1847] 1.781510


100%|██████████| 1847/1847 [02:29<00:00, 12.38it/s]
  0%|          | 2/1847 [00:00<02:15, 13.62it/s]

[5/10] [0/1847] 1.074857


  6%|▌         | 102/1847 [00:08<02:18, 12.61it/s]

[5/10] [100/1847] 0.999336


 11%|█         | 204/1847 [00:16<02:06, 13.04it/s]

[5/10] [200/1847] 1.949517


 16%|█▋        | 302/1847 [00:24<02:17, 11.23it/s]

[5/10] [300/1847] 1.735719


 22%|██▏       | 402/1847 [00:32<02:03, 11.67it/s]

[5/10] [400/1847] 1.537054


 27%|██▋       | 502/1847 [00:40<01:51, 12.06it/s]

[5/10] [500/1847] 1.257279


 33%|███▎      | 602/1847 [00:48<01:38, 12.63it/s]

[5/10] [600/1847] 1.858058


 38%|███▊      | 702/1847 [00:56<01:31, 12.49it/s]

[5/10] [700/1847] 1.386567


 43%|████▎     | 802/1847 [01:04<01:33, 11.12it/s]

[5/10] [800/1847] 1.534361


 49%|████▉     | 902/1847 [01:12<01:12, 13.03it/s]

[5/10] [900/1847] 1.056734


 54%|█████▍    | 1004/1847 [01:20<01:06, 12.67it/s]

[5/10] [1000/1847] 1.532127


 60%|█████▉    | 1102/1847 [01:28<01:04, 11.48it/s]

[5/10] [1100/1847] 1.204298


 65%|██████▌   | 1202/1847 [01:36<00:52, 12.25it/s]

[5/10] [1200/1847] 1.055848


 70%|███████   | 1302/1847 [01:45<00:45, 12.08it/s]

[5/10] [1300/1847] 0.812912


 76%|███████▌  | 1404/1847 [01:53<00:31, 14.00it/s]

[5/10] [1400/1847] 1.283862


 81%|████████▏ | 1502/1847 [02:01<00:27, 12.61it/s]

[5/10] [1500/1847] 0.910932


 87%|████████▋ | 1602/1847 [02:09<00:19, 12.43it/s]

[5/10] [1600/1847] 1.296483


 92%|█████████▏| 1702/1847 [02:17<00:11, 13.01it/s]

[5/10] [1700/1847] 1.444272


 98%|█████████▊| 1802/1847 [02:25<00:03, 13.06it/s]

[5/10] [1800/1847] 1.708743


100%|██████████| 1847/1847 [02:29<00:00, 12.38it/s]
  0%|          | 2/1847 [00:00<02:18, 13.30it/s]

[6/10] [0/1847] 0.853365


  6%|▌         | 102/1847 [00:08<02:20, 12.43it/s]

[6/10] [100/1847] 1.096057


 11%|█         | 202/1847 [00:16<02:15, 12.16it/s]

[6/10] [200/1847] 1.667867


 16%|█▋        | 302/1847 [00:24<02:18, 11.18it/s]

[6/10] [300/1847] 1.372652


 22%|██▏       | 402/1847 [00:32<02:08, 11.28it/s]

[6/10] [400/1847] 1.459770


 27%|██▋       | 502/1847 [00:41<01:53, 11.83it/s]

[6/10] [500/1847] 0.669698


 33%|███▎      | 602/1847 [00:49<01:39, 12.48it/s]

[6/10] [600/1847] 1.673789


 38%|███▊      | 702/1847 [00:57<01:31, 12.47it/s]

[6/10] [700/1847] 1.068219


 43%|████▎     | 802/1847 [01:05<01:35, 10.91it/s]

[6/10] [800/1847] 1.259984


 49%|████▉     | 902/1847 [01:13<01:13, 12.82it/s]

[6/10] [900/1847] 0.763822


 54%|█████▍    | 1002/1847 [01:21<01:10, 11.97it/s]

[6/10] [1000/1847] 1.203903


 60%|█████▉    | 1102/1847 [01:30<01:05, 11.42it/s]

[6/10] [1100/1847] 1.279786


 65%|██████▌   | 1202/1847 [01:38<00:53, 12.14it/s]

[6/10] [1200/1847] 0.747877


 70%|███████   | 1302/1847 [01:46<00:45, 11.93it/s]

[6/10] [1300/1847] 0.628024


 76%|███████▌  | 1404/1847 [01:54<00:31, 14.05it/s]

[6/10] [1400/1847] 0.961440


 81%|████████▏ | 1502/1847 [02:02<00:27, 12.67it/s]

[6/10] [1500/1847] 0.892373


 87%|████████▋ | 1602/1847 [02:10<00:19, 12.46it/s]

[6/10] [1600/1847] 1.175471


 92%|█████████▏| 1702/1847 [02:19<00:11, 12.89it/s]

[6/10] [1700/1847] 0.922353


 98%|█████████▊| 1802/1847 [02:27<00:03, 12.84it/s]

[6/10] [1800/1847] 1.079594


100%|██████████| 1847/1847 [02:30<00:00, 12.24it/s]
  0%|          | 2/1847 [00:00<02:19, 13.22it/s]

[7/10] [0/1847] 0.665298


  6%|▌         | 102/1847 [00:08<02:19, 12.48it/s]

[7/10] [100/1847] 0.780562


 11%|█         | 204/1847 [00:16<02:07, 12.85it/s]

[7/10] [200/1847] 1.690113


 16%|█▋        | 302/1847 [00:24<02:18, 11.13it/s]

[7/10] [300/1847] 0.937351


 22%|██▏       | 402/1847 [00:32<02:05, 11.52it/s]

[7/10] [400/1847] 1.477572


 27%|██▋       | 502/1847 [00:41<01:53, 11.86it/s]

[7/10] [500/1847] 1.191249


 33%|███▎      | 602/1847 [00:49<01:40, 12.42it/s]

[7/10] [600/1847] 1.432845


 38%|███▊      | 702/1847 [00:57<01:31, 12.47it/s]

[7/10] [700/1847] 1.101409


 43%|████▎     | 802/1847 [01:05<01:35, 10.97it/s]

[7/10] [800/1847] 1.167340


 49%|████▉     | 902/1847 [01:13<01:13, 12.81it/s]

[7/10] [900/1847] 1.023558


 54%|█████▍    | 1002/1847 [01:21<01:10, 12.03it/s]

[7/10] [1000/1847] 1.202716


 60%|█████▉    | 1102/1847 [01:30<01:05, 11.40it/s]

[7/10] [1100/1847] 1.064623


 65%|██████▌   | 1202/1847 [01:38<00:53, 12.11it/s]

[7/10] [1200/1847] 0.643183


 70%|███████   | 1302/1847 [01:46<00:45, 12.02it/s]

[7/10] [1300/1847] 0.653715


 76%|███████▌  | 1404/1847 [01:54<00:31, 13.88it/s]

[7/10] [1400/1847] 1.171494


 81%|████████▏ | 1502/1847 [02:03<00:27, 12.58it/s]

[7/10] [1500/1847] 0.792394


 87%|████████▋ | 1604/1847 [02:11<00:18, 13.22it/s]

[7/10] [1600/1847] 1.078761


 92%|█████████▏| 1702/1847 [02:19<00:11, 12.30it/s]

[7/10] [1700/1847] 1.202161


 98%|█████████▊| 1802/1847 [02:27<00:03, 12.90it/s]

[7/10] [1800/1847] 1.181513


100%|██████████| 1847/1847 [02:31<00:00, 12.20it/s]
  0%|          | 2/1847 [00:00<02:16, 13.50it/s]

[8/10] [0/1847] 0.864137


  6%|▌         | 102/1847 [00:08<02:19, 12.51it/s]

[8/10] [100/1847] 0.625619


 11%|█         | 204/1847 [00:16<02:08, 12.83it/s]

[8/10] [200/1847] 1.453105


 16%|█▋        | 302/1847 [00:24<02:19, 11.11it/s]

[8/10] [300/1847] 1.169810


 22%|██▏       | 402/1847 [00:32<02:05, 11.49it/s]

[8/10] [400/1847] 0.953350


 27%|██▋       | 502/1847 [00:41<01:53, 11.86it/s]

[8/10] [500/1847] 1.155094


 33%|███▎      | 602/1847 [00:49<01:40, 12.42it/s]

[8/10] [600/1847] 1.418007


 38%|███▊      | 702/1847 [00:57<01:32, 12.34it/s]

[8/10] [700/1847] 1.042644


 43%|████▎     | 802/1847 [01:06<01:34, 11.00it/s]

[8/10] [800/1847] 1.212447


 49%|████▉     | 902/1847 [01:14<01:13, 12.88it/s]

[8/10] [900/1847] 1.010972


 54%|█████▍    | 1002/1847 [01:22<01:10, 11.96it/s]

[8/10] [1000/1847] 0.947098


 60%|█████▉    | 1102/1847 [01:30<01:05, 11.35it/s]

[8/10] [1100/1847] 0.985352


 65%|██████▌   | 1202/1847 [01:38<00:53, 12.00it/s]

[8/10] [1200/1847] 0.572532


 70%|███████   | 1302/1847 [01:46<00:45, 11.95it/s]

[8/10] [1300/1847] 0.737755


 76%|███████▌  | 1404/1847 [01:55<00:31, 13.98it/s]

[8/10] [1400/1847] 0.878077


 81%|████████▏ | 1502/1847 [02:03<00:27, 12.61it/s]

[8/10] [1500/1847] 0.618264


 87%|████████▋ | 1602/1847 [02:11<00:19, 12.38it/s]

[8/10] [1600/1847] 1.442960


 92%|█████████▏| 1702/1847 [02:19<00:11, 12.91it/s]

[8/10] [1700/1847] 1.156918


 98%|█████████▊| 1802/1847 [02:27<00:03, 12.89it/s]

[8/10] [1800/1847] 1.475024


100%|██████████| 1847/1847 [02:31<00:00, 12.19it/s]
  0%|          | 2/1847 [00:00<02:19, 13.19it/s]

[9/10] [0/1847] 0.731225


  6%|▌         | 102/1847 [00:08<02:18, 12.56it/s]

[9/10] [100/1847] 0.597182


 11%|█         | 204/1847 [00:16<02:06, 12.98it/s]

[9/10] [200/1847] 1.378540


 16%|█▋        | 302/1847 [00:24<02:17, 11.23it/s]

[9/10] [300/1847] 0.948219


 22%|██▏       | 402/1847 [00:32<02:05, 11.52it/s]

[9/10] [400/1847] 1.230427


 27%|██▋       | 502/1847 [00:40<01:52, 11.95it/s]

[9/10] [500/1847] 0.830818


 33%|███▎      | 602/1847 [00:49<01:39, 12.51it/s]

[9/10] [600/1847] 1.252064


 38%|███▊      | 702/1847 [00:57<01:32, 12.39it/s]

[9/10] [700/1847] 1.209735


 43%|████▎     | 802/1847 [01:05<01:35, 11.00it/s]

[9/10] [800/1847] 1.149152


 49%|████▉     | 902/1847 [01:13<01:13, 12.88it/s]

[9/10] [900/1847] 0.662149


 54%|█████▍    | 1004/1847 [01:21<01:06, 12.62it/s]

[9/10] [1000/1847] 0.806920


 60%|█████▉    | 1102/1847 [01:29<01:04, 11.51it/s]

[9/10] [1100/1847] 1.013276


 65%|██████▌   | 1202/1847 [01:37<00:53, 12.15it/s]

[9/10] [1200/1847] 0.583889


 70%|███████   | 1302/1847 [01:45<00:45, 11.96it/s]

[9/10] [1300/1847] 0.620404


 76%|███████▌  | 1402/1847 [01:53<00:32, 13.60it/s]

[9/10] [1400/1847] 0.735287


 81%|████████▏ | 1502/1847 [02:02<00:27, 12.56it/s]

[9/10] [1500/1847] 0.667326


 87%|████████▋ | 1604/1847 [02:10<00:18, 13.22it/s]

[9/10] [1600/1847] 1.082116


 92%|█████████▏| 1702/1847 [02:18<00:11, 12.83it/s]

[9/10] [1700/1847] 0.832226


 98%|█████████▊| 1802/1847 [02:26<00:03, 12.96it/s]

[9/10] [1800/1847] 1.297751


100%|██████████| 1847/1847 [02:30<00:00, 12.27it/s]


### Training

In [20]:
# backward pass시 gradient 정보가 손실되지 않게 하려고 사용(loss에 scale factor를 곱해서 gradient 값이 너무 작아지는 것을 방지)
scaler = GradScaler()
best_score, best_checkpoint = 0, 0

for epoch in range(n_epochs):
    start_time = time.time()
    batch_loss = []
    model.train()
    for step, batch in enumerate(train_loader):
        input_ids, segment_ids, input_masks, gating_ids, target_ids, guids = [
            b.to(device) if not isinstance(b, list) else b for b in batch
        ]
        # mask
        if cfg.mask:
            change_mask_prop = 0.8
            mask_p = random.random()
            if cfg.mask and mask_p < change_mask_prop:
                input_ids = custom_to_mask(input_ids)
        # teacher forcing
        if (
            teacher_forcing > 0.0
            and random.random() < teacher_forcing
        ):
            tf = target_ids
        else:
            tf = None

        optimizer.zero_grad()  # optimizer는 input으로 model parameter를 가진다 -> zero_grad()로 파라미터 컨드롤 가능

        with autocast():  # 밑에 해당하는 코드를 자동으로 mixed precision으로 변환시켜서 실행
            all_point_outputs, all_gate_outputs = model(
                input_ids, segment_ids, input_masks, target_ids.size(-1), tf
            )
            # generation loss
            loss_1 = loss_fnc_1(
                all_point_outputs.contiguous(),
                target_ids.contiguous().view(-1),
                tokenizer.pad_token_id,
            )
            # gating loss
            loss_2 = loss_fnc_2(
                all_gate_outputs.contiguous().view(-1, cfg.n_gate),
                gating_ids.contiguous().view(-1),
            )
            loss = loss_1 + loss_2
        batch_loss.append(loss.item())

        scaler.scale(loss).backward()
        nn.utils.clip_grad_norm_(model.parameters(), cfg.max_grad_norm)
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()

        # global_step 추가 부분
        wandb.log({"train/learning_rate": get_lr(scheduler),
                   "train/epoch": epoch
                   })
        if step % 100 == 0:
            print(
                f"[{epoch}/{n_epochs}] [{step}/{len(train_loader)}] loss: {loss.item()} gen: {loss_1.item()} gate: {loss_2.item()}"
            )

            # -- train 단계에서 Loss, Accuracy 로그 저장
            wandb.log({
                "train/loss": loss.item(),
                "train/gen_loss": loss_1.item(),
                "train/gate_loss": loss_2.item(),
            })

    predictions, p_logits, p_idx, g_logits = inference(model, dev_loader, processor, device, cfg.n_gate)
    eval_result = _evaluation(predictions, dev_labels, slot_meta)

    # -- eval 단계에서 Loss, Accuracy 로그 저장
    wandb.log({
        "eval/join_goal_acc": eval_result["joint_goal_accuracy"],
        "eval/turn_slot_f1": eval_result["turn_slot_f1"],
        "eval/turn_slot_acc": eval_result["turn_slot_accuracy"],
    })

    for k, v in eval_result.items():
        print(f"{k}: {v}")

    if best_score < eval_result['joint_goal_accuracy']:
        cpprint(f"--Update Best checkpoint!, epoch: {epoch+1}")
        best_score = eval_result['joint_goal_accuracy']
        best_checkpoint = epoch
        if not os.path.isdir(cfg.model_dir):
            os.makedirs(cfg.model_dir)
        print("--Saving best model checkpoint")
        torch.save(model.state_dict(), f"{cfg.model_dir}/{wandb.run.name}/best.pth")
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler.state_dict': scheduler.state_dict(),
            'loss': loss.item(),
            'gen_loss': loss_1.item(),
            'gate_loss': loss_2.item(),
        }, os.path.join(f"{cfg.model_dir}/{wandb.run.name}", "training_best_checkpoint.bin"))
        
        # logit 저장
        np.save(os.path.join(f"{cfg.model_dir}/{wandb.run.name}", r'p_logits.npy'), p_logits)
        np.save(os.path.join(f"{cfg.model_dir}/{wandb.run.name}", r'p_idx.npy'), p_idx)
        np.save(os.path.join(f"{cfg.model_dir}/{wandb.run.name}", r'g_logits.npy'), g_logits)

        
    torch.save(model.state_dict(), f"{cfg.model_dir}/{wandb.run.name}/last.pth")
    print(f"time for 1 epoch: {time.time() - start_time}")

[0/30] [0/11532] loss: 11.108061790466309 gen: 9.425432205200195 gate: 1.6826297044754028
[0/30] [100/11532] loss: 2.7578675746917725 gen: 2.32616925239563 gate: 0.4316982626914978
[0/30] [200/11532] loss: 2.347721815109253 gen: 1.716150164604187 gate: 0.6315717101097107
[0/30] [300/11532] loss: 2.3156487941741943 gen: 1.7092682123184204 gate: 0.6063806414604187
[0/30] [400/11532] loss: 2.379779577255249 gen: 1.5815484523773193 gate: 0.7982310652732849
[0/30] [500/11532] loss: 1.5391714572906494 gen: 1.0021620988845825 gate: 0.5370092988014221
[0/30] [600/11532] loss: 1.3581907749176025 gen: 0.846183717250824 gate: 0.5120069980621338
[0/30] [700/11532] loss: 1.8442423343658447 gen: 1.2138514518737793 gate: 0.6303908228874207
[0/30] [800/11532] loss: 1.0953497886657715 gen: 0.7124114036560059 gate: 0.382938414812088
[0/30] [900/11532] loss: 1.4426168203353882 gen: 0.9204893708229065 gate: 0.5221274495124817
[0/30] [1000/11532] loss: 1.4483888149261475 gen: 0.9628758430480957 gate: 0.485

100%|██████████| 640/640 [02:01<00:00,  5.28it/s]


{'joint_goal_accuracy': 0.30656506447831183, 'turn_slot_accuracy': 0.9659024792670887, 'turn_slot_f1': 0.8526494571044264}
joint_goal_accuracy: 0.30656506447831183
turn_slot_accuracy: 0.9659024792670887
turn_slot_f1: 0.8526494571044264
'--Update Best checkpoint!, epoch: 1'
--Saving best model checkpoint




time for 1 epoch: 2073.628895998001
[1/30] [0/11532] loss: 0.09361106157302856 gen: 0.06249314546585083 gate: 0.031117917969822884
[1/30] [100/11532] loss: 0.20754575729370117 gen: 0.15954691171646118 gate: 0.04799885302782059
[1/30] [200/11532] loss: 0.22125911712646484 gen: 0.15587329864501953 gate: 0.06538582593202591
[1/30] [300/11532] loss: 0.20811662077903748 gen: 0.18028445541858673 gate: 0.027832161635160446
[1/30] [400/11532] loss: 0.4564060568809509 gen: 0.2856338322162628 gate: 0.1707722246646881
[1/30] [500/11532] loss: 0.1410551518201828 gen: 0.12343443930149078 gate: 0.017620714381337166
[1/30] [600/11532] loss: 0.16064515709877014 gen: 0.1015995666384697 gate: 0.05904558673501015
[1/30] [700/11532] loss: 0.2400299608707428 gen: 0.22909818589687347 gate: 0.010931774973869324
[1/30] [800/11532] loss: 0.13029824197292328 gen: 0.10959453135728836 gate: 0.020703714340925217
[1/30] [900/11532] loss: 0.20407912135124207 gen: 0.1551753580570221 gate: 0.04890377074480057
[1/30] [

100%|██████████| 640/640 [01:59<00:00,  5.34it/s]


{'joint_goal_accuracy': 0.5341930441578742, 'turn_slot_accuracy': 0.9830749858885974, 'turn_slot_f1': 0.9208502338396105}
joint_goal_accuracy: 0.5341930441578742
turn_slot_accuracy: 0.9830749858885974
turn_slot_f1: 0.9208502338396105
'--Update Best checkpoint!, epoch: 2'
--Saving best model checkpoint
time for 1 epoch: 2014.5532901287079
[2/30] [0/11532] loss: 0.06734965741634369 gen: 0.033441610634326935 gate: 0.033908046782016754
[2/30] [100/11532] loss: 0.1004311665892601 gen: 0.09774937480688095 gate: 0.00268178922124207
[2/30] [200/11532] loss: 0.07389026880264282 gen: 0.07244566082954407 gate: 0.0014446059940382838
[2/30] [300/11532] loss: 0.15821059048175812 gen: 0.11706220358610153 gate: 0.041148390620946884
[2/30] [400/11532] loss: 0.07149287313222885 gen: 0.07047785818576813 gate: 0.0010150144807994366
[2/30] [500/11532] loss: 0.01769321970641613 gen: 0.017462829127907753 gate: 0.00023038990912027657
[2/30] [600/11532] loss: 0.06118971109390259 gen: 0.06071899086236954 gate: 

100%|██████████| 640/640 [02:08<00:00,  4.97it/s]


{'joint_goal_accuracy': 0.5726846424384525, 'turn_slot_accuracy': 0.9852372888715322, 'turn_slot_f1': 0.9324274957224725}
joint_goal_accuracy: 0.5726846424384525
turn_slot_accuracy: 0.9852372888715322
turn_slot_f1: 0.9324274957224725
'--Update Best checkpoint!, epoch: 3'
--Saving best model checkpoint
time for 1 epoch: 2021.0467655658722
[3/30] [0/11532] loss: 0.029385985806584358 gen: 0.028227413073182106 gate: 0.0011585726169869304
[3/30] [100/11532] loss: 0.084633007645607 gen: 0.05609429255127907 gate: 0.028538716956973076
[3/30] [200/11532] loss: 0.05198601260781288 gen: 0.03538653999567032 gate: 0.016599472612142563
[3/30] [300/11532] loss: 0.020681306719779968 gen: 0.020358925685286522 gate: 0.00032238088897429407
[3/30] [400/11532] loss: 0.049501948058605194 gen: 0.04915971681475639 gate: 0.0003422326117288321
[3/30] [500/11532] loss: 0.026569433510303497 gen: 0.02637084759771824 gate: 0.0001985859271371737
[3/30] [600/11532] loss: 0.014844677411019802 gen: 0.013720237649977207

100%|██████████| 640/640 [02:03<00:00,  5.18it/s]


{'joint_goal_accuracy': 0.6016021883548262, 'turn_slot_accuracy': 0.9865138291867598, 'turn_slot_f1': 0.9384114234690889}
joint_goal_accuracy: 0.6016021883548262
turn_slot_accuracy: 0.9865138291867598
turn_slot_f1: 0.9384114234690889
'--Update Best checkpoint!, epoch: 4'
--Saving best model checkpoint
time for 1 epoch: 2008.4485597610474
[4/30] [0/11532] loss: 0.023678027093410492 gen: 0.02345253899693489 gate: 0.00022548857668880373
[4/30] [100/11532] loss: 0.009143614210188389 gen: 0.008843515999615192 gate: 0.0003000978904310614
[4/30] [200/11532] loss: 0.017038311809301376 gen: 0.016969460994005203 gate: 6.885123730171472e-05
[4/30] [300/11532] loss: 0.059454675763845444 gen: 0.04054412245750427 gate: 0.01891055330634117
[4/30] [400/11532] loss: 0.017412766814231873 gen: 0.017315365374088287 gate: 9.740069799590856e-05
[4/30] [500/11532] loss: 0.002851261757314205 gen: 0.0021231130231171846 gate: 0.0007281487341970205
[4/30] [600/11532] loss: 0.08145368099212646 gen: 0.075400255620

100%|██████████| 640/640 [02:08<00:00,  4.98it/s]


{'joint_goal_accuracy': 0.6324736225087925, 'turn_slot_accuracy': 0.9875602448873364, 'turn_slot_f1': 0.943029939673893}
joint_goal_accuracy: 0.6324736225087925
turn_slot_accuracy: 0.9875602448873364
turn_slot_f1: 0.943029939673893
'--Update Best checkpoint!, epoch: 5'
--Saving best model checkpoint
time for 1 epoch: 2008.9846594333649
[5/30] [0/11532] loss: 0.03273043781518936 gen: 0.031737715005874634 gate: 0.0009927229257300496
[5/30] [100/11532] loss: 0.038450051099061966 gen: 0.021898755803704262 gate: 0.016551295295357704
[5/30] [200/11532] loss: 0.06076347827911377 gen: 0.05023500323295593 gate: 0.010528476908802986
[5/30] [300/11532] loss: 0.0210843738168478 gen: 0.020620200783014297 gate: 0.00046417213161475956
[5/30] [400/11532] loss: 0.009211939759552479 gen: 0.008962114341557026 gate: 0.0002498255344107747
[5/30] [500/11532] loss: 0.07514254748821259 gen: 0.06244382634758949 gate: 0.012698722071945667
[5/30] [600/11532] loss: 0.014777335338294506 gen: 0.014735832810401917 g

100%|██████████| 640/640 [01:57<00:00,  5.47it/s]


{'joint_goal_accuracy': 0.6483001172332943, 'turn_slot_accuracy': 0.9883808779471245, 'turn_slot_f1': 0.9470401107793761}
joint_goal_accuracy: 0.6483001172332943
turn_slot_accuracy: 0.9883808779471245
turn_slot_f1: 0.9470401107793761
'--Update Best checkpoint!, epoch: 6'
--Saving best model checkpoint
time for 1 epoch: 1977.2742235660553
[6/30] [0/11532] loss: 0.02199436165392399 gen: 0.021661851555109024 gate: 0.0003325092839077115
[6/30] [100/11532] loss: 0.005162607878446579 gen: 0.00494813546538353 gate: 0.0002144725585822016
[6/30] [200/11532] loss: 0.005406362470239401 gen: 0.0051657031290233135 gate: 0.00024065923935268074
[6/30] [300/11532] loss: 0.0504966638982296 gen: 0.018211327493190765 gate: 0.032285336405038834
[6/30] [400/11532] loss: 0.02641472965478897 gen: 0.021930547431111336 gate: 0.004484181758016348
[6/30] [500/11532] loss: 0.047193560749292374 gen: 0.037473272532224655 gate: 0.009720289148390293
[6/30] [600/11532] loss: 0.01963650994002819 gen: 0.0193142723292112

100%|██████████| 640/640 [02:03<00:00,  5.19it/s]


{'joint_goal_accuracy': 0.6604142243063696, 'turn_slot_accuracy': 0.9889800703399866, 'turn_slot_f1': 0.9497509843280312}
joint_goal_accuracy: 0.6604142243063696
turn_slot_accuracy: 0.9889800703399866
turn_slot_f1: 0.9497509843280312
'--Update Best checkpoint!, epoch: 7'
--Saving best model checkpoint
time for 1 epoch: 2016.9318068027496
[7/30] [0/11532] loss: 0.01290210336446762 gen: 0.012819147668778896 gate: 8.295550651382655e-05
[7/30] [100/11532] loss: 0.029746131971478462 gen: 0.028667930513620377 gate: 0.0010782015742734075
[7/30] [200/11532] loss: 0.001558623742312193 gen: 0.0015254704048857093 gate: 3.315329013275914e-05
[7/30] [300/11532] loss: 0.017276136204600334 gen: 0.017193207517266273 gate: 8.292892744066194e-05
[7/30] [400/11532] loss: 0.019255371764302254 gen: 0.01919250749051571 gate: 6.286500138230622e-05
[7/30] [500/11532] loss: 0.013023057952523232 gen: 0.013000217266380787 gate: 2.284087531734258e-05
[7/30] [600/11532] loss: 0.007248000241816044 gen: 0.0071524917

100%|██████████| 640/640 [02:09<00:00,  4.95it/s]


{'joint_goal_accuracy': 0.7172723720203205, 'turn_slot_accuracy': 0.990647388302732, 'turn_slot_f1': 0.9583635043139233}
joint_goal_accuracy: 0.7172723720203205
turn_slot_accuracy: 0.990647388302732
turn_slot_f1: 0.9583635043139233
'--Update Best checkpoint!, epoch: 8'
--Saving best model checkpoint
time for 1 epoch: 1997.1662514209747
[8/30] [0/11532] loss: 0.02223101258277893 gen: 0.01801660656929016 gate: 0.004214406944811344
[8/30] [100/11532] loss: 0.000232583362958394 gen: 0.00018156530859414488 gate: 5.101805800222792e-05
[8/30] [200/11532] loss: 0.01742294616997242 gen: 0.017396945506334305 gate: 2.600128391350154e-05
[8/30] [300/11532] loss: 0.004052742384374142 gen: 0.004015588201582432 gate: 3.715432103490457e-05
[8/30] [400/11532] loss: 0.014855742454528809 gen: 0.012403873726725578 gate: 0.002451868262141943
[8/30] [500/11532] loss: 0.012446306645870209 gen: 0.012392020784318447 gate: 5.4286127124214545e-05
[8/30] [600/11532] loss: 0.009309790097177029 gen: 0.0092166839167

100%|██████████| 640/640 [02:11<00:00,  4.85it/s]


{'joint_goal_accuracy': 0.738569753810082, 'turn_slot_accuracy': 0.9913724979375682, 'turn_slot_f1': 0.9620611235362687}
joint_goal_accuracy: 0.738569753810082
turn_slot_accuracy: 0.9913724979375682
turn_slot_f1: 0.9620611235362687
'--Update Best checkpoint!, epoch: 9'
--Saving best model checkpoint
time for 1 epoch: 1996.9108972549438
[9/30] [0/11532] loss: 0.02936474233865738 gen: 0.029148560017347336 gate: 0.00021618213213514537
[9/30] [100/11532] loss: 0.028821256011724472 gen: 0.02226240746676922 gate: 0.006558848079293966
[9/30] [200/11532] loss: 0.0015236546751111746 gen: 0.0014091130578890443 gate: 0.00011454167542979121
[9/30] [300/11532] loss: 0.01535082794725895 gen: 0.015076580457389355 gate: 0.0002742476062849164
[9/30] [400/11532] loss: 0.010925752110779285 gen: 0.010886790230870247 gate: 3.8961879909038544e-05
[9/30] [500/11532] loss: 0.03836457431316376 gen: 0.03820694983005524 gate: 0.00015762625844217837
[9/30] [600/11532] loss: 0.0028338374104350805 gen: 0.0027240323

100%|██████████| 640/640 [01:56<00:00,  5.51it/s]


{'joint_goal_accuracy': 0.7624071903087143, 'turn_slot_accuracy': 0.9922669445529981, 'turn_slot_f1': 0.9654928710598026}
joint_goal_accuracy: 0.7624071903087143
turn_slot_accuracy: 0.9922669445529981
turn_slot_f1: 0.9654928710598026
'--Update Best checkpoint!, epoch: 10'
--Saving best model checkpoint
time for 1 epoch: 1983.5019569396973
[10/30] [0/11532] loss: 4.777665890287608e-05 gen: 3.16936093440745e-05 gate: 1.6083047739812173e-05
[10/30] [100/11532] loss: 0.019606903195381165 gen: 0.00464093592017889 gate: 0.014965968206524849
[10/30] [200/11532] loss: 0.010011544451117516 gen: 0.010003702715039253 gate: 7.841753358661663e-06
[10/30] [300/11532] loss: 0.0024100830778479576 gen: 0.0023936403449624777 gate: 1.6442625565105118e-05
[10/30] [400/11532] loss: 0.00016718638653401285 gen: 0.00015398017421830446 gate: 1.3206206858740188e-05
[10/30] [500/11532] loss: 0.013650394976139069 gen: 0.01362286601215601 gate: 2.752900581981521e-05
[10/30] [600/11532] loss: 0.024291303008794785 g

100%|██████████| 640/640 [01:56<00:00,  5.50it/s]


{'joint_goal_accuracy': 0.7672919109026963, 'turn_slot_accuracy': 0.9924970691676485, 'turn_slot_f1': 0.9664243307691688}
joint_goal_accuracy: 0.7672919109026963
turn_slot_accuracy: 0.9924970691676485
turn_slot_f1: 0.9664243307691688
'--Update Best checkpoint!, epoch: 11'
--Saving best model checkpoint
time for 1 epoch: 1992.540019750595
[11/30] [0/11532] loss: 0.0016310367500409484 gen: 0.0016184357227757573 gate: 1.2601005437318236e-05
[11/30] [100/11532] loss: 0.013379117473959923 gen: 0.013303330168128014 gate: 7.57871093810536e-05
[11/30] [200/11532] loss: 0.04287368431687355 gen: 0.042850397527217865 gate: 2.3285476345336065e-05
[11/30] [300/11532] loss: 0.03472943976521492 gen: 0.029522690922021866 gate: 0.0052067493088543415
[11/30] [400/11532] loss: 0.007184793706983328 gen: 0.006946433335542679 gate: 0.0002383602550253272
[11/30] [500/11532] loss: 0.035782590508461 gen: 0.03537411615252495 gate: 0.0004084755200892687
[11/30] [600/11532] loss: 0.022640926763415337 gen: 0.02261

100%|██████████| 640/640 [02:00<00:00,  5.33it/s]


{'joint_goal_accuracy': 0.7549824150058617, 'turn_slot_accuracy': 0.9921757631019116, 'turn_slot_f1': 0.9659891856309657}
joint_goal_accuracy: 0.7549824150058617
turn_slot_accuracy: 0.9921757631019116
turn_slot_f1: 0.9659891856309657
time for 1 epoch: 1972.5795938968658
[12/30] [0/11532] loss: 0.0056601883843541145 gen: 0.005530955735594034 gate: 0.00012923266331199557
[12/30] [100/11532] loss: 0.008746468462049961 gen: 0.008606940507888794 gate: 0.0001395275758113712
[12/30] [200/11532] loss: 0.07519650459289551 gen: 0.012590725906193256 gate: 0.06260577589273453
[12/30] [300/11532] loss: 0.002152983797714114 gen: 0.002116201678290963 gate: 3.678204302559607e-05
[12/30] [400/11532] loss: 0.006280471105128527 gen: 0.006179462652653456 gate: 0.0001010083215078339
[12/30] [500/11532] loss: 0.0004375980352051556 gen: 0.0004067017580382526 gate: 3.089627716690302e-05
[12/30] [600/11532] loss: 0.002478567650541663 gen: 0.002454051049426198 gate: 2.451671207381878e-05
[12/30] [700/11532] los

100%|██████████| 640/640 [02:08<00:00,  4.99it/s]


{'joint_goal_accuracy': 0.7817506838608832, 'turn_slot_accuracy': 0.9929616603708086, 'turn_slot_f1': 0.969484086548804}
joint_goal_accuracy: 0.7817506838608832
turn_slot_accuracy: 0.9929616603708086
turn_slot_f1: 0.969484086548804
'--Update Best checkpoint!, epoch: 13'
--Saving best model checkpoint
time for 1 epoch: 1989.906635761261
[13/30] [0/11532] loss: 0.00028258844395168126 gen: 0.0002656223368830979 gate: 1.6966105249593966e-05
[13/30] [100/11532] loss: 0.022801868617534637 gen: 0.013682534918189049 gate: 0.009119332768023014
[13/30] [200/11532] loss: 0.0063471863977611065 gen: 0.006316629704087973 gate: 3.0556599085684866e-05
[13/30] [300/11532] loss: 0.0002373235474806279 gen: 0.0002129882195731625 gate: 2.4335322450497188e-05
[13/30] [400/11532] loss: 0.024991385638713837 gen: 0.024965688586235046 gate: 2.5696372176753357e-05
[13/30] [500/11532] loss: 0.01323141623288393 gen: 0.013122064992785454 gate: 0.00010935126192634925
[13/30] [600/11532] loss: 0.002837582491338253 ge

100%|██████████| 640/640 [01:59<00:00,  5.38it/s]


{'joint_goal_accuracy': 0.780969128565846, 'turn_slot_accuracy': 0.9930181060310056, 'turn_slot_f1': 0.9695407043832462}
joint_goal_accuracy: 0.780969128565846
turn_slot_accuracy: 0.9930181060310056
turn_slot_f1: 0.9695407043832462
time for 1 epoch: 1988.9818758964539
[14/30] [0/11532] loss: 0.005322791635990143 gen: 0.005307556129992008 gate: 1.5235423234116752e-05
[14/30] [100/11532] loss: 0.10342341661453247 gen: 0.05519573390483856 gate: 0.04822768270969391
[14/30] [200/11532] loss: 0.008139114826917648 gen: 0.008130906149744987 gate: 8.209026418626308e-06
[14/30] [300/11532] loss: 0.030233271420001984 gen: 0.030150294303894043 gate: 8.297725435113534e-05
[14/30] [400/11532] loss: 0.007879591546952724 gen: 0.007856134325265884 gate: 2.3456821509171277e-05
[14/30] [500/11532] loss: 0.0019798539578914642 gen: 0.001967934425920248 gate: 1.1919555618078448e-05
[14/30] [600/11532] loss: 0.0035060844384133816 gen: 0.003485727356746793 gate: 2.0357112589408644e-05
[14/30] [700/11532] loss

100%|██████████| 640/640 [02:06<00:00,  5.07it/s]


{'joint_goal_accuracy': 0.7887846815162173, 'turn_slot_accuracy': 0.9933133602535742, 'turn_slot_f1': 0.9708208953776947}
joint_goal_accuracy: 0.7887846815162173
turn_slot_accuracy: 0.9933133602535742
turn_slot_f1: 0.9708208953776947
'--Update Best checkpoint!, epoch: 15'
--Saving best model checkpoint
time for 1 epoch: 2003.8503375053406
[15/30] [0/11532] loss: 5.8988349337596446e-05 gen: 2.7140822567162104e-05 gate: 3.1847528589423746e-05
[15/30] [100/11532] loss: 0.003949010744690895 gen: 0.003944139927625656 gate: 4.870978955295868e-06
[15/30] [200/11532] loss: 3.035817098862026e-05 gen: 2.5777248083613813e-05 gate: 4.580923359753797e-06
[15/30] [300/11532] loss: 0.00019546897965483367 gen: 0.00018555382848717272 gate: 9.915152077155653e-06
[15/30] [400/11532] loss: 0.00029034106410108507 gen: 0.00023255572887137532 gate: 5.778532795375213e-05
[15/30] [500/11532] loss: 0.0008849863079376519 gen: 0.00087677629198879 gate: 8.209998668462504e-06
[15/30] [600/11532] loss: 4.96501415909

100%|██████████| 640/640 [02:10<00:00,  4.92it/s]


{'joint_goal_accuracy': 0.7891754591637359, 'turn_slot_accuracy': 0.9932221788024879, 'turn_slot_f1': 0.9708425625183984}
joint_goal_accuracy: 0.7891754591637359
turn_slot_accuracy: 0.9932221788024879
turn_slot_f1: 0.9708425625183984
'--Update Best checkpoint!, epoch: 16'
--Saving best model checkpoint
time for 1 epoch: 1989.5540425777435
[16/30] [0/11532] loss: 0.011355064809322357 gen: 0.011335636489093304 gate: 1.9427894585533068e-05
[16/30] [100/11532] loss: 0.024336298927664757 gen: 0.024297770112752914 gate: 3.8529204175574705e-05
[16/30] [200/11532] loss: 0.059497613459825516 gen: 0.014189337380230427 gate: 0.045308277010917664
[16/30] [300/11532] loss: 0.0036841135006397963 gen: 0.0036643242929130793 gate: 1.9789165889960714e-05
[16/30] [400/11532] loss: 0.0006004691240377724 gen: 0.0005920789553783834 gate: 8.390145012526773e-06
[16/30] [500/11532] loss: 0.0007414468564093113 gen: 0.0007326130289584398 gate: 8.833805622998625e-06
[16/30] [600/11532] loss: 0.010539093054831028 

100%|██████████| 640/640 [02:06<00:00,  5.08it/s]


{'joint_goal_accuracy': 0.798358733880422, 'turn_slot_accuracy': 0.9934566453909978, 'turn_slot_f1': 0.9719830708356535}
joint_goal_accuracy: 0.798358733880422
turn_slot_accuracy: 0.9934566453909978
turn_slot_f1: 0.9719830708356535
'--Update Best checkpoint!, epoch: 17'
--Saving best model checkpoint
time for 1 epoch: 1983.456375837326
[17/30] [0/11532] loss: 0.02573118358850479 gen: 0.01787540689110756 gate: 0.007855777628719807
[17/30] [100/11532] loss: 0.0004559470689855516 gen: 0.0004378654411993921 gate: 1.8081638700095937e-05
[17/30] [200/11532] loss: 0.008719815872609615 gen: 0.008701490238308907 gate: 1.8325838027521968e-05
[17/30] [300/11532] loss: 0.0009391955682076514 gen: 0.000923798477742821 gate: 1.5397068636957556e-05
[17/30] [400/11532] loss: 0.0019269516924396157 gen: 0.0018937239656224847 gate: 3.322774136904627e-05
[17/30] [500/11532] loss: 0.00016450832481496036 gen: 0.00012903843889944255 gate: 3.5469889553496614e-05
[17/30] [600/11532] loss: 0.027782145887613297 g

100%|██████████| 640/640 [01:59<00:00,  5.37it/s]


{'joint_goal_accuracy': 0.7967956232903478, 'turn_slot_accuracy': 0.9935434848682242, 'turn_slot_f1': 0.972443299298665}
joint_goal_accuracy: 0.7967956232903478
turn_slot_accuracy: 0.9935434848682242
turn_slot_f1: 0.972443299298665
time for 1 epoch: 1969.5665414333344
[18/30] [0/11532] loss: 0.00025117784389294684 gen: 0.00022903864737600088 gate: 2.2139203792903572e-05
[18/30] [100/11532] loss: 0.00818981695920229 gen: 0.008180505596101284 gate: 9.311233952757902e-06
[18/30] [200/11532] loss: 0.00033003787393681705 gen: 0.0003105640644207597 gate: 1.9473796783131547e-05
[18/30] [300/11532] loss: 0.0021255137398838997 gen: 0.0021125751081854105 gate: 1.2938515283167362e-05
[18/30] [400/11532] loss: 0.00639112014323473 gen: 0.006382703315466642 gate: 8.41674227558542e-06
[18/30] [500/11532] loss: 0.0018805591389536858 gen: 0.000719814735930413 gate: 0.0011607444612309337
[18/30] [600/11532] loss: 0.010171903297305107 gen: 0.01016136072576046 gate: 1.0542368727328721e-05
[18/30] [700/115

100%|██████████| 640/640 [01:59<00:00,  5.38it/s]


{'joint_goal_accuracy': 0.8077373974208675, 'turn_slot_accuracy': 0.9938387390907935, 'turn_slot_f1': 0.9733725348690715}
joint_goal_accuracy: 0.8077373974208675
turn_slot_accuracy: 0.9938387390907935
turn_slot_f1: 0.9733725348690715
'--Update Best checkpoint!, epoch: 19'
--Saving best model checkpoint
time for 1 epoch: 1972.8341464996338
[19/30] [0/11532] loss: 0.004622033331543207 gen: 0.004547036252915859 gate: 7.499721687054262e-05
[19/30] [100/11532] loss: 0.03966108709573746 gen: 0.022906433790922165 gate: 0.016754651442170143
[19/30] [200/11532] loss: 0.01984308660030365 gen: 0.019602205604314804 gate: 0.00024088069039862603
[19/30] [300/11532] loss: 0.00031970959389582276 gen: 0.00030636152951046824 gate: 1.3348076208785642e-05
[19/30] [400/11532] loss: 0.14422547817230225 gen: 0.06358149647712708 gate: 0.08064398169517517
[19/30] [500/11532] loss: 0.0015934567200019956 gen: 0.0015282620443031192 gate: 6.519471935462207e-05
[19/30] [600/11532] loss: 5.736116145271808e-05 gen: 4

100%|██████████| 640/640 [02:06<00:00,  5.07it/s]


{'joint_goal_accuracy': 0.8034388432981634, 'turn_slot_accuracy': 0.9936824280317866, 'turn_slot_f1': 0.9728772329544656}
joint_goal_accuracy: 0.8034388432981634
turn_slot_accuracy: 0.9936824280317866
turn_slot_f1: 0.9728772329544656
time for 1 epoch: 1995.6757457256317
[20/30] [0/11532] loss: 0.01542236004024744 gen: 0.015409071929752827 gate: 1.3288233276398387e-05
[20/30] [100/11532] loss: 0.00023828927078284323 gen: 0.00021984227350912988 gate: 1.8446999092702754e-05
[20/30] [200/11532] loss: 0.00028930892585776746 gen: 0.0002728438121266663 gate: 1.6465104636154138e-05
[20/30] [300/11532] loss: 0.0267057865858078 gen: 0.007682624273002148 gate: 0.019023163244128227
[20/30] [400/11532] loss: 0.0006217620684765279 gen: 0.0006078876904211938 gate: 1.3874356227461249e-05
[20/30] [500/11532] loss: 0.06843580305576324 gen: 0.06820762157440186 gate: 0.00022818311117589474
[20/30] [600/11532] loss: 0.021174483001232147 gen: 0.009450078941881657 gate: 0.011724403128027916
[20/30] [700/1153

100%|██████████| 640/640 [01:59<00:00,  5.38it/s]


{'joint_goal_accuracy': 0.8114497850722938, 'turn_slot_accuracy': 0.9939125526464359, 'turn_slot_f1': 0.9733549156118104}
joint_goal_accuracy: 0.8114497850722938
turn_slot_accuracy: 0.9939125526464359
turn_slot_f1: 0.9733549156118104
'--Update Best checkpoint!, epoch: 21'
--Saving best model checkpoint
time for 1 epoch: 2002.7902505397797
[21/30] [0/11532] loss: 0.00026055844500660896 gen: 0.0002555967657826841 gate: 4.9616901378612965e-06
[21/30] [100/11532] loss: 0.026624459773302078 gen: 0.008500421419739723 gate: 0.018124038353562355
[21/30] [200/11532] loss: 0.0004209235485177487 gen: 0.0004159154777880758 gate: 5.008062544220593e-06
[21/30] [300/11532] loss: 0.007128195371478796 gen: 0.0071222297847270966 gate: 5.965678155916976e-06
[21/30] [400/11532] loss: 0.003183758584782481 gen: 0.0031488963868469 gate: 3.4862088796216995e-05
[21/30] [500/11532] loss: 0.009976658970117569 gen: 0.00979809183627367 gate: 0.00017856679914984852
[21/30] [600/11532] loss: 0.0073741585947573185 ge

100%|██████████| 640/640 [02:03<00:00,  5.16it/s]


{'joint_goal_accuracy': 0.8079327862446268, 'turn_slot_accuracy': 0.9938995267248537, 'turn_slot_f1': 0.9741001865066419}
joint_goal_accuracy: 0.8079327862446268
turn_slot_accuracy: 0.9938995267248537
turn_slot_f1: 0.9741001865066419
time for 1 epoch: 1970.3435521125793
[22/30] [0/11532] loss: 0.00964546762406826 gen: 0.00957574788480997 gate: 6.972001574467868e-05
[22/30] [100/11532] loss: 0.0031917495653033257 gen: 0.0031628513243049383 gate: 2.8898215532535687e-05
[22/30] [200/11532] loss: 0.000593931064940989 gen: 0.0004484574601519853 gate: 0.00014547361934091896
[22/30] [300/11532] loss: 0.007261476945132017 gen: 0.007227960973978043 gate: 3.351603663759306e-05
[22/30] [400/11532] loss: 0.0032853770535439253 gen: 0.003277236595749855 gate: 8.14054783404572e-06
[22/30] [500/11532] loss: 0.022472474724054337 gen: 0.022443797439336777 gate: 2.8677975933533162e-05
[22/30] [600/11532] loss: 0.0010901853675022721 gen: 0.0008389090653508902 gate: 0.00025127630215138197
[22/30] [700/1153

100%|██████████| 640/640 [02:10<00:00,  4.89it/s]


{'joint_goal_accuracy': 0.8114497850722938, 'turn_slot_accuracy': 0.9940211019929693, 'turn_slot_f1': 0.974453282760635}
joint_goal_accuracy: 0.8114497850722938
turn_slot_accuracy: 0.9940211019929693
turn_slot_f1: 0.974453282760635
time for 1 epoch: 1982.6032795906067
[23/30] [0/11532] loss: 1.787017026799731e-05 gen: 9.78732896328438e-06 gate: 8.082841304712929e-06
[23/30] [100/11532] loss: 0.0007423225906677544 gen: 0.0007366608479060233 gate: 5.661727300321218e-06
[23/30] [200/11532] loss: 6.4322471189370845e-06 gen: 4.400277759941673e-07 gate: 5.992219485051464e-06
[23/30] [300/11532] loss: 0.02212708815932274 gen: 0.022114889696240425 gate: 1.219891055370681e-05
[23/30] [400/11532] loss: 0.007484872359782457 gen: 0.00745773408561945 gate: 2.7138457880937494e-05
[23/30] [500/11532] loss: 0.0002525685995351523 gen: 0.0002457340306136757 gate: 6.834571649960708e-06
[23/30] [600/11532] loss: 0.009444788098335266 gen: 0.009421619586646557 gate: 2.3168844563770108e-05
[23/30] [700/11532

100%|██████████| 640/640 [02:03<00:00,  5.18it/s]


{'joint_goal_accuracy': 0.8139898397811646, 'turn_slot_accuracy': 0.9941470192349479, 'turn_slot_f1': 0.9748495400137208}
joint_goal_accuracy: 0.8139898397811646
turn_slot_accuracy: 0.9941470192349479
turn_slot_f1: 0.9748495400137208
'--Update Best checkpoint!, epoch: 24'
--Saving best model checkpoint
time for 1 epoch: 2007.5360128879547
[24/30] [0/11532] loss: 0.00894647091627121 gen: 0.008921812288463116 gate: 2.4659007976879366e-05
[24/30] [100/11532] loss: 6.995934381848201e-05 gen: 6.43479434074834e-05 gate: 5.6113972277671564e-06
[24/30] [200/11532] loss: 0.01434845756739378 gen: 0.014337568543851376 gate: 1.0889461009355728e-05
[24/30] [300/11532] loss: 0.020426807925105095 gen: 0.019656715914607048 gate: 0.0007700911955907941
[24/30] [400/11532] loss: 4.296658153180033e-05 gen: 3.763994391192682e-05 gate: 5.326639438862912e-06
[24/30] [500/11532] loss: 0.0003026046615559608 gen: 0.0002792944433167577 gate: 2.3310207325266674e-05
[24/30] [600/11532] loss: 0.00014195650874171406

100%|██████████| 640/640 [02:11<00:00,  4.88it/s]


{'joint_goal_accuracy': 0.8194607268464243, 'turn_slot_accuracy': 0.9941904389735607, 'turn_slot_f1': 0.9748568053275218}
joint_goal_accuracy: 0.8194607268464243
turn_slot_accuracy: 0.9941904389735607
turn_slot_f1: 0.9748568053275218
'--Update Best checkpoint!, epoch: 25'
--Saving best model checkpoint
time for 1 epoch: 2001.6690287590027
[25/30] [0/11532] loss: 0.0036198473535478115 gen: 0.0035992315970361233 gate: 2.061571467493195e-05
[25/30] [100/11532] loss: 0.00020452216267585754 gen: 0.0001799416058929637 gate: 2.4580562239862047e-05
[25/30] [200/11532] loss: 5.3616207878803834e-05 gen: 3.5891804145649076e-05 gate: 1.772440373315476e-05
[25/30] [300/11532] loss: 7.980680675245821e-05 gen: 4.120877929381095e-05 gate: 3.8598027458647266e-05
[25/30] [400/11532] loss: 0.00872435886412859 gen: 0.008532984182238579 gate: 0.0001913748128572479
[25/30] [500/11532] loss: 0.006098997313529253 gen: 0.006027907133102417 gate: 7.109028229024261e-05
[25/30] [600/11532] loss: 0.005193844437599

100%|██████████| 640/640 [02:13<00:00,  4.79it/s]


{'joint_goal_accuracy': 0.8120359515435717, 'turn_slot_accuracy': 0.9940514958099979, 'turn_slot_f1': 0.9741202436431604}
joint_goal_accuracy: 0.8120359515435717
turn_slot_accuracy: 0.9940514958099979
turn_slot_f1: 0.9741202436431604
time for 1 epoch: 1988.0891082286835
[26/30] [0/11532] loss: 0.0014102484565228224 gen: 0.0012875802349299192 gate: 0.00012266816338524222
[26/30] [100/11532] loss: 0.0005505651934072375 gen: 0.000499619753099978 gate: 5.094544030725956e-05
[26/30] [200/11532] loss: 0.0033519843127578497 gen: 0.002938586985692382 gate: 0.0004133972979616374
[26/30] [300/11532] loss: 0.06652822345495224 gen: 0.06556817889213562 gate: 0.000960043165832758
[26/30] [400/11532] loss: 0.00021695424220524728 gen: 0.0001453057921025902 gate: 7.164845010265708e-05
[26/30] [500/11532] loss: 0.002669818699359894 gen: 0.00257530203089118 gate: 9.451664664084092e-05
[26/30] [600/11532] loss: 0.0015080576995387673 gen: 0.0013840713072568178 gate: 0.0001239864359376952
[26/30] [700/11532

100%|██████████| 640/640 [01:59<00:00,  5.35it/s]


{'joint_goal_accuracy': 0.8116451738960532, 'turn_slot_accuracy': 0.9940949155486112, 'turn_slot_f1': 0.974747954737966}
joint_goal_accuracy: 0.8116451738960532
turn_slot_accuracy: 0.9940949155486112
turn_slot_f1: 0.974747954737966
time for 1 epoch: 1954.4034388065338
[27/30] [0/11532] loss: 2.3634231183677912e-05 gen: 1.636406705074478e-05 gate: 7.270164132933132e-06
[27/30] [100/11532] loss: 0.0002303244109498337 gen: 0.00022266995802056044 gate: 7.654458386241458e-06
[27/30] [200/11532] loss: 0.0003118542372249067 gen: 0.0002835374907590449 gate: 2.831675737979822e-05
[27/30] [300/11532] loss: 0.013522692024707794 gen: 0.012901538982987404 gate: 0.0006211534491740167
[27/30] [400/11532] loss: 0.0014884148258715868 gen: 0.0014762838836759329 gate: 1.2130899449402932e-05
[27/30] [500/11532] loss: 3.547004598658532e-05 gen: 2.721954842854757e-05 gate: 8.250496648543049e-06
[27/30] [600/11532] loss: 0.0003113968705292791 gen: 0.0001477089972468093 gate: 0.0001636878732824698
[27/30] [70

100%|██████████| 640/640 [01:59<00:00,  5.38it/s]


{'joint_goal_accuracy': 0.8151621727237202, 'turn_slot_accuracy': 0.9941730710781139, 'turn_slot_f1': 0.97498999566925}
joint_goal_accuracy: 0.8151621727237202
turn_slot_accuracy: 0.9941730710781139
turn_slot_f1: 0.97498999566925
time for 1 epoch: 1967.4125971794128
[28/30] [0/11532] loss: 6.876568659208715e-05 gen: 4.9007889174390584e-05 gate: 1.9757793779717758e-05
[28/30] [100/11532] loss: 5.688960663974285e-05 gen: 3.107321390416473e-05 gate: 2.5816390916588716e-05
[28/30] [200/11532] loss: 0.007588536944240332 gen: 0.00757877342402935 gate: 9.76329010882182e-06
[28/30] [300/11532] loss: 0.006140152923762798 gen: 0.006122842896729708 gate: 1.73100306710694e-05
[28/30] [400/11532] loss: 5.80010237172246e-05 gen: 5.373614840209484e-05 gate: 4.264875769877108e-06
[28/30] [500/11532] loss: 0.0006017561536282301 gen: 0.0005877602961845696 gate: 1.3995844710734673e-05
[28/30] [600/11532] loss: 0.007207480259239674 gen: 0.0071927872486412525 gate: 1.469305880164029e-05
[28/30] [700/11532]

100%|██████████| 640/640 [02:03<00:00,  5.17it/s]


{'joint_goal_accuracy': 0.8155529503712388, 'turn_slot_accuracy': 0.9941730710781143, 'turn_slot_f1': 0.9749586696482246}
joint_goal_accuracy: 0.8155529503712388
turn_slot_accuracy: 0.9941730710781143
turn_slot_f1: 0.9749586696482246
time for 1 epoch: 1978.1460721492767
[29/30] [0/11532] loss: 5.125667667016387e-05 gen: 3.9952679799171165e-05 gate: 1.1303997780487407e-05
[29/30] [100/11532] loss: 0.0003210804716218263 gen: 0.0003079766174778342 gate: 1.3103862329444382e-05
[29/30] [200/11532] loss: 0.00025593183818273246 gen: 0.00023967168817762285 gate: 1.6260142729151994e-05
[29/30] [300/11532] loss: 0.011781229637563229 gen: 0.011753227561712265 gate: 2.8002301405649632e-05


KeyboardInterrupt: 

In [23]:
!python inference.py

100%|█████████████████████████████████████| 2000/2000 [00:00<00:00, 6211.51it/s]
Token indices sequence length is longer than the specified maximum sequence length for this model (541 > 512). Running this sequence through the model will result in indexing errors
# eval: 14771
  "num_layers={}".format(dropout, num_layers))
Model is loaded
100%|███████████████████████████████████████| 1847/1847 [05:50<00:00,  5.27it/s]
