In [1]:
"""
Copyright 2018 NAVER Corp.
All rights reserved.

Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:

1. Redistributions of source code must retain the above copyright
   notice, this list of conditions and the following disclaimer.

2. Redistributions in binary form must reproduce the above copyright
   notice, this list of conditions and the following disclaimer in the
   documentation and/or other materials provided with the distribution.

3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America
   and IDIAP Research Institute nor the names of its contributors may be
   used to endorse or promote products derived from this software without
   specific prior written permission.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
POSSIBILITY OF SUCH DAMAGE.
"""
import argparse
import time
from datetime import datetime
import numpy as np
import random
import json
import logging
import torch
import os, sys
parentPath = os.path.abspath("..")
sys.path.insert(0, parentPath)# add parent folder to path so as to import common modules
from helper import timeSince, sent2indexes, indexes2sent, gData, gVar
import models, experiments, configs, data
from experiments import Metrics
from sample import evaluate

from tensorboardX import SummaryWriter # install tensorboardX (pip install tensorboardX) before importing this package

parser = argparse.ArgumentParser(description='DialogWAE Pytorch')
# Path Arguments
parser.add_argument('--data_path', type=str, default='/home/prakhar/Desktop/DialogWAE/data/', help='location of the data corpus')
parser.add_argument('--dataset', type=str, default='DailyDial', help='name of dataset. SWDA or DailyDial')
parser.add_argument('--model', type=str, default='DialogWAE_GMP', help='model name')
parser.add_argument('--expname', type=str, default='basic', help='experiment name, for disinguishing different parameter settings')
parser.add_argument('--visual', action='store_true', default=False, help='visualize training status in tensorboard')
parser.add_argument('--reload_from', type=int, default=-1, help='reload from a trained ephoch')
parser.add_argument('--gpu_id', type=int, default=0, help='GPU ID')

# Evaluation Arguments
parser.add_argument('--sample', action='store_true', help='sample when decoding for generation')
parser.add_argument('--log_every', type=int, default=100, help='interval to log training results')
parser.add_argument('--valid_every', type=int, default=400, help='interval to validation')
parser.add_argument('--eval_every', type=int, default=2, help='interval to evaluate on the validation set')
parser.add_argument('--seed', type=int, default=1111, help='random seed')

args = parser.parse_args()
print(vars(args))

# make output directory if it doesn't already exist
if not os.path.isdir('./output'):
    os.makedirs('./output')
if not os.path.isdir('./output/{}'.format(args.model)):
    os.makedirs('./output/{}'.format(args.model))
if not os.path.isdir('./output/{}/{}'.format(args.model, args.expname)):
    os.makedirs('./output/{}/{}'.format(args.model, args.expname))
if not os.path.isdir('./output/{}/{}/{}'.format(args.model, args.expname, args.dataset)):
    os.makedirs('./output/{}/{}/{}'.format(args.model, args.expname, args.dataset))
if not os.path.isdir('./output/{}/{}/{}/models'.format(args.model, args.expname, args.dataset)):
    os.makedirs('./output/{}/{}/{}/models'.format(args.model, args.expname, args.dataset))
if not os.path.isdir('./output/{}/{}/{}/tmp_results'.format(args.model, args.expname, args.dataset)):
    os.makedirs('./output/{}/{}/{}/tmp_results'.format(args.model, args.expname, args.dataset))

# save arguments
json.dump(vars(args), open('./output/{}/{}/{}/args.json'.format(args.model, args.expname, args.dataset), 'w'))

# LOG #
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.DEBUG, format="%(message)s")#,format="%(asctime)s: %(name)s: %(levelname)s: %(message)s")
fh = logging.FileHandler("./output/{}/{}/{}/logs.txt".format(args.model, args.expname, args.dataset))
                                  # create file handler which logs even debug messages
logger.addHandler(fh)# add the handlers to the logger

# Set the random seed manually for reproducibility.
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
use_cuda = torch.cuda.is_available()
if use_cuda:
    torch.cuda.set_device(args.gpu_id) # set gpu device
    torch.cuda.manual_seed(args.seed)

def save_model(model, epoch):
    print("Saving models")
    torch.save(f='./output/{}/{}/{}/models/model_epo{}.pckl'.format(args.model, args.expname, args.dataset, epoch),obj=model)
def load_model(epoch):
    print("Loading models")
    model = torch.load(f='./output/{}/{}/{}/models/model_epo{}.pckl'.format(args.model, args.expname, args.dataset, epoch))
    return model

config = getattr(configs, 'config_'+args.model)()

###############################################################################
# Load data
###############################################################################
data_path=args.data_path+args.dataset+'/'
#
corpus = getattr(data, args.dataset+'Corpus')(data_path, wordvec_path='/media/prakhar/Local Disk/glove/glove.twitter.27B.200d.txt', wordvec_dim=config['emb_size'])
dials = corpus.get_dialogs()
metas = corpus.get_metas()
train_dial, valid_dial, test_dial = dials.get("train"), dials.get("valid"), dials.get("test")
train_meta, valid_meta, test_meta = metas.get("train"), metas.get("valid"), metas.get("test")
train_loader = getattr(data, args.dataset+'DataLoader')("Train", train_dial, train_meta, config['maxlen'])
valid_loader = getattr(data, args.dataset+'DataLoader')("Valid", valid_dial, valid_meta, config['maxlen'])
test_loader = getattr(data, args.dataset+'DataLoader')("Test", test_dial, test_meta, config['maxlen'])

vocab = corpus.ivocab
ivocab = corpus.vocab
n_tokens = len(ivocab)

metrics=Metrics(corpus.word2vec)

print("Loaded data!")

###############################################################################
# Define the models
###############################################################################

model = getattr(models, args.model)(config, n_tokens) if args.reload_from<0 else load_model(args.reload_from)
if use_cuda:
    model=model.cuda()
    
if corpus.word2vec is not None and args.reload_from<0:
    print("Loaded word2vec")
    model.embedder.weight.data.copy_(torch.from_numpy(corpus.word2vec))
    model.embedder.weight.data[0].fill_(0)

tb_writer = SummaryWriter("./output/{}/{}/{}/logs/".format(args.model, args.expname, args.dataset)\
                          +datetime.now().strftime('%Y%m%d%H%M')) if args.visual else None


usage: ipykernel_launcher.py [-h] [--data_path DATA_PATH] [--dataset DATASET]
                             [--model MODEL] [--expname EXPNAME] [--visual]
                             [--reload_from RELOAD_FROM] [--gpu_id GPU_ID]
                             [--sample] [--log_every LOG_EVERY]
                             [--valid_every VALID_EVERY]
                             [--eval_every EVAL_EVERY] [--seed SEED]
ipykernel_launcher.py: error: unrecognized arguments: -f /run/user/1000/jupyter/kernel-2530a822-ad6c-45fd-9c1e-26838f6e848c.json


SystemExit: 2

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


[(['<s>', '<d>', '</s>'], 0, None), (['<s>', 'what', 'a', 'nice', 'day', '!', '</s>'], 1, '1'), (['<s>', 'yes', '.', 'how', 'about', 'going', 'out', 'and', 'enjoying', 'the', 'sunshine', 'on', 'the', 'grass', '?', '</s>'], 0, '3'), (['<s>', 'great', ',', 'let', "'", 's', 'go', '!', '</s>'], 1, '4'), (['<s>', 'hey', ',', 'darling', ',', 'i', 'think', 'i', 'might', 'have', 'a', 'little', 'heatstroke', 'from', 'being', 'in', 'the', 'sun', 'all', 'day', '.', 'i', 'was', 'so', 'relaxed', '.', 'it', 'felt', 'as', 'if', 'i', 'were', 'in', 'another', 'world', '.', '</s>'], 0, '1'), (['<s>', 'exactly', '.', 'you', 'know', ',', 'the', 'sunshine', 'and', 'wind', 'remind', 'me', 'of', 'our', 'honeymoon', '.', 'you', 'remember', '?', 'the', 'island', ',', 'the', 'sound', 'of', 'the', 'waves', ',', 'the', 'salty', 'sea', 'air', 'and', 'the', 'sunshine', '...', '</s>'], 1, '1'), (['<s>', 'yes', ',', 'it', 'was', 'wonderful', 'but', 'it', "'", 's', 'already', 'been', 'a', 'year', '.', 'how', 'time', '

In [21]:
er[0]

[[(['<s>', '<d>', '</s>'], 0, None),
  (['<s>',
    'hey',
    'man',
    ',',
    'you',
    'wanna',
    'buy',
    'some',
    'weed',
    '?',
    '</s>'],
   1,
   '3'),
  (['<s>', 'some', 'what', '?', '</s>'], 0, '2'),
  (['<s>',
    'weed',
    '!',
    'you',
    'know',
    '?',
    'pot',
    ',',
    'ganja',
    ',',
    'mary',
    'jane',
    'some',
    'chronic',
    '!',
    '</s>'],
   1,
   '3'),
  (['<s>', 'oh', ',', 'umm', ',', 'no', 'thanks', '.', '</s>'], 0, '4'),
  (['<s>',
    'i',
    'also',
    'have',
    'blow',
    'if',
    'you',
    'prefer',
    'to',
    'do',
    'a',
    'few',
    'lines',
    '.',
    '</s>'],
   1,
   '3'),
  (['<s>', 'no', ',', 'i', 'am', 'ok', ',', 'really', '.', '</s>'], 0, '4'),
  (['<s>',
    'come',
    'on',
    'man',
    '!',
    'i',
    'even',
    'got',
    'dope',
    'and',
    'acid',
    '!',
    'try',
    'some',
    '!',
    '</s>'],
   1,
   '3'),
  (['<s>',
    'do',
    'you',
    'really',
    'have',
   