Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions egs/aishell/s10/chain/chain_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from torch.utils.dlpack import to_dlpack

import kaldi
import kaldi_pybind.chain as chain
from kaldi import chain

g_nnet_output_deriv_tensor = None
g_xent_output_deriv_tensor = None
Expand Down Expand Up @@ -56,15 +56,15 @@ def forward(ctx, opts, den_graph, supervision, nnet_output_tensor,
# it contains [objf, l2_term, weight] and will be returned to the caller
objf_l2_term_weight_tensor = torch.zeros(3).float()

nnet_output = kaldi.CuSubMatrixFromDLPack(to_dlpack(nnet_output_tensor))
nnet_output = kaldi.PytorchToCuSubMatrix(to_dlpack(nnet_output_tensor))

nnet_output_deriv = kaldi.CuSubMatrixFromDLPack(
nnet_output_deriv = kaldi.PytorchToCuSubMatrix(
to_dlpack(g_nnet_output_deriv_tensor))

xent_output_deriv = kaldi.CuSubMatrixFromDLPack(
xent_output_deriv = kaldi.PytorchToCuSubMatrix(
to_dlpack(g_xent_output_deriv_tensor))

objf_l2_term_weight = kaldi.SubVectorFromDLPack(
objf_l2_term_weight = kaldi.PytorchToSubVector(
to_dlpack(objf_l2_term_weight_tensor))

chain.ComputeChainObjfAndDeriv(opts=opts,
Expand Down
1 change: 0 additions & 1 deletion src/pybind/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,6 @@ test: all
make -C chain test
make -C cudamatrix test
make -C dlpack test
$(eval include ../../tools/env.sh)
make -C feat test
make -C fst test
make -C matrix test
Expand Down
2 changes: 1 addition & 1 deletion src/pybind/chain/chain_supervision_pybind_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import numpy as np

import kaldi_pybind.chain as chain
from kaldi import chain


class TestChainSupervision(unittest.TestCase):
Expand Down
9 changes: 4 additions & 5 deletions src/pybind/feat/feat_pybind_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@
import unittest
import numpy as np

import kaldi_pybind as k

import kaldi_pybind.feat as feat
import kaldi
from kaldi import feat
from kaldi import SequentialWaveReader
from kaldi import SequentialMatrixReader

Expand All @@ -35,8 +34,8 @@ def test_mfcc(self):
value.Duration() * value.SampFreq(),
places=1)

waveform = k.FloatSubVector(nd.reshape(nsamp))
features = k.FloatMatrix(1, 1)
waveform = kaldi.FloatSubVector(nd.reshape(nsamp))
features = kaldi.FloatMatrix(1, 1)
mfcc.ComputeFeatures(waveform, value.SampFreq(), 1.0, features)
self.assertEqual(key, gold_reader.Key())
gold_feat = gold_reader.Value().numpy()
Expand Down
2 changes: 1 addition & 1 deletion src/pybind/fst/arc_pybind_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import unittest

import kaldi_pybind.fst as fst
from kaldi import fst


class TestArc(unittest.TestCase):
Expand Down
2 changes: 1 addition & 1 deletion src/pybind/fst/fst_pybind_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import unittest

import kaldi_pybind.fst as fst
from kaldi import fst


class TestArc(unittest.TestCase):
Expand Down
2 changes: 1 addition & 1 deletion src/pybind/fst/symbol_table_pybind_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@

import unittest

import kaldi_pybind.fst as fst
import kaldi
from kaldi import fst


class TestSymbolTable(unittest.TestCase):
Expand Down
2 changes: 1 addition & 1 deletion src/pybind/fst/vector_fst_pybind_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@

import unittest

import kaldi_pybind.fst as fst
import kaldi
from kaldi import fst


class TestStdVectorFst(unittest.TestCase):
Expand Down
2 changes: 1 addition & 1 deletion src/pybind/fst/weight_pybind_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import unittest

import kaldi_pybind.fst as fst
from kaldi import fst


class TestWeight(unittest.TestCase):
Expand Down
21 changes: 0 additions & 21 deletions src/pybind/kaldi.py

This file was deleted.

32 changes: 32 additions & 0 deletions src/pybind/kaldi/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import os
import sys
sys.path.insert(0, os.path.join(os.path.dirname(__file__), os.pardir))
sys.path.insert(0, os.path.dirname(__file__))

from kaldi_pybind import *

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose there are going to be a lot of decisions about which things to put in the global kaldi namespace. But it's OK.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are lots of classes in kaldi_pybind and it is a pain to list each one explicitly.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, that's OK; kaldi is just a thin wrapper for kaldi_pybind. I was actually talking about the lines below, where you import things like Matrix. But as I said, it's OK.

from symbol_table import *
from pytorch_util import PytorchToCuSubMatrix
from pytorch_util import PytorchToCuSubVector
from pytorch_util import PytorchToSubMatrix
from pytorch_util import PytorchToSubVector

from table import SequentialNnetChainExampleReader
from table import RandomAccessNnetChainExampleReader
from table import NnetChainExampleWriter

from table import SequentialWaveReader
from table import RandomAccessWaveReader

from table import SequentialWaveInfoReader
from table import RandomAccessWaveInfoReader

from table import SequentialMatrixReader
from table import RandomAccessMatrixReader
from table import MatrixWriter

from table import SequentialVectorReader
from table import RandomAccessVectorReader
from table import VectorWriter

from table import CompressedMatrixWriter
30 changes: 30 additions & 0 deletions src/pybind/kaldi/pytorch_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#!/usr/bin/env python3

# Copyright 2019 Mobvoi AI Lab, Beijing, China (author: Fangjun Kuang)
# Apache 2.0

import os
import sys
sys.path.insert(0, os.path.join(os.path.dirname(__file__), os.pardir))

import kaldi_pybind


def PytorchToCuSubMatrix(dlpack_tensor):
cu_sub_matrix = kaldi_pybind.CuSubMatrixFromDLPack(dlpack_tensor)
return cu_sub_matrix


def PytorchToSubMatrix(dlpack_tensor):
sub_matrix = kaldi_pybind.SubMatrixFromDLPack(dlpack_tensor)
return sub_matrix


def PytorchToCuSubVector(dlpack_tensor):
cu_sub_vector = kaldi_pybind.CuSubVectorFromDLPack(dlpack_tensor)
return cu_sub_vector


def PytorchToSubVector(dlpack_tensor):
sub_vector = kaldi_pybind.SubVectorFromDLPack(dlpack_tensor)
return sub_vector
File renamed without changes.
Loading