import argparse
import os
import random
from torch import utils
from tqdm import tqdm
import numpy as np
import torch
import torch.nn as nn
import data
import model
class RandomSearch(object):
def __init__(self, params):
self.params = params
def __iter__(self):
param_space = list(GridSearch(self.params))
for param in param_space:
yield param
class GridSearch(object):
def __init__(self, params):
self.params = params
self.param_lengths = [len(param) for param in self.params]
self.indices = [1] * len(params)
def _update(self, carry_idx):
if carry_idx >= len(self.params):
return True
if self.indices[carry_idx] < self.param_lengths[carry_idx]:
self.indices[carry_idx] += 1
return False
self.indices[carry_idx] = 1
return False or self._update(carry_idx + 1)
def __iter__(self):
self.stop_next = False
self.indices = [1] * len(self.params)
return self
def __next__(self):
if self.stop_next:
raise StopIteration
result = [param[idx - 1] for param, idx in zip(self.params, self.indices)]
self.indices[0] += 1
if self.indices[0] == self.param_lengths[0] + 1:
self.indices[0] = 1
self.stop_next = self._update(1)
return result
def train(**kwargs):
mbatch_size = kwargs["mbatch_size"]
n_epochs = kwargs["n_epochs"]
restore = kwargs["restore"]
verbose = not kwargs["quiet"]
lr = kwargs["lr"]
weight_decay = kwargs["weight_decay"]
seed = kwargs["seed"]
if not kwargs["no_cuda"]:
embed_loader = data.SSTEmbeddingLoader("data")
if restore:
conv_rnn = torch.load(kwargs["input_file"])
id_dict, weights, unk_vocab_list = embed_loader.load_embed_data()
word_model = model.SSTWordEmbeddingModel(id_dict, weights, unk_vocab_list)
if not kwargs["no_cuda"]:
conv_rnn = model.ConvRNNModel(word_model, **kwargs)
if not kwargs["no_cuda"]:
criterion = nn.CrossEntropyLoss()
parameters = list(filter(lambda p: p.requires_grad, conv_rnn.parameters()))
optimizer = torch.optim.Adam(parameters, lr=lr, weight_decay=weight_decay)
train_set, dev_set, test_set = data.SSTDataset.load_sst_sets("data")
collate_fn = conv_rnn.convert_dataset
train_loader =, shuffle=True, batch_size=mbatch_size, collate_fn=collate_fn)
dev_loader =, batch_size=len(dev_set), collate_fn=collate_fn)
test_loader =, batch_size=len(test_set), collate_fn=collate_fn)
def evaluate(loader, dev=True):
for m_in, m_out in loader:
scores = conv_rnn(*m_in)
loss = criterion(scores, m_out).item()
n_correct = (torch.max(scores, 1)[1].view(m_in[0].size(0)).data ==
accuracy = n_correct / m_in[0].size(0)
if dev and accuracy >= evaluate.best_dev:
evaluate.best_dev = accuracy
print("Saving best model ({})...".format(accuracy)), kwargs["output_file"])
if verbose:
print("{} set accuracy: {}, loss: {}".format("dev" if dev else "test", accuracy, loss))
evaluate.best_dev = 0
for epoch in range(n_epochs):
print("Epoch number: {}".format(epoch), end="\r")
if verbose:
i = 0
for (j, (train_in, train_out)), _ in zip(enumerate(train_loader), tqdm(range(len(train_loader)))):
scores = conv_rnn(*train_in)
loss = criterion(scores, train_out)
evaluate(test_loader, dev=False)
return evaluate.best_dev
def do_random_search(given_params):
test_grid = [[0.15, 0.2], [4, 5, 6], [150, 200], [3, 4, 5], [200, 300], [200, 250]]
max_params = None
max_acc = 0.
for args in RandomSearch(test_grid):
sf, gc, hid, seed, fc_size, fmaps = args
print("Testing {}".format(args))
given_params.update(dict(n_epochs=7, quiet=True, gradient_clip=gc, hidden_Size=hid, seed=seed,
n_feature_maps=fmaps, fc_size=fc_size))
dev_acc = train(**given_params)
print("Dev accuracy: {}".format(dev_acc))
if dev_acc > max_acc:
print("Found current max")
max_acc = dev_acc
max_params = args
print("Best params: {}".format(max_params))
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--dev_per_epoch", default=9, type=int)
parser.add_argument("--fc_size", default=200, type=int)
parser.add_argument("--gpu_number", default=0, type=int)
parser.add_argument("--hidden_size", default=150, type=int)
parser.add_argument("--input_file", default="saves/", type=str)
parser.add_argument("--lr", default=5E-4, type=float)
parser.add_argument("--mbatch_size", default=16, type=int)
parser.add_argument("--n_epochs", default=30, type=int)
parser.add_argument("--n_feature_maps", default=200, type=float)
parser.add_argument("--n_labels", default=5, type=int)
parser.add_argument("--no_cuda", action="store_true", default=False)
parser.add_argument("--output_file", default="saves/", type=str)
parser.add_argument("--random_search", action="store_true", default=False)
parser.add_argument("--restore", action="store_true", default=False)
parser.add_argument("--rnn_type", choices=["lstm", "gru"], default="lstm", type=str)
parser.add_argument("--seed", default=3, type=int)
parser.add_argument("--quiet", action="store_true", default=False)
parser.add_argument("--weight_decay", default=1E-3, type=float)
args = parser.parse_args()
if args.random_search:
if __name__ == "__main__":
