In [1]:
%run ../chap05/dataset.ipynb

In [2]:
MAX_LENGTH = 100

ALPHABET = ['T', 'P', 'X', 'V', 'S']

EDGES = [
    [[1, 0], [3, 3]],  # 0
    [[1, 1], [2, 0]],  # 1
    [[3, 2], [5, 4]],  # 2
    [[3, 2], [4, 3]],  # 3
    [[2, 1], [5, 4]],  # 4
    []                 # 5
]

START_STATE = 0
END_STATES = [5]

In [3]:
class AutomataDataset(Dataset):
    def __init__(self):
        super(AutomataDataset, self).__init__('automata', 'binary')
        self.input_shape = [MAX_LENGTH+1, len(ALPHABET)]
        self.output_shape = [1]

    @property
    def train_count(self):
        return 10000

In [4]:
def automata_get_train_data(self, batch_size, nth):
    return automata_generate_data(batch_size)

def automata_get_validate_data(self, count):
    return automata_generate_data(count)
    
def automata_get_test_data(self):
    return automata_generate_data(1000)
    
def automata_generate_data(count):
    xs = np.zeros([count, MAX_LENGTH, 5])
    ys = np.zeros([count, 1])

    for n in range(count):
        sent = []
        while True:
            sent = automata_generate_sent()
            if len(sent) < MAX_LENGTH: break
                
        xs[n, 0, 0] = len(sent)
        xs[n, 1:len(sent)+1, :] = np.eye(5)[sent]
        ys[n, 0] = automata_is_correct_sent(sent)
        
    return xs, ys

AutomataDataset.get_train_data = automata_get_train_data    
AutomataDataset.get_validate_data = automata_get_validate_data    
AutomataDataset.get_test_data = automata_get_test_data        
AutomataDataset.get_visualize_data = automata_get_validate_data 

In [5]:
def automata_generate_sent():
    state = START_STATE
    sent = []
    alpha_cnt = len(ALPHABET)
    correct = np.random.randint(2)
    while state not in END_STATES:
        edge_cnt = len(EDGES[state])
        choice = np.random.randint(edge_cnt)
        if correct: letter = EDGES[state][choice][1]
        else: letter = np.random.randint(alpha_cnt)
        sent.append(letter)
        state = EDGES[state][choice][0]
    return sent

def automata_is_correct_sent(sent):
    state = START_STATE
    for letter in sent:
        edge_cnt = len(EDGES[state])
        next_state = -1
        for edge in EDGES[state]:
            if letter == edge[1]:
                next_state = edge[0]
                break
        if next_state < 0: return False
        state = next_state
    return state in END_STATES

In [6]:
def automata_visualize(self, xs, est, ans):
    for n in range(len(xs)):
        length = int(xs[n, 0, 0])
        sent = np.argmax(xs[n, 1:length+1], axis=1)
        text = "".join([ALPHABET[letter] for letter in sent])
        
        answer, guess, result = '잘못된 패턴', '탈락추정', 'X'
        
        if ans[n][0] > 0.5: answer = '올바른 패턴'
        if est[n][0] > 0.5: guess = '합격추정'
        if ans[n][0] > 0.5 and est[n][0] > 0.5: result = 'O'
        if ans[n][0] < 0.5 and est[n][0] < 0.5: result = 'O'

        print('{}: {} => {}({:4.2f}) : {}'. \
            format(text, answer, guess, est[n][0], result))
        
AutomataDataset.visualize = automata_visualize