Skip to content

Commit

Permalink
Merge pull request #133 from customprogrammingsolutions/improve-logging
Browse files Browse the repository at this point in the history
[MRG] Improve logging
  • Loading branch information
oadams committed Mar 31, 2018
2 parents 4a92f4f + 8916d24 commit 8ca2bbb
Show file tree
Hide file tree
Showing 11 changed files with 85 additions and 31 deletions.
14 changes: 14 additions & 0 deletions persephone/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,15 @@
__version__ = "0.2.0"

import sys
import logging

def handle_unhandled_exception(exc_type, exc_value, exc_traceback):
"""Handler for unhandled exceptions that will write to the logs"""
if issubclass(exc_type, KeyboardInterrupt):
# call the default excepthook saved at __excepthook__
sys.__excepthook__(exc_type, exc_value, exc_traceback)
return
logger = logging.getLogger(__name__) # type: ignore
logger.critical("Unhandled exception", exc_info=(exc_type, exc_value, exc_traceback))

sys.excepthook = handle_unhandled_exception
25 changes: 22 additions & 3 deletions persephone/corpus.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from .utterance import Utterance
from .preprocess.labels import LabelSegmenter

logging.config.fileConfig(config.LOGGING_INI_PATH)
logger = logging.getLogger(__name__) # type: ignore

CorpusT = TypeVar("CorpusT", bound="Corpus")

Expand Down Expand Up @@ -82,6 +82,9 @@ def __init__(self, feat_type, label_type, tgt_dir, labels,
"""

logger.debug("Creating a new Corpus object with feature type %s, label type %s,"
"target directory %s, label set %s, max_samples %d, speakers %s",
feat_type, label_type, tgt_dir, labels, max_samples, speakers)
#: A string representing the type of speech feature (eg. "fbank"
#: for log filterbank energies).
self.feat_type = feat_type
Expand All @@ -91,21 +94,24 @@ def __init__(self, feat_type, label_type, tgt_dir, labels,
self.label_type = label_type

# Setting up directories
logger.debug("Setting up directories for this Corpus object at %s", tgt_dir)
self.set_and_check_directories(tgt_dir)

# Label-related stuff
self.initialize_labels(labels)
logging.info("Corpus label set: \n\t{}".format(labels))
logger.info("Corpus label set: \n\t{}".format(labels))

# This is a lazy function that assumes wavs are already in the WAV dir
# but only creates features if necessary
logger.debug("Preparing features")
self.prepare_feats()
self._num_feats = None

# This is also lazy if the {train,valid,test}_prefixes.txt files exist.
self.make_data_splits(max_samples=max_samples)

# Sort the training prefixes by size for more efficient training
logger.debug("Training prefixes")
self.train_prefixes = utils.sort_by_size(
self.feat_dir, self.train_prefixes, feat_type)

Expand Down Expand Up @@ -217,6 +223,7 @@ def test_prefix_fn(self) -> Path:

def set_and_check_directories(self, tgt_dir: Path) -> None:

logger.info("Setting up directories for corpus in %s", tgt_dir)
# Set the directory names
self.tgt_dir = tgt_dir
self.feat_dir = self.get_feat_dir()
Expand Down Expand Up @@ -450,9 +457,10 @@ def review(self):
print("Transcription: {}".format(transcript))
subprocess.run(["play", str(wav_fn)])

def ensure_no_set_overlap(self):
def ensure_no_set_overlap(self) -> None:
""" Ensures no test set data has creeped into the training set."""

logger.debug("Ensuring that the training, validation and test data sets have no overlap")
train = set(self.get_train_fns()[0])
valid = set(self.get_valid_fns()[0])
test = set(self.get_test_fns()[0])
Expand All @@ -463,16 +471,25 @@ def ensure_no_set_overlap(self):
assert test - train == test
assert test - valid == test

if train & valid:
logger.warning("train and valid have overlapping items: {}".format(train & valid))
if train & test:
logger.warning("train and test have overlapping items: {}".format(train & test))
if valid & test:
logger.warning("valid and test have overlapping items: {}".format(valid & test))

def pickle(self):
""" Pickles the Corpus object in a file in tgt_dir. """

pickle_path = self.tgt_dir / "corpus.p"
logger.debug("pickling %r object and saving it to path %s", self, pickle_path)
with pickle_path.open("wb") as f:
pickle.dump(self, f)

@classmethod
def from_pickle(cls: Type[CorpusT], tgt_dir: Path) -> CorpusT:
pickle_path = tgt_dir / "corpus.p"
logger.debug("Creating Corpus object from pickle file path %s", pickle_path)
with pickle_path.open("rb") as f:
return pickle.load(f)

Expand All @@ -490,6 +507,7 @@ def __init__(self, tgt_dir, feat_type="fbank", label_type="phonemes"):
@staticmethod
def determine_labels(tgt_dir, label_type):
""" Returns a set of phonemes found in the corpus. """
logger.info("Finding phonemes of type %s in directory %s", label_type, tgt_dir)

label_dir = os.path.join(tgt_dir, "label/")
if not os.path.isdir(label_dir):
Expand All @@ -503,6 +521,7 @@ def determine_labels(tgt_dir, label_type):
try:
line_phonemes = set(f.readline().split())
except UnicodeDecodeError:
logger.error("Unicode decode error on file %s", fn)
print("Unicode decode error on file {}".format(fn))
raise
phonemes = phonemes.union(line_phonemes)
Expand Down
10 changes: 6 additions & 4 deletions persephone/corpus_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from . import utils
from .exceptions import PersephoneException

logging.config.fileConfig(config.LOGGING_INI_PATH)
logger = logging.getLogger(__name__) # type: ignore

class CorpusReader:
""" Interfaces to the preprocessed corpora to read in train, valid, and
Expand All @@ -37,6 +37,7 @@ def __init__(self, corpus, num_train=None, batch_size=None, max_samples=None, ra
self.corpus = corpus

if max_samples:
logger.critical("max_samples not yet implemented in CorpusReader")
raise NotImplementedError("Not yet implemented.")

if not num_train:
Expand All @@ -46,6 +47,9 @@ def __init__(self, corpus, num_train=None, batch_size=None, max_samples=None, ra
num_batches = int(num_train / batch_size)
num_train = num_batches * batch_size
self.num_train = num_train
logger.info("Number of training utterances: {}".format(num_train))
logger.info("Batch size: {}".format(batch_size))
logger.info("Batches per epoch: {}".format(int(num_train/batch_size)))
print("Number of training utterances: {}".format(num_train))
print("Batch size: {}".format(batch_size))
print("Batches per epoch: {}".format(int(num_train/batch_size)))
Expand Down Expand Up @@ -122,7 +126,7 @@ def train_batch_gen(self):
random.shuffle(fn_batches)

for fn_batch in fn_batches:
logging.debug("Batch of training filenames: " +
logger.debug("Batch of training filenames: " +
pprint.pformat(fn_batch))
yield self.load_batch(fn_batch)

Expand Down Expand Up @@ -179,8 +183,6 @@ def human_readable(self, dense_repr):

return transcripts

#def __init__(self, corpus, num_train=None, batch_size=None, max_samples=None, rand_seed=0):

def __repr__(self):
return ("%s(" % self.__class__.__name__ +
"num_train=%s,\n" % repr(self.num_train) +
Expand Down
4 changes: 2 additions & 2 deletions persephone/logging.ini
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,6 @@ handlers=hand01

[handler_hand01]
class=FileHandler
level=NOTSET
level=DEBUG
formatter=simpleFormatter
args=("log.txt", "w")
args=("log.txt", "a")
15 changes: 9 additions & 6 deletions persephone/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import inspect
import itertools
import logging
import logging.config
import os
import subprocess
import sys
Expand All @@ -21,7 +20,7 @@
allow_growth_config = tf.ConfigProto(log_device_placement=False)
allow_growth_config.gpu_options.allow_growth=True #pylint: disable=no-member

logging.config.fileConfig(config.LOGGING_INI_PATH)
logger = logging.getLogger(__name__) # type: ignore

class Model:
""" Generic model for our ASR tasks. """
Expand Down Expand Up @@ -187,8 +186,8 @@ def output_lattices(self, batch, restore_model_path=None):
prefix + ".projection.bin", prefix + ".rmepsilon.bin"]
subprocess.run(run_args)
except FileNotFoundError:
print("Make sure you have OpenFST binaries installed and "
"available on the path")
logger.error("Make sure you have OpenFST binaries installed and "
"available on the path")
raise

def eval(self, restore_model_path=None):
Expand All @@ -197,9 +196,11 @@ def eval(self, restore_model_path=None):
saver = tf.train.Saver()
with tf.Session(config=allow_growth_config) as sess:
if restore_model_path:
logger.info("restoring model from %s", restore_model_path)
saver.restore(sess, restore_model_path)
else:
assert self.saved_model_path
logger.info("restoring model from %s", self.saved_model_path)
saver.restore(sess, self.saved_model_path)

test_x, test_x_lens, test_y = self.corpus_reader.test_batch()
Expand Down Expand Up @@ -242,7 +243,7 @@ def train(self, early_stopping_steps=10, min_epochs=30, max_valid_ler=1.0,
save_n: Whether to save the model at every n epochs.
restore_model_path: The path to restore a model from.
"""

logger.info("Training model")
best_valid_ler = 2.0
steps_since_last_record = 0

Expand Down Expand Up @@ -271,6 +272,7 @@ def train(self, early_stopping_steps=10, min_epochs=30, max_valid_ler=1.0,
sess = tf.Session(config=allow_growth_config)

if restore_model_path:
logger.info("Restoring model from path %s", restore_model_path)
saver.restore(sess, restore_model_path)
else:
sess.run(tf.global_variables_initializer())
Expand Down Expand Up @@ -311,9 +313,10 @@ def train(self, early_stopping_steps=10, min_epochs=30, max_valid_ler=1.0,
[self.ler, self.dense_decoded, self.dense_ref],
feed_dict=feed_dict)
except tf.errors.ResourceExhaustedError:
print("Ran out of memory allocating a batch:")
import pprint
print("Ran out of memory allocating a batch:")
pprint.pprint(feed_dict)
logger.critical("Ran out of memory allocating a batch: %s", pprint.pformat(feed_dict))
raise
hyps, refs = self.corpus_reader.human_readable_hyp_ref(
dense_decoded, dense_ref)
Expand Down
3 changes: 2 additions & 1 deletion persephone/preprocess/elan.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from ..utterance import Utterance

logger = logging.getLogger(__name__) # type: ignore

class Eaf(pympi.Elan.Eaf):
""" This subclass exists because eaf MEDIA_DESCRIPTOR elements typically
Expand Down Expand Up @@ -130,7 +131,7 @@ def utterances_from_dir(eaf_dir: Path,
"""

logging.info(
logger.info(
"EAF from directory: {}, searching with tier_prefixes {}".format(
eaf_dir, tier_prefixes))

Expand Down
12 changes: 10 additions & 2 deletions persephone/preprocess/feat_extract.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
""" Performs feature extraction of WAV files for acoustic modelling."""

import logging
import os
from pathlib import Path
import subprocess
Expand All @@ -11,6 +11,8 @@
from .. import config
from ..exceptions import PersephoneException

logger = logging.getLogger(__name__) #type: ignore

def extract_energy(rate, sig):
""" Extracts the energy of frames. """

Expand Down Expand Up @@ -74,6 +76,7 @@ def flatten(feats_3d):
elif len(fbanks.shape) == 2:
pass
else:
logger.error("Invalid fbank array shape %s", (str(fbanks.shape)))
raise PersephoneException("Invalid fbank array shape %s" % (str(fbanks.shape)))

diff = len(fbanks) - len(pitches)
Expand All @@ -86,6 +89,7 @@ def flatten(feats_3d):
# features goes anyway). But I'm currently keeping it this way for
# experimental consistency.
if diff > 2:
logger.warning("Excessive difference in number of frames. %d", diff)
raise PersephoneException("Excessive difference in number of frames. %d" % diff)
elif diff > 0:
pitches = np.concatenate((np.array([[0,0]]*(len(fbanks) - len(pitches))), pitches))
Expand Down Expand Up @@ -129,7 +133,7 @@ def all_wavs_processed():

# Then apply file-wise feature extraction
for filename in os.listdir(dirname):
print("Preparing %s features for %s" % (feat_type, filename))
logger.info("Preparing %s features for %s", feat_type, filename)
path = os.path.join(dirname, filename)
if path.endswith(".wav"):
if feat_type == "fbank":
Expand All @@ -144,6 +148,7 @@ def all_wavs_processed():
elif feat_type == "mfcc13_d":
mfcc(path)
else:
logger.warning("Feature type not found: %s", feat_type)
raise PersephoneException("Feature type not found: %s" % feat_type)

def convert_wav(org_wav_fn: Path, tgt_wav_fn: Path) -> None:
Expand All @@ -155,6 +160,7 @@ def convert_wav(org_wav_fn: Path, tgt_wav_fn: Path) -> None:
def kaldi_pitch(wav_dir, feat_dir):
""" Extract Kaldi pitch features. Assumes 16k mono wav files."""

logger.debug("Make wav.scp and pitch.scp files")
# Make wav.scp and pitch.scp files
prefixes = []
for fn in os.listdir(wav_dir):
Expand All @@ -165,11 +171,13 @@ def kaldi_pitch(wav_dir, feat_dir):
wav_scp_path = os.path.join(feat_dir, "wavs.scp")
with open(wav_scp_path, "w") as wav_scp:
for prefix in prefixes:
logger.info("Writing wav file: %s", os.path.join(wav_dir, prefix + ".wav"))
print(prefix, os.path.join(wav_dir, prefix + ".wav"), file=wav_scp)

pitch_scp_path = os.path.join(feat_dir, "pitch_feats.scp")
with open(pitch_scp_path, "w") as pitch_scp:
for prefix in prefixes:
logger.info("Writing scp file: %s", os.path.join(feat_dir, prefix + ".pitch.txt"))
print(prefix, os.path.join(feat_dir, prefix + ".pitch.txt"), file=pitch_scp)

# Call Kaldi pitch feat extraction
Expand Down
9 changes: 6 additions & 3 deletions persephone/preprocess/pangloss.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
""" Some functions to interface with the Pangloss """
import logging

from xml.etree import ElementTree

logger = logging.getLogger(__name__) #type: ignore

def get_sents_times_and_translations(xml_fn):
""" Given an XML filename, loads the transcriptions, their start/end times,
and translations. """
Expand All @@ -24,7 +27,7 @@ def get_sents_times_and_translations(xml_fn):
else:
transcription = child.find("FORM").text
audio_info = child.find("AUDIO")
if audio_info != None:
if audio_info is not None:
start_time = float(audio_info.attrib["start"])
end_time = float(audio_info.attrib["end"])
time = (start_time, end_time)
Expand All @@ -34,8 +37,8 @@ def get_sents_times_and_translations(xml_fn):
translations.append(translation)

return root.tag, transcriptions, times, translations
print(root.tag)
assert False
logger.critical('the root tag, %s, does not contain "WORDLIST", and is not "TEXT"', root.tag)
assert False, root.tag

def remove_content_in_brackets(sentence, brackets="[]"):
out_sentence = ''
Expand Down

0 comments on commit 8ca2bbb

Please sign in to comment.