Skip to content

Commit

Permalink
Fix HAN imports and args
Browse files Browse the repository at this point in the history
  • Loading branch information
achyudh committed Mar 16, 2019
1 parent e801957 commit 3c52648
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 11 deletions.
11 changes: 5 additions & 6 deletions models/han/__main__.py
@@ -1,6 +1,6 @@
from copy import deepcopy
import logging
import random
from copy import deepcopy

import numpy as np
import torch
Expand All @@ -10,12 +10,12 @@
from common.train import TrainerFactory
from datasets.aapd import AAPDHierarchical as AAPD
from datasets.imdb import IMDBHierarchical as IMDB
from datasets.reuters import ReutersHierarchical as Reuters
from datasets.sst import SST1
from datasets.sst import SST2
from datasets.reuters import ReutersHierarchical as Reuters
from datasets.yelp2014 import Yelp2014Hierarchical as Yelp2014
from models.han import get_args
from models.han import HAN
from datasets.yelp2014 import Yelp2014Hierarchical as Yelp2014
from models.han.model import HAN
from models.han.args import get_args


class UnknownWordVecCache(object):
Expand Down Expand Up @@ -136,7 +136,6 @@ def evaluate_dataset(split_name, dataset_cls, model, embedding, loader, batch_si
'optimizer': optimizer,
'batch_size': args.batch_size,
'log_interval': args.log_every,
'dev_log_interval': args.dev_every,
'patience': args.patience,
'model_outfile': args.save_path, # actually a directory, using model_outfile to conform to Trainer naming convention
'logger': logger,
Expand Down
5 changes: 2 additions & 3 deletions models/han/args.py
Expand Up @@ -29,9 +29,8 @@ def get_args():
parser.add_argument('--epoch_decay', type=int, default=15)
parser.add_argument('--data_dir', help='word vectors directory',
default=os.path.join(os.pardir, 'Castor-data', 'datasets'))
parser.add_argument('--word_vectors_dir', help='word vectors directory',
default=os.path.join(os.pardir, 'Castor-data', 'embeddings', 'word2vec'))
parser.add_argument('--word_vectors_file', help='word vectors filename', default='/data/GoogleNews-vectors-negative300.txt')
parser.add_argument('--word-vectors-dir', default=os.path.join(os.pardir, 'Castor-data', 'embeddings', 'word2vec'))
parser.add_argument('--word-vectors-file', default='GoogleNews-vectors-negative300.txt')
parser.add_argument('--trained_model', type=str, default="")
parser.add_argument('--weight_decay', type=float, default=0)
parser.add_argument('--onnx', action='store_true', default=False, help='Export model in ONNX format')
Expand Down
4 changes: 2 additions & 2 deletions models/han/model.py
@@ -1,8 +1,8 @@
import torch
import torch.nn as nn

from models.han import SentLevelRNN
from models.han import WordLevelRNN
from models.han.sent_level_rnn import SentLevelRNN
from models.han.word_level_rnn import WordLevelRNN


class HAN(nn.Module):
Expand Down

0 comments on commit 3c52648

Please sign in to comment.