Permalink
Browse files

Move babbling classes into a separate module.

  • Loading branch information...
1 parent 3f2b043 commit e1773d94613c2f763928c8ea1856b1e97c776396 Leif Johnson committed Mar 28, 2012
Showing with 89 additions and 69 deletions.
  1. +2 −2 lmj/trm/__init__.py
  2. +87 −0 lmj/trm/babbler.py
  3. +0 −67 lmj/trm/postures.py
View
@@ -21,5 +21,5 @@
'''A Python wrapper for the Tube Resonance Model from gnuspeech.'''
from tube import Parameters, TubeModel, parse_input_file, synthesize
-
-import synth
+from postures import Repertoire
+import babbler
View
@@ -0,0 +1,87 @@
+# Copyright (c) 2011 Leif Johnson <leif@leifjohnson.net>
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+'''Classes for babbling using the postures in a repertoire.'''
+
+
+class Babbler(list):
+ '''A babbler generates sequences of phone symbols from a repertoire.'''
+
+ def __init__(self, repertoire):
+ self.extend(sorted(repertoire.postures))
+
+ def generate(self, n=7):
+ for _ in range(n):
+ yield self[self.index()]
+
+ def index(self):
+ raise NotImplementedError
+
+
+class Uniform(Babbler):
+ '''A uniform babbler selects phones randomly.'''
+
+ def index(self):
+ return rng.randint(len(self))
+
+
+class Unigram(Babbler):
+ '''A unigram babbler samples from a discrete distribution.'''
+
+ def __init__(self, repertoire, pmf):
+ super(Unigram, self).__init__(repertoire)
+
+ pmf = numpy.asarray(pmf)
+ assert len(pmf) == len(self)
+ assert numpy.allclose(pmf.sum(), 1)
+ self.cdf = pmf.cumsum()
+
+ def index(self):
+ return self.cdf.searchsorted(rng.random())
+
+
+class Bigram(Babbler):
+ '''A bigram babbler samples from a discrete distribution for each phone.'''
+
+ def __init__(self, repertoire, unigrams, bigrams):
+ super(Bigram, self).__init__(repertoire)
+
+ unigrams = numpy.asarray(unigrams)
+ assert len(unigrams) == len(self)
+ assert numpy.allclose(unigrams.sum(), 1)
+ self.unigram_cdf = unigrams.cumsum()
+
+ bigrams = numpy.asarray(bigrams)
+ assert bigrams.shape == (len(self), len(self))
+ assert numpy.allclose(bigrams.sum(axis=1), numpy.ones(len(bigrams)))
+ self.bigram_cdfs = [p.cumsum() for p in bigrams]
+
+ self.idx = None
+
+ def generate(self, *args, **kwargs):
+ self.idx = None
+ return super(Bigram, self).generate(*args, **kwargs)
+
+ def index(self):
+ cdf = self.unigram_cdf
+ if self.idx is not None:
+ cdf = self.bigram_cdfs[self.idx]
+ self.idx = cdf.searchsorted(rng.random())
+ return self.idx
View
@@ -303,70 +303,3 @@ def interpolate(self, control_rate, *symbols):
scipy.interpolate.UnivariateSpline(times, p, k=3)
for p in numpy.array(postures).T]
return numpy.array([s(t) for s in interpolators]).T
-
-
-class Planner(list):
- '''A planner generates sequences of phone symbols from a repertoire.
- '''
-
- def __init__(self, repertoire):
- self.extend(sorted(repertoire.postures))
-
- def generate(self, n=7):
- for _ in range(n):
- yield self[self.index()]
-
- def index(self):
- raise NotImplementedError
-
-
-class Uniform(Planner):
- '''A uniform planner selects phones randomly.'''
-
- def index(self):
- return rng.randint(len(self))
-
-
-class Unigram(Planner):
- '''A unigram planner samples from a discrete distribution.'''
-
- def __init__(self, repertoire, pmf):
- super(Unigram, self).__init__(repertoire)
-
- pmf = numpy.asarray(pmf)
- assert len(pmf) == len(self)
- assert numpy.allclose(pmf.sum(), 1)
- self.cdf = pmf.cumsum()
-
- def index(self):
- return self.cdf.searchsorted(rng.random())
-
-
-class Bigram(Planner):
- '''A bigram planner samples from a discrete distribution for each phone.'''
-
- def __init__(self, repertoire, unigrams, bigrams):
- super(Bigram, self).__init__(repertoire)
-
- unigrams = numpy.asarray(unigrams)
- assert len(unigrams) == len(self)
- assert numpy.allclose(unigrams.sum(), 1)
- self.unigram_cdf = unigrams.cumsum()
-
- bigrams = numpy.asarray(bigrams)
- assert bigrams.shape == (len(self), len(self))
- assert numpy.allclose(bigrams.sum(axis=1), numpy.ones(len(bigrams)))
- self.bigram_cdfs = [p.cumsum() for p in bigrams]
-
- self.idx = None
-
- def generate(self, *args, **kwargs):
- self.idx = None
- return super(Bigram, self).generate(*args, **kwargs)
-
- def index(self):
- cdf = self.unigram_cdf
- if self.idx is not None:
- cdf = self.bigram_cdfs[self.idx]
- self.idx = cdf.searchsorted(rng.random())
- return self.idx

0 comments on commit e1773d9

Please sign in to comment.