Skip to content
Branch: master
Find file Copy path
Find file Copy path
Fetching contributors…
Cannot retrieve contributors at this time
167 lines (148 sloc) 6.38 KB
import argparse
import os
import random
from torch import utils
from tqdm import tqdm
import numpy as np
import torch
import torch.nn as nn
import data
import model
class RandomSearch(object):
def __init__(self, params):
self.params = params
def __iter__(self):
param_space = list(GridSearch(self.params))
for param in param_space:
yield param
class GridSearch(object):
def __init__(self, params):
self.params = params
self.param_lengths = [len(param) for param in self.params]
self.indices = [1] * len(params)
def _update(self, carry_idx):
if carry_idx >= len(self.params):
return True
if self.indices[carry_idx] < self.param_lengths[carry_idx]:
self.indices[carry_idx] += 1
return False
self.indices[carry_idx] = 1
return False or self._update(carry_idx + 1)
def __iter__(self):
self.stop_next = False
self.indices = [1] * len(self.params)
return self
def __next__(self):
if self.stop_next:
raise StopIteration
result = [param[idx - 1] for param, idx in zip(self.params, self.indices)]
self.indices[0] += 1
if self.indices[0] == self.param_lengths[0] + 1:
self.indices[0] = 1
self.stop_next = self._update(1)
return result
def train(**kwargs):
mbatch_size = kwargs["mbatch_size"]
n_epochs = kwargs["n_epochs"]
restore = kwargs["restore"]
verbose = not kwargs["quiet"]
lr = kwargs["lr"]
weight_decay = kwargs["weight_decay"]
seed = kwargs["seed"]
if not kwargs["no_cuda"]:
embed_loader = data.SSTEmbeddingLoader("data")
if restore:
conv_rnn = torch.load(kwargs["input_file"])
id_dict, weights, unk_vocab_list = embed_loader.load_embed_data()
word_model = model.SSTWordEmbeddingModel(id_dict, weights, unk_vocab_list)
if not kwargs["no_cuda"]:
conv_rnn = model.ConvRNNModel(word_model, **kwargs)
if not kwargs["no_cuda"]:
criterion = nn.CrossEntropyLoss()
parameters = list(filter(lambda p: p.requires_grad, conv_rnn.parameters()))
optimizer = torch.optim.Adam(parameters, lr=lr, weight_decay=weight_decay)
train_set, dev_set, test_set = data.SSTDataset.load_sst_sets("data")
collate_fn = conv_rnn.convert_dataset
train_loader =, shuffle=True, batch_size=mbatch_size, collate_fn=collate_fn)
dev_loader =, batch_size=len(dev_set), collate_fn=collate_fn)
test_loader =, batch_size=len(test_set), collate_fn=collate_fn)
def evaluate(loader, dev=True):
for m_in, m_out in loader:
scores = conv_rnn(*m_in)
loss = criterion(scores, m_out).item()
n_correct = (torch.max(scores, 1)[1].view(m_in[0].size(0)).data ==
accuracy = n_correct / m_in[0].size(0)
if dev and accuracy >= evaluate.best_dev:
evaluate.best_dev = accuracy
print("Saving best model ({})...".format(accuracy)), kwargs["output_file"])
if verbose:
print("{} set accuracy: {}, loss: {}".format("dev" if dev else "test", accuracy, loss))
evaluate.best_dev = 0
for epoch in range(n_epochs):
print("Epoch number: {}".format(epoch), end="\r")
if verbose:
i = 0
for (j, (train_in, train_out)), _ in zip(enumerate(train_loader), tqdm(range(len(train_loader)))):
scores = conv_rnn(*train_in)
loss = criterion(scores, train_out)
evaluate(test_loader, dev=False)
return evaluate.best_dev
def do_random_search(given_params):
test_grid = [[0.15, 0.2], [4, 5, 6], [150, 200], [3, 4, 5], [200, 300], [200, 250]]
max_params = None
max_acc = 0.
for args in RandomSearch(test_grid):
sf, gc, hid, seed, fc_size, fmaps = args
print("Testing {}".format(args))
given_params.update(dict(n_epochs=7, quiet=True, gradient_clip=gc, hidden_Size=hid, seed=seed,
n_feature_maps=fmaps, fc_size=fc_size))
dev_acc = train(**given_params)
print("Dev accuracy: {}".format(dev_acc))
if dev_acc > max_acc:
print("Found current max")
max_acc = dev_acc
max_params = args
print("Best params: {}".format(max_params))
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--dev_per_epoch", default=9, type=int)
parser.add_argument("--fc_size", default=200, type=int)
parser.add_argument("--gpu_number", default=0, type=int)
parser.add_argument("--hidden_size", default=150, type=int)
parser.add_argument("--input_file", default="saves/", type=str)
parser.add_argument("--lr", default=5E-4, type=float)
parser.add_argument("--mbatch_size", default=16, type=int)
parser.add_argument("--n_epochs", default=30, type=int)
parser.add_argument("--n_feature_maps", default=200, type=float)
parser.add_argument("--n_labels", default=5, type=int)
parser.add_argument("--no_cuda", action="store_true", default=False)
parser.add_argument("--output_file", default="saves/", type=str)
parser.add_argument("--random_search", action="store_true", default=False)
parser.add_argument("--restore", action="store_true", default=False)
parser.add_argument("--rnn_type", choices=["lstm", "gru"], default="lstm", type=str)
parser.add_argument("--seed", default=3, type=int)
parser.add_argument("--quiet", action="store_true", default=False)
parser.add_argument("--weight_decay", default=1E-3, type=float)
args = parser.parse_args()
if args.random_search:
if __name__ == "__main__":
You can’t perform that action at this time.