|
| 1 | +#!/usr/bin/env python3 |
| 2 | + |
| 3 | +# Copyright 2017 Johns Hopkins University (Author: Daniel Povey) |
| 4 | +# 2017 Johns Hopkins University (Author: Daniel Garcia-Romero) |
| 5 | +# 2017 David Snyder |
| 6 | +# Apache 2.0 |
| 7 | + |
| 8 | +# This script, which is used in getting training examples, decides |
| 9 | +# which examples will come from which recordings, and at what point |
| 10 | +# during the training. |
| 11 | + |
| 12 | +# You call it as (e.g.) |
| 13 | +# |
| 14 | +# allocate_egs.py --min-frames-per-chunk=50 --max-frames-per-chunk=200 \ |
| 15 | +# --frames-per-iter=1000000 --num-repeats=60 --num-archives=169 \ |
| 16 | +# --num-jobs=24 exp/xvector_a/egs/temp/utt2len.train exp/xvector_a/egs |
| 17 | +# |
| 18 | +# The program outputs certain things to the temp directory (e.g., |
| 19 | +# exp/xvector_a/egs/temp) that will enable you to dump the chunks for xvector |
| 20 | +# training. What we'll eventually be doing is invoking the following program |
| 21 | +# with something like the following args: |
| 22 | +# |
| 23 | +# nnet3-xvector-get-egs [options] exp/xvector_a/temp/ranges.1 \ |
| 24 | +# scp:data/train/feats.scp ark:exp/xvector_a/egs/egs_temp.1.ark \ |
| 25 | +# ark:exp/xvector_a/egs/egs_temp.2.ark ark:exp/xvector_a/egs/egs_temp.3.ark |
| 26 | +# |
| 27 | +# where exp/xvector_a/temp/ranges.1 contains something like the following: |
| 28 | +# |
| 29 | +# utt1 0 1 0 65 0 |
| 30 | +# utt1 6 7 160 50 0 |
| 31 | +# utt2 ... |
| 32 | +# |
| 33 | +# where each line is interpreted as follows: |
| 34 | +# <source-utterance> <relative-archive-index> <absolute-archive-index> \ |
| 35 | +# <start-frame-index> <num-frames> <spkr-label> |
| 36 | +# |
| 37 | +# Note: <relative-archive-index> is the zero-based offset of the archive-index |
| 38 | +# within the subset of archives that a particular ranges file corresponds to; |
| 39 | +# and <absolute-archive-index> is the 1-based numeric index of the destination |
| 40 | +# archive among the entire list of archives, which will form part of the |
| 41 | +# archive's filename (e.g. egs/egs.<absolute-archive-index>.ark); |
| 42 | +# <absolute-archive-index> is only kept for debug purposes so you can see which |
| 43 | +# archive each line corresponds to. |
| 44 | +# |
| 45 | +# For each line of the ranges file, we specify an eg containing a chunk of data |
| 46 | +# from a given utterane, the corresponding speaker label, and the output |
| 47 | +# archive. The list of archives corresponding to ranges.n will be written to |
| 48 | +# output.n, so in exp/xvector_a/temp/outputs.1 we'd have: |
| 49 | +# |
| 50 | +# ark:exp/xvector_a/egs/egs_temp.1.ark ark:exp/xvector_a/egs/egs_temp.2.ark \ |
| 51 | +# ark:exp/xvector_a/egs/egs_temp.3.ark |
| 52 | +# |
| 53 | +# The number of these files will equal 'num-jobs'. If you add up the |
| 54 | +# word-counts of all the outputs.* files you'll get 'num-archives'. The number |
| 55 | +# of frames in each archive will be about the --frames-per-iter. |
| 56 | +# |
| 57 | +# This program will also output to the temp directory a file called |
| 58 | +# archive_chunk_length which tells you the frame-length associated with |
| 59 | +# each archive, e.g., |
| 60 | +# 1 60 |
| 61 | +# 2 120 |
| 62 | +# the format is: <archive-index> <num-frames>. The <num-frames> will always |
| 63 | +# be in the range [min-frames-per-chunk, max-frames-per-chunk]. |
| 64 | + |
| 65 | + |
| 66 | +# We're using python 3.x style print but want it to work in python 2.x. |
| 67 | +from __future__ import print_function |
| 68 | +import re, os, argparse, sys, math, warnings, random |
| 69 | + |
| 70 | +def get_args(): |
| 71 | + parser = argparse.ArgumentParser(description="Writes ranges.*, outputs.* and archive_chunk_lengths files " |
| 72 | + "in preparation for dumping egs for xvector training.", |
| 73 | + epilog="Called by sid/nnet3/xvector/get_egs.sh") |
| 74 | + parser.add_argument("--prefix", type=str, default="", |
| 75 | + help="Adds a prefix to the output files. This is used to distinguish between the train " |
| 76 | + "and diagnostic files.") |
| 77 | + parser.add_argument("--num-repeats", type=int, default=10, help="Number of times each speaker repeats within an archive.") |
| 78 | + parser.add_argument("--min-frames-per-chunk", type=int, default=50, |
| 79 | + help="Minimum number of frames-per-chunk used for any archive") |
| 80 | + parser.add_argument("--max-frames-per-chunk", type=int, default=300, |
| 81 | + help="Maximum number of frames-per-chunk used for any archive") |
| 82 | + parser.add_argument("--randomize-chunk-length", type=str, |
| 83 | + help="If true, randomly pick a chunk length in [min-frames-per-chunk, max-frames-per-chunk]." |
| 84 | + "If false, the chunk length varies from min-frames-per-chunk to max-frames-per-chunk" |
| 85 | + "according to a geometric sequence.", |
| 86 | + default="true", choices = ["false", "true"]) |
| 87 | + parser.add_argument("--frames-per-iter", type=int, default=1000000, |
| 88 | + help="Target number of frames for each archive") |
| 89 | + parser.add_argument("--num-archives", type=int, default=-1, |
| 90 | + help="Number of archives to write"); |
| 91 | + parser.add_argument("--num-jobs", type=int, default=-1, |
| 92 | + help="Number of jobs we're going to use to write the archives; the ranges.* " |
| 93 | + "and outputs.* files are indexed by job. Must be <= the --num-archives option."); |
| 94 | + parser.add_argument("--seed", type=int, default=123, |
| 95 | + help="Seed for random number generator") |
| 96 | + parser.add_argument("--num-pdfs", type=int, default=-1, |
| 97 | + help="Num pdfs") |
| 98 | + |
| 99 | + # now the positional arguments |
| 100 | + parser.add_argument("--utt2len-filename", type=str, required=True, |
| 101 | + help="utt2len file of the features to be used as input (format is: " |
| 102 | + "<utterance-id> <num-frames>)"); |
| 103 | + parser.add_argument("--utt2int-filename", type=str, required=True, |
| 104 | + help="utt2int file of the features to be used as input (format is: " |
| 105 | + "<utterance-id> <id>)"); |
| 106 | + parser.add_argument("--egs-dir", type=str, required=True, |
| 107 | + help="Name of egs directory, e.g. exp/xvector_a/egs"); |
| 108 | + |
| 109 | + print(' '.join(sys.argv), file=sys.stderr) |
| 110 | + print(sys.argv, file=sys.stderr) |
| 111 | + args = parser.parse_args() |
| 112 | + args = process_args(args) |
| 113 | + return args |
| 114 | + |
| 115 | +def process_args(args): |
| 116 | + if args.num_repeats < 1: |
| 117 | + raise Exception("--num-repeats should have a minimum value of 1") |
| 118 | + if not os.path.exists(args.utt2int_filename): |
| 119 | + raise Exception("This script expects --utt2int-filename to exist") |
| 120 | + if not os.path.exists(args.utt2len_filename): |
| 121 | + raise Exception("This script expects --utt2len-filename to exist") |
| 122 | + if args.min_frames_per_chunk <= 1: |
| 123 | + raise Exception("--min-frames-per-chunk is invalid.") |
| 124 | + if args.max_frames_per_chunk < args.min_frames_per_chunk: |
| 125 | + raise Exception("--max-frames-per-chunk is invalid.") |
| 126 | + if args.frames_per_iter < 1000: |
| 127 | + raise Exception("--frames-per-iter is invalid.") |
| 128 | + if args.num_archives < 1: |
| 129 | + raise Exception("--num-archives is invalid") |
| 130 | + if args.num_jobs > args.num_archives: |
| 131 | + raise Exception("--num-jobs is invalid (must not exceed num-archives)") |
| 132 | + return args |
| 133 | + |
| 134 | +# Create utt2len |
| 135 | +def get_utt2len(utt2len_filename): |
| 136 | + utt2len = {} |
| 137 | + f = open(utt2len_filename, "r") |
| 138 | + if f is None: |
| 139 | + sys.exit("Error opening utt2len file " + str(utt2len_filename)) |
| 140 | + utt_ids = [] |
| 141 | + lengths = [] |
| 142 | + for line in f: |
| 143 | + tokens = line.split() |
| 144 | + if len(tokens) != 2: |
| 145 | + sys.exit("bad line in utt2len file " + line) |
| 146 | + utt2len[tokens[0]] = int(tokens[1]) |
| 147 | + f.close() |
| 148 | + return utt2len |
| 149 | + # Done utt2len |
| 150 | + |
| 151 | +# Handle utt2int, create spk2utt, spks |
| 152 | +def get_labels(utt2int_filename): |
| 153 | + f = open(utt2int_filename, "r") |
| 154 | + if f is None: |
| 155 | + sys.exit("Error opening utt2int file " + str(utt2int_filename)) |
| 156 | + spk2utt = {} |
| 157 | + utt2spk = {} |
| 158 | + for line in f: |
| 159 | + tokens = line.split() |
| 160 | + if len(tokens) != 2: |
| 161 | + sys.exit("bad line in utt2int file " + line) |
| 162 | + spk = int(tokens[1]) |
| 163 | + utt = tokens[0] |
| 164 | + utt2spk[utt] = spk |
| 165 | + if spk not in spk2utt: |
| 166 | + spk2utt[spk] = [utt] |
| 167 | + else: |
| 168 | + spk2utt[spk].append(utt) |
| 169 | + spks = spk2utt.keys() |
| 170 | + f.close() |
| 171 | + return spks, spk2utt, utt2spk |
| 172 | + # Done utt2int |
| 173 | + |
| 174 | + |
| 175 | +# this function returns a random integer utterance index, limited to utterances |
| 176 | +# above a minimum length in frames, with probability proportional to its length. |
| 177 | +def get_random_utt(spkr, spk2utt, min_length): |
| 178 | + this_utts = spk2utt[spkr] |
| 179 | + this_num_utts = len(this_utts) |
| 180 | + i = random.randint(0, this_num_utts-1) |
| 181 | + utt = this_utts[i] |
| 182 | + return utt |
| 183 | + |
| 184 | +def random_chunk_length(min_frames_per_chunk, max_frames_per_chunk): |
| 185 | + ans = random.randint(min_frames_per_chunk, max_frames_per_chunk) |
| 186 | + return ans |
| 187 | + |
| 188 | +# This function returns an integer in the range |
| 189 | +# [min-frames-per-chunk, max-frames-per-chunk] according to a geometric |
| 190 | +# sequence. For example, suppose min-frames-per-chunk is 50, |
| 191 | +# max-frames-per-chunk is 200, and args.num_archives is 3. Then the |
| 192 | +# lengths for archives 0, 1, and 2 will be 50, 100, and 200. |
| 193 | +def deterministic_chunk_length(archive_id, num_archives, min_frames_per_chunk, max_frames_per_chunk): |
| 194 | + if max_frames_per_chunk == min_frames_per_chunk: |
| 195 | + return max_frames_per_chunk |
| 196 | + elif num_archives == 1: |
| 197 | + return int(max_frames_per_chunk); |
| 198 | + else: |
| 199 | + return int(math.pow(float(max_frames_per_chunk) / |
| 200 | + min_frames_per_chunk, float(archive_id) / |
| 201 | + (num_archives-1)) * min_frames_per_chunk + 0.5) |
| 202 | + |
| 203 | + |
| 204 | + |
| 205 | +# given an utterance length utt_length (in frames) and two desired chunk lengths |
| 206 | +# (length1 and length2) whose sum is <= utt_length, |
| 207 | +# this function randomly picks the starting points of the chunks for you. |
| 208 | +# the chunks may appear randomly in either order. |
| 209 | +def get_random_offset(utt_length, length): |
| 210 | + if length > utt_length: |
| 211 | + sys.exit("code error: length > utt-length") |
| 212 | + free_length = utt_length - length |
| 213 | + |
| 214 | + offset = random.randint(0, free_length) |
| 215 | + return offset |
| 216 | + |
| 217 | + |
| 218 | +def main(): |
| 219 | + args = get_args() |
| 220 | + if not os.path.exists(args.egs_dir + "/temp"): |
| 221 | + os.makedirs(args.egs_dir + "/temp") |
| 222 | + random.seed(args.seed) |
| 223 | + utt2len = get_utt2len(args.utt2len_filename) |
| 224 | + spks, spk2utt, utt2spk = get_labels(args.utt2int_filename) |
| 225 | + if args.num_pdfs == -1: |
| 226 | + args.num_pdfs = max(spks) + 1 |
| 227 | + |
| 228 | + # archive_chunk_lengths is an mapping from archive id to the number of |
| 229 | + # frames in examples of that archive. |
| 230 | + archive_chunk_lengths = [] |
| 231 | + # all_egs contains 2-tuples of the form (utt-id, offset) |
| 232 | + all_egs= [] |
| 233 | + |
| 234 | + prefix = "" |
| 235 | + if args.prefix != "": |
| 236 | + prefix = args.prefix + "_" |
| 237 | + |
| 238 | + info_f = open(args.egs_dir + "/temp/" + prefix + "archive_chunk_lengths", "w") |
| 239 | + if info_f is None: |
| 240 | + sys.exit(str("Error opening file {0}/temp/" + prefix + "archive_chunk_lengths").format(args.egs_dir)); |
| 241 | + for archive_index in range(args.num_archives): |
| 242 | + print("Processing archive {0}".format(archive_index + 1)) |
| 243 | + if args.randomize_chunk_length == "true": |
| 244 | + # don't constrain the lengths to be the same |
| 245 | + length = random_chunk_length(args.min_frames_per_chunk, args.max_frames_per_chunk) |
| 246 | + else: |
| 247 | + length = deterministic_chunk_length(archive_index, args.num_archives, args.min_frames_per_chunk, args.max_frames_per_chunk); |
| 248 | + print("{0} {1}".format(archive_index + 1, length), file=info_f) |
| 249 | + archive_chunk_lengths.append(length) |
| 250 | + this_num_egs = int((args.frames_per_iter / length) + 1) |
| 251 | + this_egs = [ ] # A 2-tuple of the form (utt-id, start-frame) |
| 252 | + spkrs = args.num_repeats * list(spk2utt.keys()) |
| 253 | + random.shuffle(spkrs) |
| 254 | + for n in range(this_num_egs): |
| 255 | + if len(spkrs) == 0: |
| 256 | + print("Ran out of speakers for archive {0}".format(archive_index + 1)) |
| 257 | + break |
| 258 | + spkr = spkrs.pop() |
| 259 | + utt = get_random_utt(spkr, spk2utt, length) |
| 260 | + utt_len = utt2len[utt] |
| 261 | + offset = get_random_offset(utt_len, length) |
| 262 | + this_egs.append( (utt, offset) ) |
| 263 | + all_egs.append(this_egs) |
| 264 | + info_f.close() |
| 265 | + |
| 266 | + # work out how many archives we assign to each job in an equitable way. |
| 267 | + num_archives_per_job = [ 0 ] * args.num_jobs |
| 268 | + for i in range(0, args.num_archives): |
| 269 | + num_archives_per_job[i % args.num_jobs] = num_archives_per_job[i % args.num_jobs] + 1 |
| 270 | + |
| 271 | + pdf2num = {} |
| 272 | + cur_archive = 0 |
| 273 | + for job in range(args.num_jobs): |
| 274 | + this_ranges = [] |
| 275 | + this_archives_for_job = [] |
| 276 | + this_num_archives = num_archives_per_job[job] |
| 277 | + |
| 278 | + for i in range(0, this_num_archives): |
| 279 | + this_archives_for_job.append(cur_archive) |
| 280 | + for (utterance_index, offset) in all_egs[cur_archive]: |
| 281 | + this_ranges.append( (utterance_index, i, offset) ) |
| 282 | + cur_archive = cur_archive + 1 |
| 283 | + |
| 284 | + f = open(args.egs_dir + "/temp/" + prefix + "ranges." + str(job + 1), "w") |
| 285 | + if f is None: |
| 286 | + sys.exit("Error opening file " + args.egs_dir + "/temp/" + prefix + "ranges." + str(job + 1)) |
| 287 | + for (utterance_index, i, offset) in sorted(this_ranges): |
| 288 | + archive_index = this_archives_for_job[i] |
| 289 | + print("{0} {1} {2} {3} {4} {5}".format(utterance_index, |
| 290 | + i, |
| 291 | + archive_index + 1, |
| 292 | + offset, |
| 293 | + archive_chunk_lengths[archive_index], |
| 294 | + utt2spk[utterance_index]), |
| 295 | + file=f) |
| 296 | + if utt2spk[utterance_index] in pdf2num: |
| 297 | + pdf2num[utt2spk[utterance_index]] += 1 |
| 298 | + else: |
| 299 | + pdf2num[utt2spk[utterance_index]] = 1 |
| 300 | + f.close() |
| 301 | + |
| 302 | + |
| 303 | + f = open(args.egs_dir + "/temp/" + prefix + "outputs." + str(job + 1), "w") |
| 304 | + if f is None: |
| 305 | + sys.exit("Error opening file " + args.egs_dir + "/temp/" + prefix + "outputs." + str(job + 1)) |
| 306 | + print( " ".join([ str("{0}/" + prefix + "egs_temp.{1}.ark").format(args.egs_dir, n + 1) for n in this_archives_for_job ]), |
| 307 | + file=f) |
| 308 | + f.close() |
| 309 | + |
| 310 | + f = open(args.egs_dir + "/" + prefix + "pdf2num", "w") |
| 311 | + nums = [] |
| 312 | + for k in range(0, args.num_pdfs): |
| 313 | + if k in pdf2num: |
| 314 | + nums.append(pdf2num[k]) |
| 315 | + else: |
| 316 | + nums.append(0) |
| 317 | + |
| 318 | + print(" ".join(map(str, nums)), file=f) |
| 319 | + f.close() |
| 320 | + |
| 321 | + print("allocate_egs.py: finished generating " + prefix + "ranges.* and " + prefix + "outputs.* files") |
| 322 | + |
| 323 | +if __name__ == "__main__": |
| 324 | + main() |
| 325 | + |
0 commit comments