Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

minibatches for LFMMI #22

Open
yotam319 opened this issue Feb 18, 2020 · 1 comment
Open

minibatches for LFMMI #22

yotam319 opened this issue Feb 18, 2020 · 1 comment

Comments

@yotam319
Copy link

Hi,
I have a few suggestions on using LF-MMI:
I noticed you are looping over the batch, creating the supervision and calculating the criterion.
you can use MergeSupervision function I added to pykaldi in-order to create the supervision to the whole batch and run the criterion only once.

here is my collate function for the data-loader:

def supervision_collate(batch):
    """
    a collate function, for using supervision with dataloader
    """
    elem = batch[0]
    elem_type = type(elem)
    if isinstance(elem, container_abcs.Sequence):
        transposed = zip(*batch)
        return [supervision_collate(samples) for samples in transposed]
    elif isinstance(elem,kaldi.chain._chain_supervision.Supervision):
        if len(batch) == 1:
            return batch[0]
        return kaldi.chain.merge_supervison(batch)
    elif elem is None:
        return batch
    return torch.utils.data.dataloader.default_collate(batch)

To add the MergeSupervision I made a pull request to pykaldi (pykaldi/pykaldi#182), but you can use my fork that already have the change (https://github.com/yotam319/pykaldi)

also, using phone_ali gives a small supervision, you should consider using lattices and phone_lattice_to_proto_supervision instead of alignment_to_proto_supervision.

and finally, you can save your supervision as bytes and read them again.
here are the functions I used for doing this:

import kaldi
from kaldi import chain

def supervision_to_bytes(supervision):
    out_s = kaldi.base.io.stringstream()
    supervision.write(out_s,True)
    return out_s.to_bytes()

def supervision_from_supervision_bytes(supervision_bytes):
    in_s = kaldi.base.io.stringstream.from_str(supervision_bytes)
    supervision = kaldi.chain.Supervision()
    supervision.read(in_s,True)
    return supervision

def split_supervision(supervision, start, duration):
    sup_cut = kaldi.chain.SupervisionSplitter(supervision).get_frame_range(start,duration)
    sup_cut.fst = StdVectorFst(sup_cut.fst).rmepsilon()
    return sup_cut

def ali_phone_to_supervision_bytes(phones_durs,
                             opt, ctx_dep, trans_model):
    """
    input:
    phones_durs: list of phone*duration tuples
    opt: kaldi.chain.SupervisionOptions object
    ctx_dep: from kaldi.alignment.Aligner.read_tree("exp\chain\<ref_model>\tree")
    trans_model: from kaldi.alignment.Aligner.read_model("exp\chain\<ref_model>\0.trans_mdl")
    
    returns: byte representation of supervision
    """
    p_supervision = chain.alignment_to_proto_supervision_with_phones_durs(opt,phones_durs)
    supervision = chain.proto_supervision_to_supervision(ctx_dep,trans_model,p_supervision, opt.convert_to_pdfs)
    return supervision_to_bytes(supervision)

def lat_to_supervision_bytes(lat,phone_lat_mdl, phone_lat_opts,
                             supervision_opts, ctx_dep, trans_model):
    """
    input:
    lat: lattice
    phone_lat_mdl: final.mdl from the lat folder
    phone_lat_opts: PhoneAlignLatticeOptions object
    supervision_opts: kaldi.chain.SupervisionOptions object
    ctx_dep: from kaldi.alignment.Aligner.read_tree("exp\chain\<ref_model>\tree")
    trans_model: from kaldi.alignment.Aligner.read_model("exp\chain\<ref_model>\0.trans_mdl")
    
    returns: byte representation of supervision
    """
    (suc,phone_lat) = kaldi.lat.align.phone_align_lattice(lat,phone_lat_mdl, phone_lat_opts)
    assert suc
    phone_lat.topsort()
    phone_lat.topsort()
    p_supervision = chain.phone_lattice_to_proto_supervision(supervision_opts,phone_lat)
    supervision = chain.proto_supervision_to_supervision(ctx_dep,trans_model,p_supervision, supervision_opts.convert_to_pdfs)
    return supervision_to_bytes(supervision)

hope this helps :)

@jzlianglu
Copy link
Owner

Hi @yotam319, thanks a lot for your advice and code sample. Yes, what you said definitely makes sense. Previously, I only did a vanilla version of LF-MMI in the toolbox, and planed to revisit later to improve the efficiency. I noticed the code change in the pykaldi lib, but have not been able to squeeze my time to work on it. Our internal tools are not built on Kaldi, so I have very limited time to work on this toolkit. Will try to integrate your dataloader into the code base soon. Thanks again!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants