In [1]:
import sys
import os
import time
import importlib
import argparse

import numpy as np

import torch
from torch import nn, optim

from data import MonoTextData
from modules import VAE
from modules import LSTMEncoder, LSTMDecoder
from logger import Logger

In [2]:
parser = argparse.ArgumentParser(description='VAE mode collapse study')

# model hyperparameters
parser.add_argument('--dataset', type=str,default='dataset', help='dataset to use')

# optimization parameters
parser.add_argument('--momentum', type=float, default=0, help='sgd momentum')
parser.add_argument('--nsamples', type=int, default=1, help='number of samples for training')
parser.add_argument('--iw_nsamples', type=int, default=500,
                     help='number of samples to compute importance weighted estimate')

# select mode
parser.add_argument('--eval', action='store_true', default=False, help='compute iw nll')
parser.add_argument('--load_path', type=str, default='')


# decoding
parser.add_argument('--decode_from', type=str, default="", help="pretrained model path")
parser.add_argument('--decoding_strategy', type=str, choices=["greedy", "beam", "sample"], default="greedy")
parser.add_argument('--decode_input', type=str, default="", help="input text file to perform reconstruction")


# annealing paramters
parser.add_argument('--warm_up', type=int, default=10, help="number of annealing epochs")
parser.add_argument('--kl_start', type=float, default=1.0, help="starting KL weight")

# inference parameters
parser.add_argument('--aggressive', type=int, default=0,
                     help='apply aggressive training when nonzero, reduce to vanilla VAE when aggressive is 0')
# others
parser.add_argument('--seed', type=int, default=783435, metavar='S', help='random seed')

# these are for slurm purpose to save model
parser.add_argument('--jobid', type=int, default=0, help='slurm job id')
parser.add_argument('--taskid', type=int, default=0, help='slurm task id')


args = parser.parse_args([])

In [3]:
args

Namespace(aggressive=0, dataset='dataset', decode_from='', decode_input='', decoding_strategy='greedy', eval=False, iw_nsamples=500, jobid=0, kl_start=1.0, load_path='', momentum=0, nsamples=1, seed=783435, taskid=0, warm_up=10)

In [4]:

args.cuda = torch.cuda.is_available()

save_dir = "models/%s" % args.dataset
log_dir = "logs/%s" % args.dataset

if not os.path.exists(save_dir):
    os.makedirs(save_dir)

if not os.path.exists(log_dir):
    os.makedirs(log_dir)

seed_set = [783435, 101, 202, 303, 404, 505, 606, 707, 808, 909]
args.seed = seed_set[args.taskid]

id_ = "%s_aggressive%d_kls%.2f_warm%d_%d_%d_%d" % \
        (args.dataset, args.aggressive, args.kl_start,
         args.warm_up, args.jobid, args.taskid, args.seed)

save_path = os.path.join(save_dir, id_ + '.pt')

args.save_path = save_path
print("save path", args.save_path)

args.log_path = os.path.join(log_dir, id_ + ".log")
print("log path", args.log_path)

#     # load config file into args
#     config_file = "config.config_%s" % args.dataset
#     params = importlib.import_module(config_file).params
#     args = argparse.Namespace(**vars(args), **params)

#     if 'label' in params:
#         args.label = params['label']
#     else:
#         args.label = False

np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.cuda:
    torch.cuda.manual_seed(args.seed)
    torch.backends.cudnn.deterministic = True



save path models/dataset\dataset_aggressive0_kls1.00_warm10_0_0_783435.pt
log path logs/dataset\dataset_aggressive0_kls1.00_warm10_0_0_783435.log


In [5]:
args.train_data ="datasets/yahoo_data/yahoo.train.txt"
train_data = MonoTextData(args.train_data)

In [8]:
train_data.vocab

<data.text_data.VocabEntry at 0x23b550103d0>

In [14]:
i=0
# fname="datasets/yahoo_data/yahoo.train.txt"
fname="../../Data/train.txt"
with open(fname,encoding='UTF-8') as fin:
    
    for line in fin:
        if i>2:
            break
        split_line = line.split()
        print(split_line)
        
        i+=1

['"', 'The', 'thickness', 'of', 'the', 'dielectric', 'material', 'of', 'an', 'integrated', 'circuit', 'on', 'top', 'of', 'which', 'is', 'provided', 'a', 'semiconductor', 'layer,', 'is', 'selected', 'to', 'be', 'an', 'integer', 'multiple', 'of', 'one-half', 'the', 'wavelength', 'of', 'the', 'laser', 'light', 'in', 'the', 'dielectric', 'material', 'in', 'order', 'to', 'make', 'the', 'dielectric', 'material', 'layer', 'invisible', 'to', 'the', 'laser-trimming', 'light.', '"']
['Individually', 'produced', 'cams', 'and', 'journals', 'are', 'fastened', 'to', 'a', 'hollow', 'tube', 'to', 'form', 'a', 'camshaft', 'by', 'outwordly', 'deforming', 'the', 'tube', 'with', 'a', 'lost', 'mandrel', 'which', 'is', 'left', 'in', 'the', 'tube', 'to', 'form', 'a', 'seal.']
['"', 'Corrosion', 'inhibiting', 'calcium-containing', 'amorphous', 'precipitated', 'silica', 'is', 'described.', 'The', 'silica', 'is', 'prepared', 'by', 'admixing', 'simultaneously', 'in', 'a', 'reactor', 'aqueous', 'alkali', 'metal',