-
Notifications
You must be signed in to change notification settings - Fork 0
/
PennToPCFG.py
213 lines (170 loc) · 8.17 KB
/
PennToPCFG.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
#!/usr/bin/env python
# -*- coding: utf-8 -*-
__author__ = 'Johannes Gontrum <gontrum@uni-potsdam.de>'
# Learns an unlexicalized PCFG from a Penn Treebank - like corpus (e.g. WallStreet Journal),
# Features:
# * Has the ability to create a file containing the sentences from the given treebank file.
# * Can also create a file containing the trees from the treebank - one tree per line.
# * The sets of terminal symbols and non-terminal symbols in the grammar will be distinct.
# If a tree has a 'rule' 'NNP -> NNP' where the 'NNP' on the rhs is a terminal symbol in the
# unlexicalized grammar, it will be replaced by the rule '_NNP_ -> NNP'.
###### Imports #########################################################################
import argparse
from nltk.corpus import BracketParseCorpusReader
import nltk
import sys
from os import path
###### Variables #########################################################################
start_symbol = "" #< The startsymbol of the grammar as string.
terminals = set() #< The set of all found terminal symbols (String)
nonterminals = set() #< A Set of nonterminals (NLTK.Nonterminal)
pos = set() #< Set of all POS-tags. A POS tag is a direct parent of a leaf in the tree.
symbolMap = dict() #< Maps ambigous nonterminals to a unique replacement
alteredTrees = list()
productions = dict()
ntCounter = dict()
grammar = list()
rhs_set = set()
sentences = list()
###### Functions #########################################################################
## Sets up the command line options
def createArgParser ():
parser = argparse.ArgumentParser(description='Learns an unlexicalised PCFG from a Penn Treebank file')
parser.add_argument("-p", "--penn", help="The Penn Treebank file.", required=False)
parser.add_argument("-g", "--grammar", help="File to write the PCFG to.", required=False,
type=argparse.FileType('w'))
parser.add_argument("-pe", "--pennEval",
help="The Penn Treebank file that is used to read the sentences and the trees from. If not specified it uses the file to create the grammar from.",
required=False)
parser.add_argument("-s", "--sentences", help="File to write the sentences to.", required=False,
type=argparse.FileType('w'))
parser.add_argument("-t", "--trees", help="File to write the trees to.", required=False,
type=argparse.FileType('w'))
parser.add_argument("-l", "--length", help="Maximum length of the sentences for the evaluation (default=30)",
required=False,
type=int, default=30)
parser.add_argument("-b", "--debinarize", help="Saves the trees from the input file as unbinarized trees in the output file.", required=False,
nargs=2)
return parser
## Split symbols in terminals and nonterminals
def findSymbolsInTree (tree):
# Basecase: Still a tree
if type(tree) == nltk.Tree:
nonterminals.add(tree.node)
if len(tree) == 1 and tree.height() == 2:
pos.add(tree.node) # found a POS tag
for subtree in tree:
findSymbolsInTree(subtree) # Go on recursively
# Recursive case: It's a String
else:
terminals.add(tree)
## Replace ambiguous symbols
def replaceSymbolsInTree (tree, sent):
if tree.node in symbolMap:
tree.node = symbolMap[tree.node]
# iterate over all subtrees...
for i in range(len(tree)):
# Case: The subtree is a leaf (String)
if type(tree[i]) != nltk.Tree:
tree[i] = revertPOS(tree.node)
sent.append(tree[i])
pass
# Recursive case
else:
replaceSymbolsInTree(tree[i], sent)
## Turns a _A_ symbol back to A
def revertPOS(symbol):
return symbol[1:-1]
###### Main #########################################################################
if __name__ == '__main__':
clArgs = createArgParser().parse_args()
#Check if any arguments are given. If not, display help
active = False
if clArgs.penn != None and clArgs.grammar != None:
active = True
## Set up the treebank reader
ptb = BracketParseCorpusReader(path.dirname(clArgs.penn), [path.basename(clArgs.penn)])
## Collect all terminal and nonterminals
for tree in ptb.parsed_sents(ptb.fileids()[0]):
# Also set the start symbol to the root of the first tree
if len(start_symbol) == 0:
start_symbol = tree.node
findSymbolsInTree(tree)
## Find ambiguous symbols and map them to a unique alternative
for symbol in nonterminals.intersection(pos):
replacement = "_" + symbol + "_"
symbolMap[symbol] = replacement
if replacement in pos or replacement in nonterminals:
print "Cannot make nonterminal unambiguous: ", symbol
sys.exit(-1)
## Iterate over all trees and replace ambigous nonterminals with their unique alternative
for tree in ptb.parsed_sents(ptb.fileids()[0]):
newTree = tree.copy(True)
# Remove unary rules and convert to CNF
newTree.chomsky_normal_form(horzMarkov=2)
newTree.collapse_unary(collapsePOS=False)
replaceSymbolsInTree(newTree, [])
# newTree.draw()
## Count production occupancies
for production in newTree.productions():
if len(production.rhs()) == 0:
production = nltk.grammar.Production(production.lhs(), [revertPOS(production.lhs())])
# Update the symbol and rule counter
if production.lhs() in ntCounter:
ntCounter[production.lhs()] += 1
else:
ntCounter[production.lhs()] = 1
if production in productions:
productions[production] += 1
else:
productions[production] = 1
# Check if the start symbol must be replaced
if start_symbol in symbolMap:
start_symbol = symbolMap[start_symbol]
## Time to write the PCFG
clArgs.grammar.write(start_symbol + "\n")
for prod in productions:
rhs_set.add(prod.rhs())
ret = "{0} -> ".format(prod.lhs())
for sym in prod.rhs():
ret += ("{0} ".format(sym))
ret += ("[" + str(productions[prod] / float(ntCounter[prod.lhs()])) + "]")
clArgs.grammar.write(ret + "\n")
## Check if we have to handle the optional arguments
evalSentences = False
evalTrees = False
if clArgs.sentences != None:
evalSentences = True
if clArgs.trees != None:
evalTrees = True
## Go on creating evaluation data
if evalSentences or evalTrees:
evalFile = clArgs.penn
if clArgs.pennEval:
evalFile = clArgs.pennEval
ptb_eval = BracketParseCorpusReader(path.dirname(evalFile), [path.basename(evalFile)])
for tree in ptb_eval.parsed_sents(ptb_eval.fileids()[0]):
newTree = tree.copy(True)
# Transform the trees for evaluation in exactly the same way as done for the grammar
# newTree.chomsky_normal_form(horzMarkov=2)
newTree.collapse_unary(collapsePOS=False)
sent = []
replaceSymbolsInTree(newTree, sent)
if len(sent) <= clArgs.length:
if evalSentences:
clArgs.sentences.write(" ".join(sent) + "\n")
if evalTrees:
clArgs.trees.write(newTree._pprint_flat('', "()", False) + "\n")
## Debinarize
if clArgs.debinarize != None:
active = True
bintreefile = open(clArgs.debinarize[0], "r")
outtreefile = open(clArgs.debinarize[1], "w")
for stringtree in bintreefile:
tree = nltk.Tree(stringtree)
tree.un_chomsky_normal_form()
outtreefile.write(tree._pprint_flat('', "()", False) + "\n")
bintreefile.close()
outtreefile.close()
if not active:
createArgParser().print_usage()