In [1]:
import sys
sys.path.insert(0,'./data')

import time
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
from torch import nn
from experiment.utils import *
#from models.decoder_with_attention import DecoderWithAttention
from models.decoder_with_attention_final import DecoderWithAttention
from transformers import (WEIGHTS_NAME, BertConfig,
                                  BertForSequenceClassification, BertTokenizer,
                                  )
from load_datasets_final import CaptionDataset
from create_inputs_utils import (CaptionProcessor,)
import create_inputs_utils_test

#from data.load_datasets_final import *
# from data.create_inputs_utils import *

from experiment._train_one_epoch import *
from experiment._validation_one_epoch import *

I1130 15:32:40.029527 140544855676736 file_utils.py:39] PyTorch version 0.4.1 available.
I1130 15:32:41.030264 140544855676736 configuration_utils.py:151] loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json from cache at /root/.cache/torch/transformers/4dad0251492946e18ac39290fcfe91b89d370fee250efe9521476438fe8ca185.bf3b9ea126d8c0001ee8a1e8b92229871d06d36d8808208cc2449280da87785c
I1130 15:32:41.031545 140544855676736 configuration_utils.py:168] Model config {
  "attention_probs_dropout_prob": 0.1,
  "finetuning_task": "sst-2",
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "num_labels": 1,
  "output_attentions": false,
  "output_hidden_states": false,
  "pruned_heads": {},
  "torchscript": false,
  "type_vocab_size": 2,
  "use_

In [2]:
def _get_available_devices():
    from tensorflow.python.client import device_lib
    local_device_protos = device_lib.list_local_devices()
    return [x.name for x in local_device_protos]

print(_get_available_devices())

['/device:CPU:0', '/device:XLA_CPU:0', '/device:XLA_GPU:0', '/device:XLA_GPU:1', '/device:GPU:0', '/device:GPU:1']


In [3]:
# Data parameters
data_folder = 'preprocessed_data'  # folder with data files saved by create_input_files.py
data_name = 'preprocessed_coco'  # base name shared by data files

# Model parameters
emb_dim = 768 #1024  # dimension of word embeddings > change bert embedding size (768)
attention_dim = 384  # dimension of attention linear layers 1024
decoder_dim = 384 # dimension of decoder RNN 1024
dropout = 0.5
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # sets device for model and PyTorch tensors
device_0 = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")  # sets device for model and PyTorch tensors

cudnn.benchmark = True  # set to true only if inputs to model are fixed size; otherwise lot of computational overhead

# Training parameters
start_epoch = 0
epochs = 40  # number of epochs to train for (if early stopping is not triggered)
epochs_since_improvement = 0  # keeps track of number of epochs since there's been an improvement in validation BLEU
batch_size = 20
workers = 0 # for data-loading; right now, only 1 works with h5py
best_bleu4 = 0.  # BLEU-4 score right now
# print_freq = 100  # print training/validation stats every __ batches
#checkpoint = 'ckpt/BEST_1checkpoint_preprocessed_coco.pth.tar'  # path to checkpoint, None if none
checkpoint = None  # path to checkpoint, None if none
BERT_VOCA_SIZE = 30522

In [None]:
# Read word map
#word_map_file = os.path.join(data_folder, 'WORDMAP_' + data_name + '.json')
#with open(word_map_file, 'r') as j:
#    word_map = json.load(j)

# Initialize / load checkpoint
if checkpoint is None:
    decoder = DecoderWithAttention(attention_dim=attention_dim,
                                   embed_dim=emb_dim,
                                   decoder_dim=decoder_dim,
                                   vocab_size=BERT_VOCA_SIZE,
                                   dropout=dropout)

    decoder_optimizer = torch.optim.Adamax(params=filter(lambda p: p.requires_grad, decoder.parameters()))

else:
    checkpoint = torch.load(checkpoint)
    start_epoch = checkpoint['epoch'] + 1
    epochs_since_improvement = checkpoint['epochs_since_improvement']
    best_bleu4 = checkpoint['bleu-4']
    decoder = checkpoint['decoder']
    decoder_optimizer = checkpoint['decoder_optimizer']
    
    
# Move to GPU, if available
decoder = decoder.to(device)

# Loss functions
criterion_ce = nn.CrossEntropyLoss().to(device)
criterion_dis = nn.MultiLabelMarginLoss().to(device)

# Custom dataloaders suffle 빼버림
train_loader = torch.utils.data.DataLoader(
    CaptionDataset(data_folder, data_name, 'TRAIN'),
    batch_size=batch_size,num_workers=workers, pin_memory=True)
val_loader = torch.utils.data.DataLoader(
    CaptionDataset(data_folder, data_name, 'VAL'),
    batch_size=batch_size,num_workers=workers, pin_memory=True)


for epoch in range(start_epoch, epochs):
    
    # Decay learning rate if there is no improvement for 8 consecutive epochs, and terminate training after 20
    if epochs_since_improvement == 20:
        break
    if epochs_since_improvement > 0 and epochs_since_improvement % 8 == 0:
        adjust_learning_rate(decoder_optimizer, 0.8)

    # One epoch's training
    train(train_loader=train_loader,
          decoder=decoder,
          criterion_ce = criterion_ce,
          criterion_dis=criterion_dis,
          decoder_optimizer=decoder_optimizer,
          epoch=epoch)


    # One epoch's validation
    recent_bleu4 = validate(val_loader=val_loader,
                            decoder=decoder,
                            criterion_ce=criterion_ce,
                            criterion_dis=criterion_dis,)
    
    # Check if there was an improvement
    is_best = recent_bleu4 > best_bleu4
    best_bleu4 = max(recent_bleu4, best_bleu4)
    if not is_best:
        epochs_since_improvement += 1
        print("\nEpochs since last best performance: %d\n" % (epochs_since_improvement,))
    else:
        epochs_since_improvement = 0

    # Save checkpoint
    save_checkpoint(data_name, epoch, epochs_since_improvement, decoder,decoder_optimizer, recent_bleu4, is_best)

I1130 15:32:43.362420 140544855676736 configuration_utils.py:151] loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json from cache at /root/.cache/torch/transformers/4dad0251492946e18ac39290fcfe91b89d370fee250efe9521476438fe8ca185.bf3b9ea126d8c0001ee8a1e8b92229871d06d36d8808208cc2449280da87785c
I1130 15:32:43.363126 140544855676736 configuration_utils.py:168] Model config {
  "attention_probs_dropout_prob": 0.1,
  "finetuning_task": "sst-2",
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "num_labels": 1,
  "output_attentions": false,
  "output_hidden_states": false,
  "pruned_heads": {},
  "torchscript": false,
  "type_vocab_size": 2,
  "use_bfloat16": false,
  "vocab_size": 30522
}

I1130 15:32:44.200376 140544855676736 tokeniza

Epoch: [0][0/28322]	Batch Time 0.320 (0.320)	Data Load Time 0.038 (0.038)	Loss 253.8363 (253.8363)	Top-5 Accuracy 0.000 (0.000)
Epoch: [0][100/28322]	Batch Time 0.156 (0.161)	Data Load Time 0.011 (0.014)	Loss 54.3673 (55.7030)	Top-5 Accuracy 30.290 (27.491)
Epoch: [0][200/28322]	Batch Time 0.148 (0.159)	Data Load Time 0.005 (0.013)	Loss 28.1113 (53.0845)	Top-5 Accuracy 32.895 (29.211)
Epoch: [0][300/28322]	Batch Time 0.172 (0.159)	Data Load Time 0.014 (0.013)	Loss 64.1638 (51.4028)	Top-5 Accuracy 23.293 (29.824)
Epoch: [0][400/28322]	Batch Time 0.135 (0.160)	Data Load Time 0.011 (0.013)	Loss 24.6474 (49.6953)	Top-5 Accuracy 32.759 (30.644)
Epoch: [0][500/28322]	Batch Time 0.140 (0.159)	Data Load Time 0.015 (0.013)	Loss 32.3378 (44.9378)	Top-5 Accuracy 30.263 (31.066)
Epoch: [0][600/28322]	Batch Time 0.147 (0.159)	Data Load Time 0.017 (0.013)	Loss 17.7505 (41.3620)	Top-5 Accuracy 35.169 (31.622)
Epoch: [0][700/28322]	Batch Time 0.153 (0.158)	Data Load Time 0.009 (0.013)	Loss 15.7754 (37