In [1]:
import torch
import torch.nn as nn
from torch.autograd import Variable
import collections

In [2]:
class strLabelConverter(object):
    """Convert between str and label.

    NOTE:
        Insert `blank` to the alphabet for CTC.

    Args:
        alphabet (str): set of the possible characters.
        ignore_case (bool, default=True): whether or not to ignore all of the case.
    """

    def __init__(self, alphabet, ignore_case=True):
        self._ignore_case = ignore_case
        if self._ignore_case:
            alphabet = alphabet.lower()
        self.alphabet = alphabet + '-'  # for last index

        self.dict = {}
        for i, char in enumerate(alphabet):
            # NOTE: 0 is reserved for 'blank' required by wrap_ctc
            self.dict[char] = i + 1 # initialized coding for dictionary used

    def encode(self, text):
        """Support batch or single str.

        Args:
            text (str or list of str): texts to convert.

        Returns:
            torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts.
            torch.IntTensor [n]: length of each text.
        """
        if isinstance(text, str):
            text = [
                self.dict[char.lower() if self._ignore_case else char]
                for char in text
            ]
            length = [len(text)]
        elif isinstance(text, collections.Iterable):
            length = [len(s) for s in text]
            text = ''.join(text)
            text, _ = self.encode(text)
        return (torch.IntTensor(text), torch.IntTensor(length))

    def decode(self, t, length, raw=False):
        """Decode encoded texts back into strs.

        Args:
            torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts.
            torch.IntTensor [n]: length of each text.

        Raises:
            AssertionError: when the texts and its length does not match.

        Returns:
            text (str or list of str): texts to convert.
        """
        if length.numel() == 1:
            length = length[0]
            assert t.numel() == length, "text with length: {} does not match declared length: {}".format(t.numel(), length)
            if raw:
                return ''.join([self.alphabet[i - 1] for i in t])
            else:
                char_list = []
                for i in range(length):
                    if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])):
                        char_list.append(self.alphabet[t[i] - 1])
                return ''.join(char_list)
        else:
            # batch mode
            assert t.numel() == length.sum(), "texts with length: {} does not match declared length: {}".format(t.numel(), length.sum())
            texts = []
            index = 0
            for i in range(length.numel()):
                l = length[i]
                texts.append(
                    self.decode(
                        t[index:index + l], torch.IntTensor([l]), raw=raw))
                index += l
            return texts

In [3]:
alphabet = 'abcdefg'

samples = ['abb', 'cee', 'fga']

converter = strLabelConverter(alphabet=alphabet)

In [4]:
converter.encode(samples)

  elif isinstance(text, collections.Iterable):


(tensor([1, 2, 2, 3, 5, 5, 6, 7, 1], dtype=torch.int32),
 tensor([3, 3, 3], dtype=torch.int32))

In [32]:
class labelConverter(object):

    def __init__(self, alphabet):
        self.dict = {}
        for i, char in enumerate(alphabet):
            self.dict[char] = i + 1

        self.dict[''] = 0

    def encode(self, text):
        '''Convert list of strings to label seq'''
        length=[]
        seq=[]

        for item in text:
            length.append(len(item))

            for char in item:
                if char in self.dict:
                    label = self.dict[char]
                else:
                    label = 0
                seq.append(label)
        
        return (torch.IntTensor(seq),torch.IntTensor(length))

    def decode(self, seq, length):
        '''Reverse the above conversion'''
        seq_len = seq.numel() # return total no. of elements without multiplying dims
        word_no = length.numel()

        # word = []
        text = []
        
        assert sum(length)==seq_len
        start = 0
        for i in range(word_no):
            chars = []
            for label in list(seq[start : start+length[i]]):
                char = list(self.dict.keys())[list(self.dict.values()).index(label)]
                chars.append(char)
            word = ''.join(chars)
            start += length[i]
            text.append(word)
        
        return text
    


In [33]:
alphabet = 'abcdefg'

samples = ['abb', 'cee', 'fga']

converter = labelConverter(alphabet=alphabet)
print(converter.dict)
x,_ = converter.encode(samples)

converter.decode(x,torch.IntTensor([3,3,3]))

{'a': 1, 'b': 2, 'c': 3, 'd': 4, 'e': 5, 'f': 6, 'g': 7, '': 0}


['abb', 'cee', 'fga']