From 5829aee35c00610105d9ed5441827d486912a137 Mon Sep 17 00:00:00 2001 From: alsuhr Date: Tue, 12 Jun 2018 14:58:16 -0400 Subject: [PATCH] more files --- snippets.py | 105 +++++++++++ sql_util.py | 410 +++++++++++++++++++++++++++++++++++++++++ token_predictor.py | 387 ++++++++++++++++++++++++++++++++++++++ tokenizers.py | 82 +++++++++ util.py | 16 ++ utterance.py | 115 ++++++++++++ utterance_model.py | 171 +++++++++++++++++ visualize_attention.py | 113 ++++++++++++ vocabulary.py | 77 ++++++++ 9 files changed, 1476 insertions(+) create mode 100644 snippets.py create mode 100644 sql_util.py create mode 100644 token_predictor.py create mode 100644 tokenizers.py create mode 100644 util.py create mode 100644 utterance.py create mode 100644 utterance_model.py create mode 100644 visualize_attention.py create mode 100644 vocabulary.py diff --git a/snippets.py b/snippets.py new file mode 100644 index 0000000..43ff233 --- /dev/null +++ b/snippets.py @@ -0,0 +1,105 @@ +""" Contains the Snippet class and methods for handling snippets. + +Attributes: + SNIPPET_PREFIX: string prefix for snippets. +""" + +SNIPPET_PREFIX = "SNIPPET_" + + +def is_snippet(token): + """ Determines whether a token is a snippet or not. + + Inputs: + token (str): The token to check. + + Returns: + bool, indicating whether it's a snippet. + """ + return token.startswith(SNIPPET_PREFIX) + +def expand_snippets(sequence, snippets): + """ Given a sequence and a list of snippets, expand the snippets in the sequence. + + Inputs: + sequence (list of str): Query containing snippet references. + snippets (list of Snippet): List of available snippets. + + return list of str representing the expanded sequence + """ + snippet_id_to_snippet = {} + for snippet in snippets: + assert snippet.name not in snippet_id_to_snippet + snippet_id_to_snippet[snippet.name] = snippet + expanded_seq = [] + for token in sequence: + if token in snippet_id_to_snippet: + expanded_seq.extend(snippet_id_to_snippet[token].sequence) + else: + assert not is_snippet(token) + expanded_seq.append(token) + + return expanded_seq + +def snippet_index(token): + """ Returns the index of a snippet. + + Inputs: + token (str): The snippet to check. + + Returns: + integer, the index of the snippet. + """ + assert is_snippet(token) + return int(token.split("_")[-1]) + + +class Snippet(): + """ Contains a snippet. """ + def __init__(self, + sequence, + startpos, + sql, + age=0): + self.sequence = sequence + self.startpos = startpos + self.sql = sql + + # TODO: age vs. index? + self.age = age + self.index = 0 + + self.name = "" + self.embedding = None + + self.endpos = self.startpos + len(self.sequence) + assert self.endpos < len(self.sql), "End position of snippet is " + str( + self.endpos) + " which is greater than length of SQL (" + str(len(self.sql)) + ")" + assert self.sequence == self.sql[self.startpos:self.endpos], \ + "Value of snippet (" + " ".join(self.sequence) + ") " \ + "is not the same as SQL at the same positions (" \ + + " ".join(self.sql[self.startpos:self.endpos]) + ")" + + def __str__(self): + return self.name + "\t" + \ + str(self.age) + "\t" + " ".join(self.sequence) + + def __len__(self): + return len(self.sequence) + + def increase_age(self): + """ Ages a snippet by one. """ + self.index += 1 + + def assign_id(self, number): + """ Assigns the name of the snippet to be the prefix + the number. """ + self.name = SNIPPET_PREFIX + str(number) + + def set_embedding(self, embedding): + """ Sets the embedding of the snippet. + + Inputs: + embedding (dy.Expression) + + """ + self.embedding = embedding diff --git a/sql_util.py b/sql_util.py new file mode 100644 index 0000000..2aa0150 --- /dev/null +++ b/sql_util.py @@ -0,0 +1,410 @@ +import copy +import pymysql +import random +import signal +import sqlparse +import util + +from snippets import Snippet +from sqlparse import tokens as token_types +from sqlparse import sql as sql_types + +interesting_selects = ["DISTINCT", "MAX", "MIN", "count"] +ignored_subtrees = [["1", "=", "1"]] + +# strip_whitespace_front +# Strips whitespace and punctuation from the front of a SQL token list. +# +# Inputs: +# token_list: the token list. +# +# Outputs: +# new token list. + + +def strip_whitespace_front(token_list): + new_token_list = [] + found_valid = False + + for token in token_list: + if not (token.is_whitespace or token.ttype == + token_types.Punctuation) or found_valid: + found_valid = True + new_token_list.append(token) + + return new_token_list + +# strip_whitespace +# Strips whitespace from a token list. +# +# Inputs: +# token_list: the token list. +# +# Outputs: +# new token list with no whitespace/punctuation surrounding. + + +def strip_whitespace(token_list): + subtokens = strip_whitespace_front(token_list) + subtokens = strip_whitespace_front(subtokens[::-1])[::-1] + return subtokens + +# token_list_to_seq +# Converts a Token list to a sequence of strings, stripping out surrounding +# punctuation and all whitespace. +# +# Inputs: +# token_list: the list of tokens. +# +# Outputs: +# list of strings + + +def token_list_to_seq(token_list): + subtokens = strip_whitespace(token_list) + + seq = [] + flat = sqlparse.sql.TokenList(subtokens).flatten() + for i, token in enumerate(flat): + strip_token = str(token).strip() + if len(strip_token) > 0: + seq.append(strip_token) + if len(seq) > 0: + if seq[0] == "(" and seq[-1] == ")": + seq = seq[1:-1] + + return seq + +# TODO: clean this up +# find_subtrees +# Finds subtrees for a subsequence of SQL. +# +# Inputs: +# sequence: sequence of SQL tokens. +# current_subtrees: current list of subtrees. +# +# Optional inputs: +# where_parent: whether the parent of the current sequence was a where clause +# keep_conj_subtrees: whether to look for a conjunction in this sequence and +# keep its arguments + + +def find_subtrees(sequence, + current_subtrees, + where_parent=False, + keep_conj_subtrees=False): + # If the parent of the subsequence was a WHERE clause, keep everything in the + # sequence except for the beginning WHERE and any surrounding parentheses. + if where_parent: + # Strip out the beginning WHERE, and any punctuation or whitespace at the + # beginning or end of the token list. + seq = token_list_to_seq(sequence.tokens[1:]) + if len(seq) > 0 and seq not in current_subtrees: + current_subtrees.append(seq) + + # If the current sequence has subtokens, i.e. if it's a node that can be + # expanded, check for a conjunction in its subtrees, and expand its subtrees. + # Also check for any SELECT statements and keep track of what follows. + if sequence.is_group: + if keep_conj_subtrees: + subtokens = strip_whitespace(sequence.tokens) + + # Check if there is a conjunction in the subsequence. If so, keep the + # children. Also make sure you don't split where AND is used within a + # child -- the subtokens sequence won't treat those ANDs differently (a + # bit hacky but it works) + has_and = False + for i, token in enumerate(subtokens): + if token.value == "OR" or token.value == "AND": + has_and = True + break + + if has_and: + and_subtrees = [] + current_subtree = [] + for i, token in enumerate(subtokens): + if token.value == "OR" or (token.value == "AND" and i - 4 >= 0 and i - 4 < len( + subtokens) and subtokens[i - 4].value != "BETWEEN"): + and_subtrees.append(current_subtree) + current_subtree = [] + else: + current_subtree.append(token) + and_subtrees.append(current_subtree) + + for subtree in and_subtrees: + seq = token_list_to_seq(subtree) + if len(seq) > 0 and seq[0] == "WHERE": + seq = seq[1:] + if seq not in current_subtrees: + current_subtrees.append(seq) + + in_select = False + select_toks = [] + for i, token in enumerate(sequence.tokens): + # Mark whether this current token is a WHERE. + is_where = (isinstance(token, sql_types.Where)) + + # If you are in a SELECT, start recording what follows until you hit a + # FROM + if token.value == "SELECT": + in_select = True + elif in_select: + select_toks.append(token) + if token.value == "FROM": + in_select = False + + seq = [] + if len(sequence.tokens) > i + 2: + seq = token_list_to_seq( + select_toks + [sequence.tokens[i + 2]]) + + if seq not in current_subtrees and len( + seq) > 0 and seq[0] in interesting_selects: + current_subtrees.append(seq) + + select_toks = [] + + # Recursively find subtrees in the children of the node. + find_subtrees(token, + current_subtrees, + is_where, + where_parent or keep_conj_subtrees) + +# get_subtrees + + +def get_subtrees(sql, oldsnippets=[]): + parsed = sqlparse.parse(" ".join(sql))[0] + + subtrees = [] + find_subtrees(parsed, subtrees) + + final_subtrees = [] + for subtree in subtrees: + if subtree not in ignored_subtrees: + final_version = [] + keep = True + + parens_counts = 0 + for i, token in enumerate(subtree): + if token == ".": + newtoken = final_version[-1] + "." + subtree[i + 1] + final_version = final_version[:-1] + [newtoken] + keep = False + elif keep: + final_version.append(token) + else: + keep = True + + if token == "(": + parens_counts -= 1 + elif token == ")": + parens_counts += 1 + + if parens_counts == 0: + final_subtrees.append(final_version) + + snippets = [] + sql = [str(tok) for tok in sql] + for subtree in final_subtrees: + startpos = -1 + for i in range(len(sql) - len(subtree) + 1): + if sql[i:i + len(subtree)] == subtree: + startpos = i + if startpos >= 0 and startpos + len(subtree) < len(sql): + age = 0 + for prevsnippet in oldsnippets: + if prevsnippet.sequence == subtree: + age = prevsnippet.age + 1 + snippet = Snippet(subtree, startpos, sql, age=age) + snippets.append(snippet) + + return snippets + + +conjunctions = {"AND", "OR", "WHERE"} + + +def get_all_in_parens(sequence): + if sequence[-1] == ";": + sequence = sequence[:-1] + + if not "(" in sequence: + return [] + + if sequence[0] == "(" and sequence[-1] == ")": + in_parens = sequence[1:-1] + return [in_parens] + get_all_in_parens(in_parens) + else: + paren_subseqs = [] + current_seq = [] + num_parens = 0 + in_parens = False + for token in sequence: + if in_parens: + current_seq.append(token) + if token == ")": + num_parens -= 1 + if num_parens == 0: + in_parens = False + paren_subseqs.append(current_seq) + current_seq = [] + elif token == "(": + in_parens = True + current_seq.append(token) + if token == "(": + num_parens += 1 + + all_subseqs = [] + for subseq in paren_subseqs: + all_subseqs.extend(get_all_in_parens(subseq)) + return all_subseqs + + +def split_by_conj(sequence): + num_parens = 0 + current_seq = [] + subsequences = [] + + for token in sequence: + if num_parens == 0: + if token in conjunctions: + subsequences.append(current_seq) + current_seq = [] + break + current_seq.append(token) + if token == "(": + num_parens += 1 + elif token == ")": + num_parens -= 1 + + assert num_parens >= 0 + + return subsequences + + +def get_sql_snippets(sequence): + # First, get all subsequences of the sequence that are surrounded by + # parentheses. + all_in_parens = get_all_in_parens(sequence) + all_subseq = [] + + # Then for each one, split the sequence on conjunctions (AND/OR). + for seq in all_in_parens: + subsequences = split_by_conj(seq) + all_subseq.append(seq) + all_subseq.extend(subsequences) + + # Finally, also get "interesting" selects + + for i, seq in enumerate(all_subseq): + print(str(i) + "\t" + " ".join(seq)) + exit() + +# add_snippets_to_query + + +def add_snippets_to_query(snippets, ignored_entities, query, prob_align=1.): + query_copy = copy.copy(query) + + # Replace the longest snippets first, so sort by length descending. + sorted_snippets = sorted(snippets, key=lambda s: len(s.sequence))[::-1] + + for snippet in sorted_snippets: + ignore = False + snippet_seq = snippet.sequence + + # TODO: continue here + # If it contains an ignored entity, then don't use it. + for entity in ignored_entities: + ignore = ignore or util.subsequence(entity, snippet_seq) + + # No NL entities found in snippet, then see if snippet is a substring of + # the gold sequence + if not ignore: + snippet_length = len(snippet_seq) + + # Iterate through gold sequence to see if it's a subsequence. + for start_idx in range(len(query_copy) - snippet_length + 1): + if query_copy[start_idx:start_idx + + snippet_length] == snippet_seq: + align = random.random() < prob_align + + if align: + prev_length = len(query_copy) + + # At the start position of the snippet, replace with an + # identifier. + query_copy[start_idx] = snippet.name + + # Then cut out the indices which were collapsed into + # the snippet. + query_copy = query_copy[:start_idx + 1] + \ + query_copy[start_idx + snippet_length:] + + # Make sure the length is as expected + assert len(query_copy) == prev_length - \ + (snippet_length - 1) + + return query_copy + + +def execution_results(query, username, password, timeout=3): + connection = pymysql.connect(user=username, password=password) + + class TimeoutException(Exception): + pass + + def timeout_handler(signum, frame): + raise TimeoutException + + signal.signal(signal.SIGALRM, timeout_handler) + + syntactic = True + semantic = True + + table = [] + + with connection.cursor() as cursor: + signal.alarm(timeout) + try: + cursor.execute("SET sql_mode='IGNORE_SPACE';") + cursor.execute("use atis3;") + cursor.execute(query) + table = cursor.fetchall() + cursor.close() + except TimeoutException: + signal.alarm(0) + cursor.close() + except pymysql.err.ProgrammingError: + syntactic = False + semantic = False + cursor.close() + except pymysql.err.InternalError: + semantic = False + cursor.close() + except Exception as e: + signal.alarm(0) + signal.alarm(0) + cursor.close() + signal.alarm(0) + + connection.close() + + return (syntactic, semantic, sorted(table)) + + +def executable(query, username, password, timeout=2): + return execution_results(query, username, password, timeout)[1] + + +def fix_parentheses(sequence): + num_left = sequence.count("(") + num_right = sequence.count(")") + + if num_right < num_left: + fixed_sequence = sequence[:-1] + \ + [")" for _ in range(num_left - num_right)] + [sequence[-1]] + return fixed_sequence + + return sequence diff --git a/token_predictor.py b/token_predictor.py new file mode 100644 index 0000000..4a9bb42 --- /dev/null +++ b/token_predictor.py @@ -0,0 +1,387 @@ +"""Predicts a token.""" + +from collections import namedtuple + +import dynet as dy +import dynet_utils as du + +from attention import Attention + +class PredictionInput(namedtuple('PredictionInput', + ('decoder_state', + 'input_hidden_states', + 'snippets', + 'input_sequence'))): + """ Inputs to the token predictor. """ + __slots__ = () + + +class TokenPrediction(namedtuple('TokenPrediction', + ('scores', + 'aligned_tokens', + 'attention_results', + 'decoder_state'))): + + """A token prediction. + + Attributes: + scores (dy.Expression): Scores for each possible output token. + aligned_tokens (list of str): The output tokens, aligned with the scores. + attention_results (AttentionResult): The result of attending on the input + sequence. + """ + __slots__ = () + + +def score_snippets(snippets, scorer): + """ Scores snippets given a scorer. + + Inputs: + snippets (list of Snippet): The snippets to score. + scorer (dy.Expression): Dynet vector against which to score the snippets. + + Returns: + dy.Expression, list of str, where the first is the scores and the second + is the names of the snippets that were scored. + """ + snippet_expressions = [snippet.embedding for snippet in snippets] + all_snippet_embeddings = dy.concatenate(snippet_expressions, d=1) + + if du.is_vector(scorer): + scorer = du.add_dim(scorer) + + scores = dy.transpose(dy.transpose(scorer) * all_snippet_embeddings) + + if scores.dim()[0][0] != len(snippets): + raise ValueError("Got " + str(scores.dim()[0][0]) + " scores for " + + str(len(snippets)) + " snippets") + + return scores, [snippet.name for snippet in snippets] + + +class TokenPredictor(): + """ Predicts a token given a (decoder) state. + + Attributes: + vocabulary (Vocabulary): A vocabulary object for the output. + attention_module (Attention): An attention module. + state_transformation_weights (dy.Parameters): Transforms the input state + before predicting a token. + vocabulary_weights (dy.Parameters): Final layer weights. + vocabulary_biases (dy.Parameters): Final layer biases. + """ + + def __init__(self, model, params, vocabulary, attention_key_size): + self.vocabulary = vocabulary + self.attention_module = Attention(model, + params.decoder_state_size, + attention_key_size, + attention_key_size) + self.state_transform_weights = du.add_params( + model, + (params.decoder_state_size + + attention_key_size, + params.decoder_state_size), + "weights-state-transform") + self.vocabulary_weights = du.add_params( + model, (params.decoder_state_size, len(vocabulary)), "weights-vocabulary") + self.vocabulary_biases = du.add_params(model, + tuple([len(vocabulary)]), + "biases-vocabulary") + + def _get_intermediate_state(self, state, dropout_amount=0.): + intermediate_state = dy.tanh( + du.linear_layer( + state, self.state_transform_weights)) + return dy.dropout(intermediate_state, dropout_amount) + + def _score_vocabulary_tokens(self, state): + scores = dy.transpose(du.linear_layer(state, + self.vocabulary_weights, + self.vocabulary_biases)) + if scores.dim()[0][0] != len(self.vocabulary.inorder_tokens): + raise ValueError("Got " + + str(scores.dim()[0][0]) + + " scores for " + + str(len(self.vocabulary.inorder_tokens)) + + " vocabulary items") + + return scores, self.vocabulary.inorder_tokens + + def __call__(self, + prediction_input, + dropout_amount=0.): + decoder_state = prediction_input.decoder_state + input_hidden_states = prediction_input.input_hidden_states + + attention_results = self.attention_module(decoder_state, + input_hidden_states) + + state_and_attn = dy.concatenate( + [decoder_state, attention_results.vector]) + + intermediate_state = self._get_intermediate_state( + state_and_attn, dropout_amount=dropout_amount) + vocab_scores, vocab_tokens = self._score_vocabulary_tokens( + intermediate_state) + + return TokenPrediction(vocab_scores, vocab_tokens, attention_results, decoder_state) + + +class SnippetTokenPredictor(TokenPredictor): + """ Token predictor that also predicts snippets. + + Attributes: + snippet_weights (dy.Parameter): Weights for scoring snippets against some + state. + """ + + def __init__( + self, + model, + params, + vocabulary, + attention_key_size, + snippet_size): + TokenPredictor.__init__(self, + model, + params, + vocabulary, + attention_key_size) + if snippet_size <= 0: + raise ValueError("Snippet size must be greater than zero; was " \ + + str(snippet_size)) + self.snippet_weights = du.add_params(model, + (params.decoder_state_size, + snippet_size), + "weights-snippet") + + def _get_snippet_scorer(self, state): + return dy.transpose(du.linear_layer(dy.transpose(state), + self.snippet_weights)) + + def __call__(self, + prediction_input, + dropout_amount=0.): + decoder_state = prediction_input.decoder_state + input_hidden_states = prediction_input.input_hidden_states + snippets = prediction_input.snippets + + attention_results = self.attention_module(decoder_state, + input_hidden_states) + + state_and_attn = dy.concatenate( + [decoder_state, attention_results.vector]) + + intermediate_state = self._get_intermediate_state( + state_and_attn, dropout_amount=dropout_amount) + vocab_scores, vocab_tokens = self._score_vocabulary_tokens( + intermediate_state) + + final_scores = vocab_scores + aligned_tokens = [] + aligned_tokens.extend(vocab_tokens) + + if snippets: + snippet_scores, snippet_tokens = score_snippets( + snippets, + self._get_snippet_scorer(intermediate_state)) + + final_scores = dy.concatenate([final_scores, snippet_scores]) + aligned_tokens.extend(snippet_tokens) + + return TokenPrediction(final_scores, + aligned_tokens, + attention_results, + decoder_state) + + +class AnonymizationTokenPredictor(TokenPredictor): + """ Token predictor that also predicts anonymization tokens. + + Attributes: + anonymizer (Anonymizer): The anonymization object. + + """ + + def __init__(self, + model, + params, + vocabulary, + attention_key_size, + anonymizer): + TokenPredictor.__init__(self, + model, + params, + vocabulary, + attention_key_size) + if not anonymizer: + raise ValueError("Expected an anonymizer, but was None") + self.anonymizer = anonymizer + + def _score_anonymized_tokens(self, + input_sequence, + attention_scores): + scores = [] + tokens = [] + for i, token in enumerate(input_sequence): + if self.anonymizer.is_anon_tok(token): + scores.append(attention_scores[i]) + tokens.append(token) + + if len(scores) > 0: + if len(scores) != len(tokens): + raise ValueError("Got " + str(len(scores)) + " scores for " + + str(len(tokens)) + " anonymized tokens") + + return dy.concatenate(scores), tokens + else: + return None, [] + + def __call__(self, + prediction_input, + dropout_amount=0.): + decoder_state = prediction_input.decoder_state + input_hidden_states = prediction_input.input_hidden_states + input_sequence = prediction_input.input_sequence + assert input_sequence + + attention_results = self.attention_module(decoder_state, + input_hidden_states) + + state_and_attn = dy.concatenate( + [decoder_state, attention_results.vector]) + + intermediate_state = self._get_intermediate_state( + state_and_attn, dropout_amount=dropout_amount) + vocab_scores, vocab_tokens = self._score_vocabulary_tokens( + intermediate_state) + + final_scores = vocab_scores + aligned_tokens = [] + aligned_tokens.extend(vocab_tokens) + + anonymized_scores, anonymized_tokens = self._score_anonymized_tokens( + input_sequence, + attention_results.scores) + + if anonymized_scores: + final_scores = dy.concatenate([final_scores, anonymized_scores]) + aligned_tokens.extend(anonymized_tokens) + + return TokenPrediction(final_scores, + aligned_tokens, + attention_results, + decoder_state) + + +class SnippetAnonymizationTokenPredictor( + SnippetTokenPredictor, + AnonymizationTokenPredictor): + """ Token predictor that both anonymizes and scores snippets.""" + + def __init__(self, + model, + params, + vocabulary, + attention_key_size, + snippet_size, + anonymizer): + SnippetTokenPredictor.__init__(self, + model, + params, + vocabulary, + attention_key_size, + snippet_size) + AnonymizationTokenPredictor.__init__(self, + model, + params, + vocabulary, + attention_key_size, + anonymizer) + + def __call__(self, + prediction_input, + dropout_amount=0.): + decoder_state = prediction_input.decoder_state + assert prediction_input.input_sequence + + snippets = prediction_input.snippets + + attention_results = self.attention_module(decoder_state, + prediction_input.input_hidden_states) + + intermediate_state = self._get_intermediate_state( + dy.concatenate([decoder_state, attention_results.vector]), + dropout_amount=dropout_amount) + + # Vocabulary tokens + final_scores, vocab_tokens = self._score_vocabulary_tokens( + intermediate_state) + + aligned_tokens = [] + aligned_tokens.extend(vocab_tokens) + + # Snippets + if snippets: + snippet_scores, snippet_tokens = score_snippets( + snippets, + self._get_snippet_scorer(intermediate_state)) + + final_scores = dy.concatenate([final_scores, snippet_scores]) + aligned_tokens.extend(snippet_tokens) + + # Anonymized tokens + anonymized_scores, anonymized_tokens = self._score_anonymized_tokens( + prediction_input.input_sequence, + attention_results.scores) + + if anonymized_scores: + final_scores = dy.concatenate([final_scores, anonymized_scores]) + aligned_tokens.extend(anonymized_tokens) + + return TokenPrediction(final_scores, + aligned_tokens, + attention_results, + decoder_state) + + +def construct_token_predictor(parameter_collection, + params, + vocabulary, + attention_key_size, + snippet_size, + anonymizer=None): + """ Constructs a token predictor given the parameters. + + Inputs: + parameter_collection (dy.ParameterCollection): Contains the parameters. + params (dictionary): Contains the command line parameters/hyperparameters. + vocabulary (Vocabulary): Vocabulary object for output generation. + attention_key_size (int): The size of the attention keys. + anonymizer (Anonymizer): An anonymization object. + """ + if params.use_snippets and anonymizer and not params.previous_decoder_snippet_encoding: + return SnippetAnonymizationTokenPredictor(parameter_collection, + params, + vocabulary, + attention_key_size, + snippet_size, + anonymizer) + elif params.use_snippets and not params.previous_decoder_snippet_encoding: + return SnippetTokenPredictor(parameter_collection, + params, + vocabulary, + attention_key_size, + snippet_size) + elif anonymizer: + return AnonymizationTokenPredictor(parameter_collection, + params, + vocabulary, + attention_key_size, + anonymizer) + else: + return TokenPredictor(parameter_collection, + params, + vocabulary, + attention_key_size) diff --git a/tokenizers.py b/tokenizers.py new file mode 100644 index 0000000..f5f5a41 --- /dev/null +++ b/tokenizers.py @@ -0,0 +1,82 @@ +"""Tokenizers for natural language SQL queries, and lambda calculus.""" +import nltk +import sqlparse + +def nl_tokenize(string): + """Tokenizes a natural language string into tokens. + + Inputs: + string: the string to tokenize. + Outputs: + a list of tokens. + + Assumes data is space-separated (this is true of ZC07 data in ATIS2/3). + """ + return nltk.word_tokenize(string) + +def sql_tokenize(string): + """ Tokenizes a SQL statement into tokens. + + Inputs: + string: string to tokenize. + + Outputs: + a list of tokens. + """ + tokens = [] + statements = sqlparse.parse(string) + + # SQLparse gives you a list of statements. + for statement in statements: + # Flatten the tokens in each statement and add to the tokens list. + flat_tokens = sqlparse.sql.TokenList(statement.tokens).flatten() + for token in flat_tokens: + strip_token = str(token).strip() + if len(strip_token) > 0: + tokens.append(strip_token) + + newtokens = [] + keep = True + for i, token in enumerate(tokens): + if token == ".": + newtoken = newtokens[-1] + "." + tokens[i + 1] + newtokens = newtokens[:-1] + [newtoken] + keep = False + elif keep: + newtokens.append(token) + else: + keep = True + + return newtokens + +def lambda_tokenize(string): + """ Tokenizes a lambda-calculus statement into tokens. + + Inputs: + string: a lambda-calculus string + + Outputs: + a list of tokens. + """ + + space_separated = string.split(" ") + + new_tokens = [] + + # Separate the string by spaces, then separate based on existence of ( or + # ). + for token in space_separated: + tokens = [] + + current_token = "" + for char in token: + if char == ")" or char == "(": + tokens.append(current_token) + tokens.append(char) + current_token = "" + else: + current_token += char + tokens.append(current_token) + new_tokens.extend([tok for tok in tokens if tok]) + + return new_tokens diff --git a/util.py b/util.py new file mode 100644 index 0000000..57d773b --- /dev/null +++ b/util.py @@ -0,0 +1,16 @@ +"""Contains various utility functions.""" +def subsequence(first_sequence, second_sequence): + """ + Returns whether the first sequence is a subsequence of the second sequence. + + Inputs: + first_sequence (list): A sequence. + second_sequence (list): Another sequence. + + Returns: + Boolean indicating whether first_sequence is a subsequence of second_sequence. + """ + for startidx in range(len(second_sequence) - len(first_sequence) + 1): + if second_sequence[startidx:startidx + len(first_sequence)] == first_sequence: + return True + return False diff --git a/utterance.py b/utterance.py new file mode 100644 index 0000000..ab15798 --- /dev/null +++ b/utterance.py @@ -0,0 +1,115 @@ +""" Contains the Utterance class. """ + +import sql_util +import tokenizers + +ANON_INPUT_KEY = "cleaned_nl" +OUTPUT_KEY = "sql" + +class Utterance: + """ Utterance class. """ + def process_input_seq(self, + anonymize, + anonymizer, + anon_tok_to_ent): + assert not anon_tok_to_ent or anonymize + assert not anonymize or anonymizer + + if anonymize: + assert anonymizer + + self.input_seq_to_use = anonymizer.anonymize( + self.original_input_seq, anon_tok_to_ent, ANON_INPUT_KEY, add_new_anon_toks=True) + else: + self.input_seq_to_use = self.original_input_seq + + def process_gold_seq(self, + output_sequences, + nl_to_sql_dict, + available_snippets, + anonymize, + anonymizer, + anon_tok_to_ent): + # Get entities in the input sequence: + # anonymized entity types + # othe recognized entities (this includes "flight") + entities_in_input = [ + [tok] for tok in self.input_seq_to_use if tok in anon_tok_to_ent] + entities_in_input.extend( + nl_to_sql_dict.get_sql_entities( + self.input_seq_to_use)) + + # Get the shortest gold query (this is what we use to train) + shortest_gold_and_results = min(output_sequences, + key=lambda x: len(x[0])) + + # Tokenize and anonymize it if necessary. + self.original_gold_query = shortest_gold_and_results[0] + self.gold_sql_results = shortest_gold_and_results[1] + + self.contained_entities = entities_in_input + + # Keep track of all gold queries and the resulting tables so that we can + # give credit if it predicts a different correct sequence. + self.all_gold_queries = output_sequences + + self.anonymized_gold_query = self.original_gold_query + if anonymize: + self.anonymized_gold_query = anonymizer.anonymize( + self.original_gold_query, anon_tok_to_ent, OUTPUT_KEY, add_new_anon_toks=False) + + # Add snippets to it. + self.gold_query_to_use = sql_util.add_snippets_to_query( + available_snippets, entities_in_input, self.anonymized_gold_query) + + def __init__(self, + example, + available_snippets, + nl_to_sql_dict, + params, + anon_tok_to_ent={}, + anonymizer=None): + # Get output and input sequences from the dictionary representation. + output_sequences = example[OUTPUT_KEY] + self.original_input_seq = tokenizers.nl_tokenize(example[params.input_key]) + self.available_snippets = available_snippets + self.keep = False + + pruned_output_sequences = [] + for sequence in output_sequences: + if len(sequence[0]) > 3: + pruned_output_sequences.append(sequence) + + output_sequences = pruned_output_sequences + if len(output_sequences) > 0 and len(self.original_input_seq) > 0: + # Only keep this example if there is at least one output sequence. + self.keep = True + if len(output_sequences) == 0 or len(self.original_input_seq) == 0: + return + + # Process the input sequence + self.process_input_seq(params.anonymize, + anonymizer, + anon_tok_to_ent) + + # Process the gold sequence + self.process_gold_seq(output_sequences, + nl_to_sql_dict, + self.available_snippets, + params.anonymize, + anonymizer, + anon_tok_to_ent) + + def __str__(self): + string = "Original input: " + " ".join(self.original_input_seq) + "\n" + string += "Modified input: " + " ".join(self.input_seq_to_use) + "\n" + string += "Original output: " + " ".join(self.original_gold_query) + "\n" + string += "Modified output: " + " ".join(self.gold_query_to_use) + "\n" + string += "Snippets:\n" + for snippet in self.available_snippets: + string += str(snippet) + "\n" + return string + + def length_valid(self, input_limit, output_limit): + return (len(self.input_seq_to_use) < input_limit \ + and len(self.gold_query_to_use) < output_limit) diff --git a/utterance_model.py b/utterance_model.py new file mode 100644 index 0000000..8dabe8f --- /dev/null +++ b/utterance_model.py @@ -0,0 +1,171 @@ +""" Class for the Sequence to sequence model for ATIS.""" + +import dynet as dy +import dynet_utils as du +import vocabulary as vocab + +def gold_tok_to_id(token, idx_to_token): + """ Maps from a gold token to a list of indices in the probability distribution. + + Inputs: + token (int): The unique ID of the token. + idx_to_token (dict int->str): Maps from indices in the probability + distribution to strings. + """ + if token in idx_to_token: + if len(set(idx_to_token)) == len( + idx_to_token): # no duplicates + return [idx_to_token.index(token)] + else: + indices = [] + for index, check_tok in enumerate(idx_to_token): + if token == check_tok: + indices.append(index) + assert len(indices) == len(set(indices)) + return indices + else: + return [idx_to_token.index(vocab.UNK_TOK)] + +def predict(model, + utterances, + prev_query=None, + snippets=None, + gold_seq=None, + dropout_amount=0., + loss_only=False, + beam_size=1.): + """ Predicts a SQL query given an utterance and other various inputs. + + Inputs: + model (Seq2SeqModel): The model to use to predict. + utterances (list of list of str): The utterances to predict for. + prev_query (list of str, optional): The previously generated query. + snippets (list of Snippet. optional): The snippets available for prediction. + all_snippets (list of Snippet, optional): All snippets so far in the interaction. + gold_seq (list of str, optional): The gold sequence. + dropout_amount (float, optional): How much dropout to apply during predictino. + loss_only (bool, optional): Whether to only return the loss. + beam_size (float, optional): How many items to include in the beam during prediction. + """ + assert len(prev_query) == 0 or model.use_snippets + assert len(snippets) == 0 or model.use_snippets + assert not loss_only or len(gold_seq) > 0 + + (enc_state, enc_outputs), input_seq = model.encode_input_sequences( + utterances, dropout_amount) + + embedded_snippets = [] + if snippets: + embedded_snippets = model.encode_snippets( + prev_query, snippets, dropout_amount=dropout_amount) + assert len(embedded_snippets) == len(snippets) + + if gold_seq: + item = model.decode( + enc_state, + enc_outputs, + input_seq, + snippets=embedded_snippets if model.use_snippets else [], + gold_seq=gold_seq, + dropout_amount=dropout_amount)[0] + scores = item.scores + scores_by_timestep = [score[0] for score in scores] + score_maps_by_timestep = [score[1] for score in scores] + + assert scores_by_timestep[0].dim()[0][0] == len( + score_maps_by_timestep[0]) + assert len(score_maps_by_timestep[0]) >= len(model.output_vocab) + len(snippets) + + loss = du.compute_loss(gold_seq, + scores_by_timestep, + score_maps_by_timestep, + gold_tok_to_id, + noise=0.00000000001) + + if loss_only: + return loss + sequence = du.get_seq_from_scores(scores_by_timestep, + score_maps_by_timestep) + else: + item = model.decode( + enc_state, + enc_outputs, + input_seq, + snippets=embedded_snippets if model.use_snippets else [], + beam_size=beam_size)[0] + scalar_loss = 0 + sequence = item.sequence + + token_acc = 0 + if gold_seq: + token_acc = du.per_token_accuracy(gold_seq, sequence) + + return sequence, scalar_loss, token_acc, item.probability + + +def prepare_and_predict(model, + item, + use_gold=False, + training=False, + dropout=0., + beam_size=1): + utterances = du.get_utterances( + item, model.input_vocab, model.history_length) + assert len(utterances) <= model.history_length + if use_gold: + assert item.flatten_sequence( + item.gold_query()) == item.original_gold_query() + return model.predict( + utterances, + prev_query=item.previous_query() if model.use_snippets else [], + snippets=item.snippets() if model.use_snippets else [], + gold_seq=item.gold_query() if use_gold else [], + loss_only=training, + dropout_amount=dropout, + beam_size=1 if training else beam_size) + + +def train_step(model, batch, lr_coeff, dropout): + dy.renew_cg() + losses = [] + assert not model.prev_decoder_snippet_rep + + num_tokens = 0 + for item in batch.items: + loss = model.prepare_and_predict(item, + use_gold=True, + training=True, + dropout=dropout) + num_tokens += len(item.gold_query()) + losses.append(loss) + + final_loss = dy.esum(losses) / num_tokens + final_loss.forward() + final_loss.backward() + model.trainer.learning_rate = lr_coeff + model.trainer.update() + + return final_loss.npvalue()[0] + + +# eval_step +# Runs an evaluation on the example. +# +# Inputs: +# example: an Utterance. +# use_gold: whether or not to pass gold tokens into the decoder. +# +# Outputs: +# information provided by prepare and predict +def eval_step(model, + example, + use_gold=False, + dropout_amount=0., + beam_size=1): + dy.renew_cg() + assert not model.prev_decoder_snippet_rep + return model.prepare_and_predict(example, + use_gold=use_gold, + training=False, + dropout=dropout_amount, + beam_size=beam_size) diff --git a/visualize_attention.py b/visualize_attention.py new file mode 100644 index 0000000..e072c7e --- /dev/null +++ b/visualize_attention.py @@ -0,0 +1,113 @@ +""" +Used for creating a graph of attention over a fixed number of logits over a +sequence. E.g., attention over an input sequence while generating an output +sequence. +""" +import matplotlib.pyplot as plt +import numpy as np +from pylab import rcParams + + +class AttentionGraph(): + """Creates a graph showing attention distributions for inputs and outputs. + + Attributes: + keys (list of str): keys over which attention is done during generation. + generated_values (list of str): keeps track of the generated values. + attentions (list of list of float): keeps track of the probability + distributions. + """ + + def __init__(self, keys): + """ + Initializes the attention graph. + + Args: + keys (list of string): a list of keys over which attention is done + during generation. + """ + if not keys: + raise ValueError("Expected nonempty keys for attention graph.") + + self.keys = keys + self.generated_values = [] + self.attentions = [] + + def add_attention(self, gen_value, probabilities): + """ + Adds attention scores for all item in `self.keys`. + + Args: + gen_value (string): a generated value for this timestep. + probabilities (np.array): probability distribution over the keys. Assumes + the order of probabilities corresponds to the order of the keys. + + Raises: + ValueError if `len(probabilities)` is not the same as `len(self.keys)` + ValueError if `sum(probabilities)` is not 1 + """ + if len(probabilities) != len(self.keys): + raise ValueError("Length of attention keys is " + + str(len(self.keys)) + + " but got probabilities of length " + + str(len(probabilities))) +# if sum(probabilities) != 1.0: +# raise ValueError("Probabilities sum to " + +# str(sum(probabilities)) + "; not 1.0") + + self.generated_values.append(gen_value) + self.attentions.append(probabilities) + + def render(self, filename): + """ + Renders the attention graph over timesteps. + + Args: + filename (string): filename to save the figure to. + """ + figure, axes = plt.subplots() + graph = np.stack(self.attentions) + + axes.imshow(graph, cmap=plt.cm.Blues, interpolation="nearest") + axes.xaxis.tick_top() + axes.set_xticks(range(len(self.keys))) + axes.set_xticklabels(self.keys) + plt.setp(axes.get_xticklabels(), rotation=90) + axes.set_yticks(range(len(self.generated_values))) + axes.set_yticklabels(self.generated_values) + axes.set_aspect(1, adjustable='box') +# axes.grid(b=True, color='w', linestyle='-',linewidth=2,which='minor') +# plt.minorticks_on() + plt.tick_params(axis='x',which='both',bottom='off',top='off') + plt.tick_params(axis='y',which='both',left='off',right='off') + + figure.savefig(filename) + + def render_as_latex(self, filename): + ofile = open(filename, "w") + + ofile.write("\\documentclass{article}\\usepackage[margin=0.5in]{geometry}\\usepackage{tikz}\\begin{document}\\begin{tikzpicture}[scale=0.25]\\begin{tiny}\\begin{scope}<+->;\n") + xstart = 0 + ystart = 0 + xend = len(self.keys) + yend = len(self.generated_values) + + ofile.write("\\draw[step=1cm,gray,very thin] (" + str(xstart) + "," + str(ystart) +") grid (" + str(xend) + ", " + str(yend) + ");\n") + + for i, tok in enumerate(self.keys): + tok = tok.replace("_", "\_") + tok = tok.replace("#", "\#") + ofile.write("\\draw[gray, xshift=" + str(i) + ".5 cm] (0,0.3) -- (0,0) node[below,rotate=90,anchor=east] {" + tok + "};\n") + + for i, tok in enumerate(self.generated_values[::-1]): + tok = tok.replace("_", "\_") + tok = tok.replace("#", "\#") + ofile.write("\\draw[gray, yshift=" + str(i) + ".5 cm] (0.3,0) -- (0,0) node[left] {" + tok + "};\n") + + for i, gentok_atts in enumerate(self.attentions[::-1]): + for j, val in enumerate(gentok_atts): + if val < 0.001: + val = 0 + ofile.write("\\filldraw[thin,red,opacity=" + "%.2f" % val + "] (" + str(j) + ", " + str(i) + ") rectangle (" + str(j+1)+ "," + str(i+1) + ");\n") + + ofile.write("\\end{scope}\\end{tiny}\\end{tikzpicture}{\end{document}") diff --git a/vocabulary.py b/vocabulary.py new file mode 100644 index 0000000..429c213 --- /dev/null +++ b/vocabulary.py @@ -0,0 +1,77 @@ +import operator +import os +import pickle + +# Special sequencing tokens. +UNK_TOK = "_UNK" # Replaces out-of-vocabulary words. +EOS_TOK = "_EOS" # Appended to the end of a sequence to indicate its end. +DEL_TOK = ";" + + +class Vocabulary: + def get_vocab(self, sequences, ignore_fn): + type_counts = {} + + for sequence in sequences: + for token in sequence: + if not ignore_fn(token): + if token not in type_counts: + type_counts[token] = 0 + type_counts[token] += 1 + + # Create sorted list of tokens, by their counts. Reverse so it is in order of + # most frequent to least frequent. + sorted_type_counts = sorted(sorted(type_counts.items()), + key=operator.itemgetter(1))[::-1] + + sorted_types = [typecount[0] + for typecount in sorted_type_counts if typecount[1] >= self.min_occur] + + # Append the necessary functional tokens. + sorted_types = self.functional_types + sorted_types + + # Cut off if vocab_size is set (nonnegative) + if self.max_size >= 0: + vocab = sorted_types[:max(self.max_size, len(sorted_types))] + else: + vocab = sorted_types + + return vocab + + def __init__(self, + sequences, + filename, + functional_types=[], + max_size=-1, + min_occur=0, + ignore_fn=lambda x: False): + self.functional_types = functional_types + self.max_size = max_size + self.min_occur = min_occur + + vocab = self.get_vocab(sequences, ignore_fn) + + self.id_to_token = [] + self.token_to_id = {} + + for i in range(len(vocab)): + self.id_to_token.append(vocab[i]) + self.token_to_id[vocab[i]] = i + + # Load the previous vocab, if it exists. + if os.path.exists(filename): + f = open(filename, 'rb') + loaded_vocab = pickle.load(f) + f.close() + + print("Loaded vocabulary from " + str(filename)) + if loaded_vocab.id_to_token != self.id_to_token or loaded_vocab.token_to_id != self.token_to_id: + print("Loaded vocabulary is different than generated vocabulary.") + else: + print("Writing vocabulary to " + str(filename)) + f = open(filename, 'wb') + pickle.dump(self, f) + f.close() + + def __len__(self): + return len(self.id_to_token)