Skip to content
46 changes: 25 additions & 21 deletions neuralcoref/train/conllparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,13 +119,13 @@ def load_file(full_name, debug=False):
load a *._conll file
Input: full_name: path to the file
Output: list of tuples for each conll doc in the file, where the tuple contains:
(utts_text ([str]): list of the utterances in the document
utts_tokens ([[str]]): list of the tokens (conll words) in the document
(utts_text ([str]): list of the utterances in the document
utts_tokens ([[str]]): list of the tokens (conll words) in the document
utts_corefs: list of coref objects (dicts) with the following properties:
coref['label']: id of the coreference cluster,
coref['start']: start index (index of first token in the utterance),
coref['end': end index (index of last token in the utterance).
utts_speakers ([str]): list of the speaker associated to each utterances in the document
utts_speakers ([str]): list of the speaker associated to each utterances in the document
name (str): name of the document
part (str): part of the document
)
Expand Down Expand Up @@ -377,11 +377,11 @@ def get_feature_array(self, doc_id, feature=None, compressed=True, debug=False):
mentions_features: (N, Fs)
mentions_labels: (N, 1)
mentions_pairs_start_index: (N, 1) index of beggining of pair list in pair_labels
mentions_pairs_length: (N, 1) number of pairs (i.e. nb of antecedents) for each mention
mentions_pairs_length: (N, 1) number of pairs (i.e. nb of antecedents) for each mention
pairs_features: (P, Fp)
pairs_labels: (P, 1)
pairs_ant_idx: (P, 1) => indexes of antecedents mention for each pair (mention index in doc)
"""
"""
if not self.mentions:
if debug: print("No mention in this doc !")
return {}
Expand Down Expand Up @@ -552,7 +552,7 @@ def list_undetected_mentions(self, data_path, save_file, debug=True):
out_file.write(out_str)
if debug: print(out_str)

def read_corpus(self, data_path, debug=False):
def read_corpus(self, data_path, model=None, debug=False):
print("🌋 Reading files")
for dirpath, _, filenames in os.walk(data_path):
print("In", dirpath, os.path.abspath(dirpath))
Expand Down Expand Up @@ -595,20 +595,23 @@ def read_corpus(self, data_path, debug=False):
embedding_extractor=self.embed_extractor,
conll=CONLL_GENRES[name[:2]]))
print("🌋 Loading spacy model")
model_options = ['en_core_web_lg', 'en_core_web_md', 'en_core_web_sm', 'en']
model = None
for model_option in model_options:
if not model:
try:
spacy.info(model_option)
model = model_option
print("Loaded model", model_option)
except:
print("Could not detect model", model_option)
if not model:
print("Could not detect any suitable English model")
return

if model is None:
model_options = ['en_core_web_lg', 'en_core_web_md', 'en_core_web_sm', 'en']
for model_option in model_options:
if not model:
try:
spacy.info(model_option)
model = model_option
print("Loading model", model_option)
except:
print("Could not detect model", model_option)
if not model:
print("Could not detect any suitable English model")
return
else:
spacy.info(model)
print("Loading model", model)
nlp = spacy.load(model)
print("🌋 Parsing utterances and filling docs with use_gold_mentions=" + (str(bool(self.gold_mentions))))
doc_iter = (s for s in self.utts_text)
Expand All @@ -618,7 +621,7 @@ def read_corpus(self, data_path, debug=False):
spacy_tokens, conll_tokens, corefs, speaker, doc_id = utt_tuple
if debug: print(unicode_(self.docs_names[doc_id]), "-", spacy_tokens)
doc = spacy_tokens
if debug:
if debug:
out_str = "utterance " + unicode_(doc) + " corefs " + unicode_(corefs) + \
" speaker " + unicode_(speaker) + "doc_id" + unicode_(doc_id)
print(out_str.encode('utf-8'))
Expand Down Expand Up @@ -722,6 +725,7 @@ def _vocabulary_to_file(path, vocabulary):
parser.add_argument('--n_jobs', type=int, default=1, help='Number of parallel jobs (default 1)')
parser.add_argument('--gold_mentions', type=int, default=0, help='Use gold mentions (1) or not (0, default)')
parser.add_argument('--blacklist', type=int, default=0, help='Use blacklist (1) or not (0, default)')
parser.add_argument('--spacy_model', type=str, default=None, help='model name')
args = parser.parse_args()
if args.key is None:
args.key = args.path + "/key.txt"
Expand All @@ -738,7 +742,7 @@ def _vocabulary_to_file(path, vocabulary):
print(file)
os.remove(SAVE_DIR + file)
start_time = time.time()
CORPUS.read_corpus(args.path)
CORPUS.read_corpus(args.path, model=args.spacy_model)
print('=> read_corpus time elapsed', time.time() - start_time)
if not CORPUS.docs:
print("Could not parse any valid docs")
Expand Down
45 changes: 38 additions & 7 deletions neuralcoref/train/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,32 @@ def load_embeddings_from_file(name):
voc = [line.strip() for line in f]
return embed, voc


class _DictionaryDataLoader(object):
def __init__(self, dict_object, order):
self.dict_object = dict_object
self.order = order

def __len__(self):
return len(self.dict_object[self.order[0]])

def __getitem__(self, idx):
if isinstance(idx, slice):
data = []
for i in range(idx.start, idx.stop, idx.step if idx.step is not None else 1):
temp_data = []
for key in self.order:
temp_data.append(self.dict_object[key][i])
data.append(temp_data)

else:
data = []
for key in self.order:
data.append(self.dict_object[key][idx])

return data


class NCDataset(Dataset):
def __init__(self, data_path, params, no_targets=False):
print("🏝 Loading Dataset at", data_path)
Expand All @@ -44,13 +70,18 @@ def __init__(self, data_path, params, no_targets=False):
continue
numpy_files_found = True
print(file_name, end=', ')
datas[file_name.split(u'.')[0]] = np.load(data_path + file_name)
datas[file_name.split(u'.')[0]] = np.load(data_path + file_name, mmap_mode="r" if params.lazy else None)
if not numpy_files_found:
raise ValueError("Can't find numpy files in {}".format(data_path))

# Gather arrays in two lists of tuples for mention and pairs
self.mentions = list(zip(*(arr for key, arr in sorted(datas.items()) if key.startswith(u"mentions"))))
self.pairs = list(zip(*(arr for key, arr in sorted(datas.items()) if key.startswith(u"pairs"))))
if not params.lazy:
self.mentions = list(zip(*(arr for key, arr in sorted(datas.items()) if key.startswith(u"mentions"))))
self.pairs = list(zip(*(arr for key, arr in sorted(datas.items()) if key.startswith(u"pairs"))))
else:
self.mentions = _DictionaryDataLoader(datas, order=('mentions_features', 'mentions_labels', 'mentions_pairs_length', 'mentions_pairs_start_index', 'mentions_spans', 'mentions_words'))
self.pairs = _DictionaryDataLoader(datas, order=('pairs_ant_index', 'pairs_features', 'pairs_labels'))

self.mentions_pair_length = datas[FEATURES_NAMES[2]]
assert [arr.shape[0] for arr in self.mentions[0]] == [6, 1, 1, 1, 250, 8] # Cf order of FEATURES_NAMES in conllparser.py
assert [arr.shape[0] for arr in self.pairs[0]] == [1, 9, 1] # Cf order of FEATURES_NAMES in conllparser.py
Expand Down Expand Up @@ -148,7 +179,7 @@ def __getitem__(self, mention_idx, debug=False):
ant_features[:, 15] = ant_features_raw[:, 2].astype(float) / ant_features_raw[:, 3].astype(float)
ant_features[:, 16] = ant_features_raw[:, 4]
pairs_features[:, 29:46] = ant_features
# Here we keep the genre
# Here we keep the genre
ana_features = np.tile(features, (pairs_length, 1))
pairs_features[:, 46:] = ana_features

Expand Down Expand Up @@ -213,7 +244,7 @@ def __init__(self, mentions_pairs_length, batchsize=600,
shuffle=False, debug=False):
""" Create and feed batches of mentions having close number of antecedents
The batch are padded and collated by the padder_collate function

# Arguments:
mentions_pairs_length array of shape (N, 1): list/array of the number of pairs for each mention
batchsize: Number of pairs of each batch will be capped at this
Expand All @@ -232,7 +263,7 @@ def __init__(self, mentions_pairs_length, batchsize=600,
num = 0
for length, mention_idx in sorted_lengths:
if num > batchsize or (num == len(batch) and length != 0): # We keep the no_pairs batches pure
if debug: print("Added batch number", len(self.batches),
if debug: print("Added batch number", len(self.batches),
"with", len(batch), "mentions and", num, "pairs")
self.batches.append(batch)
self.batches_size.append(num) # We don't count the max 7 additional mentions that are repeated
Expand Down Expand Up @@ -281,7 +312,7 @@ def __len__(self):

def padder_collate(batch, debug=False):
""" Puts each data field into a tensor with outer dimension batch size
Pad variable length input tensors and add a weight tensor to the target
Pad variable length input tensors and add a weight tensor to the target
"""
transposed_inputs = tuple(zip(*batch))
if len(transposed_inputs) == 2:
Expand Down
4 changes: 2 additions & 2 deletions neuralcoref/train/learn.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,10 +303,10 @@ def run_epochs(start_epoch, end_epoch, loss_func, optim_func, save_name, lr, g_s
parser.add_argument('--min_lr', type=float, default=2e-8, help='min learning rate')
parser.add_argument('--on_eval_decrease', type=str, default='nothing',
help='What to do when evaluation decreases ("nothing", "divide_lr", "next_stage", "divide_then_next")')
parser.add_argument('--lazy', type=int, default=1, choices=(0, 1), help='Use lazy loading (1, default) or not (0) while loading the npy files')
args = parser.parse_args()

args.costs = {'FN': args.costfn, 'FL': args.costfl, 'WL' : args.costwl }

args.lazy = bool(args.lazy)
current_time = datetime.now().strftime('%b%d_%H-%M-%S')
args.save_path = os.path.join(PACKAGE_DIRECTORY, 'checkpoints', current_time + '_' + socket.gethostname() + '_')

Expand Down
4 changes: 2 additions & 2 deletions neuralcoref/train/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def forward(self, inputs, concat_axis=1):
else:
spans, words, single_features = inputs
words = words.type(torch.LongTensor)
if self.cuda:
if torch.cuda.is_available():
words = words.cuda()
embed_words = self.drop(self.word_embeds(words).view(words.size()[0], -1))
single_input = torch.cat([spans, embed_words, single_features], 1)
Expand All @@ -80,7 +80,7 @@ def forward(self, inputs, concat_axis=1):
batchsize, pairs_num, _ = ana_spans.size()
ant_words_long = ant_words.view(batchsize, -1).type(torch.LongTensor)
ana_words_long = ana_words.view(batchsize, -1).type(torch.LongTensor)
if self.cuda:
if torch.cuda.is_available():
ant_words_long = ant_words_long.cuda()
ana_words_long = ana_words_long.cuda()
ant_embed_words = self.drop(self.word_embeds(ant_words_long).view(batchsize, pairs_num, -1))
Expand Down