### som-dst `train.py`

## 변화된 점
- mixed precision 추가

In [1]:
import sys
sys.path.append('..')

In [2]:
"""
SOM-DST
Copyright (c) 2020-present NAVER Corp.
MIT license
"""

from model import SomDST
# from pytorch_transformers import BertTokenizer, AdamW, WarmupLinearSchedule, BertConfig
from transformers import BertModel, BertTokenizer, BertConfig, AdamW, get_linear_schedule_with_warmup, get_cosine_with_hard_restarts_schedule_with_warmup
from utils.data_utils import prepare_dataset, MultiWozDataset
from utils.data_utils import  domain2id, OP_SET, make_turn_label, postprocessing # make_slot_meta,
from utils.eval_utils import compute_prf, compute_acc, per_domain_join_accuracy
# from utils.ckpt_utils import download_ckpt, convert_ckpt_compatible
from evaluation import model_evaluation

import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
import torch.cuda.amp as amp

import numpy as np
import argparse
import random
import os
import json
import time
import pickle
from tqdm import tqdm
from pathlib import Path

# 이상 검출을 위한 코드
torch.autograd.set_detect_anomaly(True)

<torch.autograd.anomaly_mode.set_detect_anomaly at 0x7f6df8233910>

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device Name : {device}')

# torch.cuda.empty_cache()

Device Name : cuda


In [4]:
def masked_cross_entropy_for_value(logits, target, pad_idx=0):
    mask = target.ne(pad_idx)
    logits_flat = logits.view(-1, logits.size(-1))
    log_probs_flat = torch.log(logits_flat)
    target_flat = target.view(-1, 1)
    losses_flat = -torch.gather(log_probs_flat, dim=1, index=target_flat)
    losses = losses_flat.view(*target.size())
    losses = losses * mask.float()
    loss = losses.sum() / (mask.sum().float())
    return loss

In [5]:
def increment_output_dir(output_path, exist_ok=False):
  path = Path(output_path)
  if (path.exists() and exist_ok) or (not path.exists()):
    return str(path)
  else:
    dirs = glob.glob(f"{path}*")
    matches = [re.search(rf"%s(\d+)" %path.stem, d) for d in dirs]
    i = [int(m.groups()[0]) for m in matches if m]
    n = max(i) + 1 if i else 2
    return f"{path}{n}"

In [6]:
import wandb
# !wandb login  # run once

In [7]:
# wandb sweep 생성 시 parameters에 전달하는 config 설정
# hyperparameter_defaults = dict(
#     batch_size = args.batch_size,
#     learning_rate = args.learning_rate,
#     epochs = args.num_train_epochs,
#     weight_decay = args.weight_decay,
#     attn_head = args.attn_head,
#     distance_metric = args.distance_metric,
    
#     dropout = 0.1,
#     smoothing = 0.2
#     model_name = 'BertForSequenceClassification',
#     tokenizer_name = 'BertTokenizer',
#     )

# wandb.init(config=hyperparameter_defaults, project="TRADE")
wandb.init(project="SOM-DST")
config = wandb.config

[34m[1mwandb[0m: Currently logged in as: [33mtaepd[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.10.30 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


### args setting

In [8]:
from argparse import Namespace

parser = argparse.ArgumentParser()

# Required parameters
args = {
    "data_root":'data',  # default='data/mwz2.1'
    "train_data":'train_dials.json',
    "dev_data":'dev_dials.json',
    "test_data":'test_dials.json',
    "ontology_data":'ontology.json',
    "slot_meta":'slot_meta.json',
    # "--vocab_path":'assets/vocab.txt',
    "bert_config_path":'assets/bert_config_base_uncased.json',
    "bert_ckpt_path":'assets/bert-base-uncased-pytorch_model.bin',
    "save_dir":'outputs',

    "random_seed":42, 
    "num_workers":4, 
    "batch_size": 32, 
    "enc_warmup":0.1, 
    "dec_warmup":0.1, 
    "enc_lr":4e-5, # default 4e-5
    "dec_lr":1e-4, 
    "n_epochs": 100, 
    "eval_epoch":1, 

    "op_code":"4",
    "slot_token":"[SLOT]",
    "dropout":0.1, 
    "hidden_dropout_prob":0.1, 
    "attention_probs_dropout_prob":0.1, 
    "decoder_teacher_forcing":1, 
    "word_dropout":0.1, 
    "not_shuffle_state":False,
    "shuffle_p":0.5, 

    "n_history":1, 
    "max_seq_length":512, 
    "msg":None,
    "exclude_domain": True, 
}

args = Namespace(**args)

args.train_data_path = os.path.join(args.data_root, args.train_data)
args.dev_data_path = os.path.join(args.data_root, args.dev_data)
args.test_data_path = os.path.join(args.data_root, args.test_data)
args.ontology_data = os.path.join(args.data_root, args.ontology_data)
args.slot_meta = os.path.join(args.data_root, args.slot_meta)
args.shuffle_state = False if args.not_shuffle_state else True


print('pytorch version: ', torch.__version__)
# print(args)

pytorch version:  1.7.0+cu101


In [9]:
def worker_init_fn(worker_id):
        np.random.seed(args.random_seed + worker_id)

n_gpu = 0
if torch.cuda.is_available():
    n_gpu = torch.cuda.device_count()
np.random.seed(args.random_seed)
random.seed(args.random_seed)
rng = random.Random(args.random_seed)
torch.manual_seed(args.random_seed)
if n_gpu > 0:
    torch.cuda.manual_seed(args.random_seed)
    torch.cuda.manual_seed_all(args.random_seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

if not os.path.exists(args.save_dir):
    os.mkdir(args.save_dir)

In [10]:
ontology = json.load(open(args.ontology_data))
# slot_meta, ontology = make_slot_meta(ontology)
slot_meta = json.load(open(args.slot_meta))
op2id = OP_SET[args.op_code]
print(op2id)


{'delete': 0, 'update': 1, 'dontcare': 2, 'carryover': 3}


### MODEL CONFIG

In [11]:
# model_config = BertConfig.from_json_file(args.bert_config_path)
model_config = BertConfig.from_pretrained('dsksd/bert-ko-small-minimal')
model_config.dropout = args.dropout
model_config.attention_probs_dropout_prob = args.attention_probs_dropout_prob
model_config.hidden_dropout_prob = args.hidden_dropout_prob
model_config.vocab_size += 3  # 추가되는 special token 고려
model = SomDST(model_config, len(op2id), len(domain2id), op2id['update'], args.exclude_domain)

wandb.watch(model)
# if not os.path.exists(args.bert_ckpt_path):
#     args.bert_ckpt_path = download_ckpt(args.bert_ckpt_path, args.bert_config_path, 'assets')

# ckpt = torch.load(args.bert_ckpt_path, map_location='cpu')
# model.encoder.bert.load_state_dict(ckpt)


# re-initialize added special tokens ([SLOT], [NULL], [EOS])
# model.encoder.bert.embeddings.word_embeddings.weight.data[1].normal_(mean=0.0, std=0.02)
# model.encoder.bert.embeddings.word_embeddings.weight.data[2].normal_(mean=0.0, std=0.02)
# model.encoder.bert.embeddings.word_embeddings.weight.data[3].normal_(mean=0.0, std=0.02)


# tokenizer = BertTokenizer(args.vocab_path, do_lower_case=True)
tokenizer = model.encoder.tokenizer
# tokenizer = BertTokenizer.from_pretrained('dsksd/bert-ko-small-minimal', additional_special_tokens = ['[SLOT]', '[NULL]','[EOS]'])
# tokenizer = BertTokenizer.from_pretrained('dsksd/bert-ko-small-minimal')

model.to(device)

Special tokens have been added in the vocabulary, make sure the associated word embedding are fine-tuned or trained.


SomDST(
  (encoder): Encoder(
    (bert): BertModel(
      (embeddings): BertEmbeddings(
        (word_embeddings): Embedding(35003, 768)
        (position_embeddings): Embedding(512, 768)
        (token_type_embeddings): Embedding(2, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): BertEncoder(
        (layer): ModuleList(
          (0): BertLayer(
            (attention): BertAttention(
              (self): BertSelfAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): Linear(in_features=768, out_features=768, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (output): BertSelfOutput(
                (dense): Linear(in_features=768, out_features=768, bias=True)
                (LayerNorm): LayerNorm(

In [12]:
tokenizer.all_special_tokens

tokenizer.additional_special_tokens_ids



[35000, 35001, 35002]

### DATA PREPERATION

In [13]:
if not os.path.exists('raw_data'):
    os.mkdir('raw_data')
print('Making Train_Data_Raw....')
if not os.path.exists('./raw_data/train_data_raw'):
    train_data_raw = prepare_dataset(data_path=args.train_data_path,
                                    tokenizer=tokenizer,
                                    slot_meta=slot_meta,
                                    n_history=args.n_history,
                                    max_seq_length=args.max_seq_length,
                                    op_code=args.op_code)
    with open('./raw_data/train_data_raw', 'wb') as f:
        pickle.dump(train_data_raw, f)
else:
    with open('./raw_data/train_data_raw', 'rb') as f:
        train_data_raw = pickle.load(f)

train_data = MultiWozDataset(train_data_raw,
                             tokenizer,
                             slot_meta,
                             args.max_seq_length,
                             rng,
                             ontology,
                             args.word_dropout,
                             args.shuffle_state,
                             args.shuffle_p)
print("# train examples %d" % len(train_data_raw))

print('Making Dev_Data_Raw....')
# print(f'max_seq_length : {args.max_seq_length}')
if not os.path.exists('./raw_data/dev_data_raw'):
    dev_data_raw = prepare_dataset(data_path=args.dev_data_path,
                                tokenizer=tokenizer,
                                slot_meta=slot_meta,
                                n_history=args.n_history,
                                max_seq_length=args.max_seq_length,
                                op_code=args.op_code)
    with open('./raw_data/dev_data_raw', 'wb') as f:
        pickle.dump(dev_data_raw, f)
else:
    with open('./raw_data/dev_data_raw', 'rb') as f:
        dev_data_raw = pickle.load(f)
print("# dev examples %d" % len(dev_data_raw))

if not os.path.exists('./raw_data/test_data_raw'):
    test_data_raw = prepare_dataset(data_path=args.test_data_path,
                                    tokenizer=tokenizer,
                                    slot_meta=slot_meta,
                                    n_history=args.n_history,
                                    max_seq_length=args.max_seq_length,
                                    op_code=args.op_code)
    with open('./raw_data/test_data_raw', 'wb') as f:
        pickle.dump(test_data_raw, f)
else:
    with open('./raw_data/test_data_raw', 'rb') as f:
        test_data_raw = pickle.load(f)
print("# test examples %d" % len(test_data_raw))

Making Train_Data_Raw....
# train examples 45320
Making Dev_Data_Raw....
# dev examples 5067
# test examples 14771


### Optimizer & Scheduler

In [14]:
######## STEPS ########
num_train_steps = int(len(train_data_raw) / args.batch_size * args.n_epochs)


######## ENC/DEC OPTIMIZER ########
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
enc_param_optimizer = list(model.encoder.named_parameters())
enc_optimizer_grouped_parameters = [
    {'params': [p for n, p in enc_param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
    {'params': [p for n, p in enc_param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ]

enc_optimizer = AdamW(enc_optimizer_grouped_parameters, lr=args.enc_lr)
# enc_scheduler = WarmupLinearSchedule(enc_optimizer, int(num_train_steps * args.enc_warmup),
#                                      t_total=num_train_steps)
# enc_scheduler = get_linear_schedule_with_warmup(enc_optimizer, num_warmup_steps=int(num_train_steps * args.enc_warmup),
#                                      num_training_steps=num_train_steps)
enc_scheduler = get_linear_schedule_with_warmup(enc_optimizer, num_training_steps=num_train_steps, num_warmup_steps=0)

dec_param_optimizer = list(model.decoder.parameters())
dec_optimizer = AdamW(dec_param_optimizer, lr=args.dec_lr)
# dec_scheduler = WarmupLinearSchedule(dec_optimizer, int(num_train_steps * args.dec_warmup),
#                                      t_total=num_train_steps)
# dec_scheduler = get_linear_schedule_with_warmup(dec_optimizer, num_warmup_steps=int(num_train_steps * args.dec_warmup),
#                                     num_training_steps=num_train_steps)
dec_scheduler = get_linear_schedule_with_warmup(dec_optimizer, num_training_steps=num_train_steps, num_warmup_steps=0)

if n_gpu > 1:
    model = torch.nn.DataParallel(model)

### Train DataLoading

In [15]:
print('Making Train DataLoader...')
train_sampler = RandomSampler(train_data)
train_dataloader = DataLoader(train_data,
                              sampler=train_sampler,
                              batch_size=args.batch_size,
                              collate_fn=train_data.collate_fn,
                              num_workers=args.num_workers,
                              worker_init_fn=worker_init_fn)

Making Train DataLoader...


### Train

In [16]:
def get_lr(scheduler):
    return scheduler.get_last_lr()[0]

In [17]:
output_dir = increment_output_dir(wandb.run.name)

In [19]:
######## TRAINING ########
print("Let's Do the Training!")
cnt = 0
loss_fnc = nn.CrossEntropyLoss()
best_score = {'epoch': 0, 'joint_acc': 0, 'op_acc': 0, 'final_slot_f1': 0}
output_dir = increment_output_dir(wandb.run.name)
if not os.path.exists(f"checkpoint/{output_dir}"):
    os.makedirs(f"checkpoint/{output_dir}")  


for epoch in range(args.n_epochs):
    batch_loss = []
    model.train()
    for step, batch in enumerate(tqdm(train_dataloader)):
        batch = [b.to(device) if not isinstance(b, int) else b for b in batch]
        input_ids, input_mask, segment_ids, state_position_ids, op_ids,\
        domain_ids, gen_ids, max_value, max_update = batch

        if rng.random() < args.decoder_teacher_forcing:  # teacher forcing
            teacher = gen_ids
        else:
            teacher = None
#         with amp.autocast():
        domain_scores, state_scores, gen_scores = model(input_ids=input_ids,
                                                        token_type_ids=segment_ids,
                                                        state_positions=state_position_ids,
                                                        attention_mask=input_mask,
                                                        max_value=max_value,
                                                        op_ids=op_ids,
                                                        max_update=max_update,
                                                        teacher=teacher)

#         for param in model.parameters():
#             print('param: ', param.data.isnan().any())
#         break
        
        loss_s = loss_fnc(state_scores.view(-1, len(op2id)), op_ids.view(-1))
        loss_g = masked_cross_entropy_for_value(gen_scores.contiguous(),
                                                gen_ids.contiguous(),
                                                tokenizer.vocab['[PAD]'])
           
        loss = loss_s + loss_g
        if args.exclude_domain is not True:
            loss_d = loss_fnc(domain_scores.view(-1, len(domain2id)), domain_ids.view(-1))
            loss = loss + loss_d
        batch_loss.append(loss.item())  # loss_s + loss_g

#         loss.register_hook(lambda grad: print(grad)) 
        
#         print(torch.isfinite(input), torch.isfinite(target))
        loss.backward()
        loss.register_hook(lambda grad: print(grad)) 

#         nn.utils.clip_grad_norm_(model.parameters(), 1.0)  # nan loss 해결하고자 추가한 코드
        
    
#         for name, param in model.named_parameters():
#             print(name, param.abs().max())
#         break
        enc_optimizer.step()
        enc_scheduler.step()
        dec_optimizer.step()
        dec_scheduler.step()
        model.zero_grad()

        if step % 100 == 0:
            if args.exclude_domain is not True:
                print("[%d/%d] [%d/%d] mean_loss : %.3f, state_loss : %.3f, gen_loss : %.3f, dom_loss : %.3f" \
                      % (epoch+1, args.n_epochs, step,
                         len(train_dataloader), np.mean(batch_loss),
                         loss_s.item(), loss_g.item(), loss_d.item()))
                wandb.log({
                    "train/mean_loss": np.mean(batch_loss),
                    "train/state_loss": loss_s.item(),
                    "train/gen_loss": loss_g.item(),
                    "train/dom_loss": loss_d.item(),
                    "train/epoch": epoch+1,
                    "train/enc_learning rate": get_lr(enc_scheduler),
                    "train/dec_learning rate": get_lr(dec_scheduler)
                })
            else:
                print("[%d/%d] [%d/%d] mean_loss : %.3f, state_loss : %.3f, gen_loss : %.3f" \
                      % (epoch+1, args.n_epochs, step,
                         len(train_dataloader), np.mean(batch_loss),
                         loss_s.item(), loss_g.item()))
                wandb.log({
                    "train/mean_loss": np.mean(batch_loss),
                    "train/state_loss": loss_s.item(),
                    "train/gen_loss": loss_g.item(),
                    "train/epoch": epoch+1,
                    "train/enc_learning rate": get_lr(enc_scheduler),
                    "train/dec_learning rate": get_lr(dec_scheduler)
                })
            batch_loss = []

    if (epoch+1) % args.eval_epoch == 0:
        eval_res = model_evaluation(model, dev_data_raw, tokenizer, slot_meta, epoch+1, args.op_code)
        wandb.log({
            'eval/epoch': epoch+1,
            'eval/joint_acc': eval_res['joint_acc'],
            'eval/slot_acc':eval_res['slot_acc'],
            'eval/slot_f1': eval_res['slot_f1'],
            'eval/op_acc': eval_res['op_acc'],
            'eval/op_f1': eval_res['op_f1'],
            'eval/final_slot_f1':eval_res['final_slot_f1'],
        })
        if eval_res['joint_acc'] > best_score['joint_acc']:
            best_score = eval_res
            model_to_save = model.module if hasattr(model, 'module') else model
            output_path = f"checkpoint/{output_dir}/model_best.bin"
#             save_path = os.path.join(args.save_dir, 'model_best.bin')
#             torch.save(model_to_save.state_dict(), save_path)
            torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'enc_optimizer_state_dict': enc_optimizer.state_dict(),
                    'dec_optimizer_state_dict': dec_optimizer.state_dict(),
                    'loss': loss,
                    }, output_path)        
        print("Best Score : ", best_score)
        print("\n")

torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'enc_optimizer_state_dict': enc_optimizer.state_dict(),
            'dec_optimizer_state_dict': dec_optimizer.state_dict(),
            'loss': loss,
            }, f"checkpoint/{output_dir}/model_last.bin")  

  0%|          | 0/1417 [00:00<?, ?it/s]

Let's Do the Training!
domain_scores: tensor(False, device='cuda:0')
state_scores: tensor(False, device='cuda:0')
decoder_input: tensor(False, device='cuda:0')
seqeunce_output: tensor(False, device='cuda:0')
pooled_output: tensor(False, device='cuda:0')
gen_scores: tensor(False, device='cuda:0')


  0%|          | 1/1417 [00:02<1:01:19,  2.60s/it]

[1/100] [0/1417] mean_loss : 0.188, state_loss : 0.009, gen_loss : 0.180
domain_scores: tensor(False, device='cuda:0')
state_scores: tensor(True, device='cuda:0')
decoder_input: tensor(True, device='cuda:0')
seqeunce_output: tensor(True, device='cuda:0')
pooled_output: tensor(True, device='cuda:0')


  0%|          | 1/1417 [00:03<1:20:41,  3.42s/it]

gen_scores: tensor(True, device='cuda:0')





RuntimeError: Function 'LogBackward' returned nan values in its 0th output.

In [18]:
###### TEST DATA EVALUATION ######
print("Test using best model...")
# best_epoch = best_score['epoch']
ckpt_path = os.path.join(args.save_dir, 'model_best.bin')
ckpt_path = os.path.join('checkpoint/pretty-serenity-42/model_best.bin')
model = SomDST(model_config, len(op2id), len(domain2id), op2id['update'], args.exclude_domain)
######## ENC/DEC OPTIMIZER ########
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
enc_param_optimizer = list(model.encoder.named_parameters())
enc_optimizer_grouped_parameters = [
    {'params': [p for n, p in enc_param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
    {'params': [p for n, p in enc_param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ]

# enc_optimizer = AdamW(enc_optimizer_grouped_parameters, lr=args.enc_lr)
enc_optimizer = AdamW(enc_optimizer_grouped_parameters, lr=0)
# enc_scheduler = WarmupLinearSchedule(enc_optimizer, int(num_train_steps * args.enc_warmup),
#                                      t_total=num_train_steps)
# enc_scheduler = get_linear_schedule_with_warmup(enc_optimizer, num_warmup_steps=int(num_train_steps * args.enc_warmup),
#                                      num_training_steps=num_train_steps)
enc_scheduler = get_linear_schedule_with_warmup(enc_optimizer, num_training_steps=num_train_steps, num_warmup_steps=0)

dec_param_optimizer = list(model.decoder.parameters())
dec_optimizer = AdamW(dec_param_optimizer, lr=args.dec_lr)
# dec_scheduler = WarmupLinearSchedule(dec_optimizer, int(num_train_steps * args.dec_warmup),
#                                      t_total=num_train_steps)
# dec_scheduler = get_linear_schedule_with_warmup(dec_optimizer, num_warmup_steps=int(num_train_steps * args.dec_warmup),
#                                     num_training_steps=num_train_steps)
dec_scheduler = get_linear_schedule_with_warmup(dec_optimizer, num_training_steps=num_train_steps, num_warmup_steps=0)
model.to(device)
# ckpt = torch.load(ckpt_path, map_location='cpu')
ckpt = torch.load(ckpt_path)

# model.load_state_dict(ckpt)
model.load_state_dict(ckpt['model_state_dict'])

enc_optimizer.load_state_dict(ckpt['enc_optimizer_state_dict'])
dec_optimizer.load_state_dict(ckpt['dec_optimizer_state_dict'])

model.train()
# model_evaluation(model, test_data_raw, tokenizer, slot_meta, 999, args.op_code,
#                  is_gt_op=False, is_gt_p_state=False, is_gt_gen=False)
# model_evaluation(model, test_data_raw, tokenizer, slot_meta, best_epoch, args.op_code,
#                  is_gt_op=False, is_gt_p_state=False, is_gt_gen=True)
# model_evaluation(model, test_data_raw, tokenizer, slot_meta, best_epoch, args.op_code,
#                  is_gt_op=False, is_gt_p_state=True, is_gt_gen=False)
# model_evaluation(model, test_data_raw, tokenizer, slot_meta, best_epoch, args.op_code,
#                  is_gt_op=False, is_gt_p_state=True, is_gt_gen=True)
# model_evaluation(model, test_data_raw, tokenizer, slot_meta, best_epoch, args.op_code,
#                  is_gt_op=True, is_gt_p_state=False, is_gt_gen=False)
# model_evaluation(model, test_data_raw, tokenizer, slot_meta, best_epoch, args.op_code,
#                  is_gt_op=True, is_gt_p_state=True, is_gt_gen=False)
# model_evaluation(model, test_data_raw, tokenizer, slot_meta, best_epoch, args.op_code,
#                  is_gt_op=True, is_gt_p_state=False, is_gt_gen=True)
# model_evaluation(model, test_data_raw, tokenizer, slot_meta, best_epoch, args.op_code,
#                  is_gt_op=True, is_gt_p_state=True, is_gt_gen=True)

Test using best model...


Special tokens have been added in the vocabulary, make sure the associated word embedding are fine-tuned or trained.


SomDST(
  (encoder): Encoder(
    (bert): BertModel(
      (embeddings): BertEmbeddings(
        (word_embeddings): Embedding(35003, 768)
        (position_embeddings): Embedding(512, 768)
        (token_type_embeddings): Embedding(2, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): BertEncoder(
        (layer): ModuleList(
          (0): BertLayer(
            (attention): BertAttention(
              (self): BertSelfAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): Linear(in_features=768, out_features=768, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (output): BertSelfOutput(
                (dense): Linear(in_features=768, out_features=768, bias=True)
                (LayerNorm): LayerNorm(

In [28]:
pred = json.load(open('preds_999.json'))

In [29]:
predictions = {}
for k, v in pred.items():
    predictions[k] = v[0]
predictions

json.dump(
        predictions,
#         open(f"{args.output_dir}/predictions.csv", "w"),
        open(f"predictions.csv", "w"),
        indent=2,
        ensure_ascii=False,
    )