<a href="https://colab.research.google.com/github/BYU-Handwriting-Lab/GettingStarted/blob/master/notebooks/transcription_correction.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Language Model Text Error Correction

This notebook contains code that corrects the output of a handwriting
recognition model using techniques from neural machine translation. We
implement a basic encoder/decoder architecture with a transformer to
correct the output.

In [0]:
try:
  %tensorflow_version 2.x
except Exception:
  pass

In [0]:
# TensorFlow
import tensorflow as tf
import tensorflow_addons as tfa

# Python
import os
import time

# Data Structures
import pandas as pd
import numpy as np

# Image/Plotting
from matplotlib import pyplot as plt

# Debugging
from tqdm import tqdm
from IPython.core.ultratb import AutoFormattedTB
__ITB__ = AutoFormattedTB(mode='Verbose', color_scheme='LightBg', tb_offset=1)

Download the Dataset from Google Drive

In [0]:
# ID: 1w0sumZm2YPxgMAsz9utAm9t2HriIe-JL
!wget -q --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1w0sumZm2YPxgMAsz9utAm9t2HriIe-JL' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1w0sumZm2YPxgMAsz9utAm9t2HriIe-JL" -O error.csv && rm -rf /tmp/cookies.txt

### Character to Index Mapping

In [0]:
# CHAR_SET = {"idx_to_char": {"1": " ", "2": "!", "3": "\"", "4": "#", "5": "$", "6": "%", "7": "&", "8": "'", "9": "(","10": ")", "11": "*", "12": "+", "13": ",", "14": "-", "15": ".", "16": "/", "17": "0", "18": "1", "19": "2", "20": "3", "21": "4", "22": "5", "23": "6", "24": "7", "25": "8", "26": "9", "27": ":", "28": ";", "29": "=", "30": "?", "31": "A", "32": "B", "33": "C", "34": "D", "35": "E", "36": "F", "37": "G", "38": "H", "39": "I", "40": "J", "41": "K", "42": "L", "43": "M", "44": "N", "45": "O", "46": "P", "47": "Q", "48": "R", "49": "S", "50": "T", "51": "U", "52": "V", "53": "W", "54": "X", "55": "Y", "56": "Z", "57": "[", "58": "]", "59": "_", "60": "`", "61": "a", "62": "b", "63": "c", "64": "d", "65": "e", "66": "f", "67": "g", "68": "h", "69": "i", "70": "j", "71": "k", "72": "l", "73": "m", "74": "n", "75": "o", "76": "p", "77": "q", "78": "r", "79": "s", "80": "t", "81": "u", "82": "v", "83": "w", "84": "x", "85": "y", "86": "z", "87": "|", "88": "~", "89": "\u00a3", "90": "\u00a7", "91": "\u00a8", "92": "\u00ab", "93": "\u00ac", "94": "\u00ad", "95": "\u00b0", "96": "\u00b2", "97": "\u00b4", "98": "\u00b7", "99": "\u00ba", "100": "\u00bb", "101": "\u00bc", "102": "\u00bd", "103": "\u00be", "104": "\u00c0", "105": "\u00c2", "106": "\u00c4", "107": "\u00c7", "108": "\u00c8", "109": "\u00c9", "110": "\u00ca", "111": "\u00d4", "112": "\u00d6", "113": "\u00dc", "114": "\u00df", "115": "\u00e0", "116": "\u00e1", "117": "\u00e2", "118": "\u00e4", "119": "\u00e6", "120": "\u00e7", "121": "\u00e8", "122": "\u00e9", "123": "\u00ea", "124": "\u00eb", "125": "\u00ec", "126": "\u00ee", "127": "\u00ef", "128": "\u00f1", "129": "\u00f2", "130": "\u00f3", "131": "\u00f4", "132": "\u00f6", "133": "\u00f8", "134": "\u00f9", "135": "\u00fa", "136": "\u00fb", "137": "\u00fc", "138": "\u00ff", "139": "\u0142", "140": "\u0152", "141": "\u0153", "142": "\u0393", "143": "\u0396", "144": "\u03a4", "145": "\u03ac", "146": "\u03ae", "147": "\u03b1", "148": "\u03b4", "149": "\u03b5", "150": "\u03b7", "151": "\u03b9", "152": "\u03ba", "153": "\u03bb", "154": "\u03bc", "155": "\u03bd", "156": "\u03be", "157": "\u03bf", "158": "\u03c0", "159": "\u03c1", "160": "\u03c4", "161": "\u03c5", "162": "\u03c7", "163": "\u03c8", "164": "\u03c9", "165": "\u03cc", "166": "\u03ce", "167": "\u0406", "168": "\u2012", "169": "\u2013", "170": "\u2014", "171": "\u2020", "172": "\u2021", "173": "\u2030", "174": "\u2039", "175": "\u203a", "176": "\u2082", "177": "\u20a4", "178": "\u2114", "179": "\u2153", "180": "\u2154", "181": "\u2155", "182": "\u2156", "183": "\u2157", "184": "\u2158", "185": "\u2159", "186": "\u215a", "187": "\u215b", "188": "\u2206", "189": "\u2207", "190": "\u222b", "191": "\u2260", "192": "\u25a1", "193": "\u2640", "194": "\u2642", "195": "\u2713", "196": "\uff46"},
            # "char_to_idx": {"\u203a": 175, "\u2014": 170, "\u25a1": 192, " ": 1, "\u00a3": 89, "$": 5, "\u00a7": 90, "(": 9, "\u00ab": 92, "\u2206": 188, ",": 13, "\u03b1": 147, "0": 17, "\u03b5": 149, "4": 21, "\u00b7": 98, "\u03b9": 151, "8": 25, "\u00bb": 100, "\u03bd": 155, "\u03c1": 159, "\u2640": 193, "\u0142": 139, "\u03c5": 161, "D": 34, "\u00c7": 107, "\u2260": 191, "\u03c9": 164, "H": 38, "L": 42, "P": 46, "\u0152": 140, "T": 50, "\u2156": 182, "X": 54, "\u215a": 186, "\u00df": 114, "`": 60, "d": 64, "\u00e7": 120, "h": 68, "\u00eb": 124, "l": 72, "\u00ef": 127, "p": 76, "\u00f3": 130, "t": 80, "x": 84, "\u00fb": 136, "|": 87, "\u00ff": 138, "\u2207": 189, "\u2153": 179, "\u2013": 169, "\u0396": 143, "#": 4, "\u20a4": 177, "'": 8, "\u00a8": 91, "+": 12, "\u00ac": 93, "/": 16, "\u03ae": 146, "\u00b0": 95, "3": 20, "\u00b4": 97, "7": 24, ";": 28, "\u03ba": 152, "\u00bc": 101, "?": 30, "\u03be": 156, "\u00c0": 104, "C": 33, "\u00c4": 106, "G": 37, "\u2020": 171, "\u00c8": 108, "K": 41, "O": 45, "\u03ce": 166, "S": 49, "\u2155": 181, "\u00d4": 111, "W": 53, "\u2159": 185, "[": 57, "\u00dc": 113, "_": 59, "\u00e0": 115, "c": 63, "\u00e4": 118, "g": 67, "\u00e8": 121, "k": 71, "\u00ec": 125, "o": 75, "s": 79, "\u00f4": 131, "w": 83, "\u00f8": 133, "\u2021": 172, "\u00fc": 137, "\u2030": 173, "\u0406": 167, "\u0393": 142, "\u2012": 168, "\u2114": 178, "\"": 3, "&": 7, "*": 11, "\u00ad": 94, ".": 15, "2": 19, "\u03b7": 150, "6": 23, "\u03bb": 153, ":": 27, "\u00bd": 102, "\u03bf": 157, "B": 32, "\u03c7": 162, "F": 36, "\u00c9": 109, "J": 40, "N": 44, "R": 48, "\u2154": 180, "V": 52, "\u2158": 184, "Z": 56, "\u00e1": 116, "b": 62, "\u2039": 174, "f": 66, "\u00e9": 122, "j": 70, "n": 74, "\u00f1": 128, "r": 78, "v": 82, "\u00f9": 134, "z": 86, "~": 88, "\u2082": 176, "\u2713": 195, "\u2642": 194, "!": 2, "%": 6, "\u03a4": 144, ")": 10, "\uff46": 196, "-": 14, "\u03ac": 145, "1": 18, "\u00b2": 96, "5": 22, "\u03b4": 148, "9": 26, "\u00ba": 99, "=": 29, "\u03bc": 154, "\u00be": 103, "A": 31, "\u03c0": 158, "\u00c2": 105, "E": 35, "\u03c4": 160, "I": 39, "\u03c8": 163, "\u00ca": 110, "M": 43, "\u03cc": 165, "Q": 47, "\u0153": 141, "U": 51, "\u2157": 183, "\u00d6": 112, "Y": 55, "\u215b": 187, "]": 58, "a": 61, "\u00e2": 117, "e": 65, "\u00e6": 119, "i": 69, "\u00ea": 123, "m": 73, "\u00ee": 126, "q": 77, "\u00f2": 129, "u": 81, "\u00f6": 132, "y": 85, "\u00fa": 135, "\u222b": 190}
            # }

CHAR_SET = {'char_to_idx': {'Ĵ': '318', '¬': '1', 'Õ': '2', 'Y': '3', 'Į': '4', 'ø': '5', 'Ÿ': '6', ',': '7', '«': '8', 'ĳ': '9', 'e': '10', 'Ô': '11', 'U': '12', '[': '13', 'j': '14', 'Ũ': '15', '3': '16', 'o': '17', 'ï': '18', 'd': '19', 'x': '20', 'ċ': '21', 'Ü': '22', 'ı': '23', 'Ð': '24', 'Ď': '25', 'Ŋ': '26', '2': '27', '®': '28', '9': '29', 'ß': '30', 'ľ': '31', '/': '32', 'V': '33', '½': '34', 'û': '35', 'h': '36', 'ě': '37', 'r': '38', 'm': '39', '¥': '40', 'g': '41', 'ĺ': '42', 'B': '43', 'Ė': '44', 'Ř': '45', 'Ĺ': '46', 'Ò': '47', 'ĥ': '48', 'À': '49', '{': '50', 'Ž': '51', 'ã': '52', ':': '53', 'Ì': '54', 'Ī': '55', 'Ķ': '56', 'ń': '57', 'õ': '58', 'Å': '59', 'G': '60', 'È': '61', 'ſ': '62', 'Ą': '63', '5': '64', 'ë': '65', 'Ō': '66', 'ŋ': '67', 'ţ': '68', 'Ħ': '69', 'Q': '70', 'č': '71', 'Ŀ': '72', '=': '73', 'Ĉ': '74', 'Ş': '75', 'Ū': '76', 'ħ': '77', 'ŗ': '78', 'É': '79', '%': '80', 'ť': '81', 'æ': '82', '±': '83', '?': '84', 'D': '85', '»': '86', 'ż': '87', 'ć': '88', '<': '89', '|': '90', 'C': '91', 'Ġ': '92', '´': '93', 'ŏ': '94', '.': '95', '$': '96', 'ü': '97', '+': '98', 'ġ': '99', 'Ï': '100', 'ŕ': '101', 'Ă': '102', 'i': '103', 'Ý': '104', '"': '105', 'w': '106', 'Ù': '107', 'Ŝ': '108', 'Đ': '109', 'Ä': '110', 'ì': '111', '`': '112', 'ű': '113', '\xad': '114', 'ģ': '115', 'î': '116', '7': '117', 'Ö': '118', 'İ': '119', 'ĵ': '120', 'Z': '121', '¶': '122', 'Ņ': '123', '¨': '124', '4': '125', 'R': '126', ']': '127', '^': '128', 'F': '129', 'ļ': '130', 'ğ': '131', 'k': '132', 'ī': '133', 'é': '134', 'ŉ': '135', 'Ń': '136', 'Ľ': '137', '!': '138', 'ù': '139', 'Ĳ': '140', 'S': '141', 'E': '142', 'â': '143', ')': '144', '·': '145', '¾': '146', 'Þ': '147', 'Ł': '148', 'ř': '149', 'Ļ': '150', 'Ê': '151', 'ä': '152', 'n': '153', 'œ': '154', '(': '155', 'ĕ': '156', '§': '157', 'ê': '158', '°': '159', 'ý': '160', '@': '161', 'Ź': '162', '-': '163', 'Ţ': '164', 'ũ': '165', 'ė': '166', '0': '167', 'Ĩ': '168', 'ş': '169', 'š': '170', 'ō': '171', 'ą': '172', 'H': '173', 'ų': '174', 'O': '175', 'ŭ': '176', ' ': '177', 'ñ': '178', 'ś': '179', 'b': '180', '¦': '181', 'Ú': '182', 'Œ': '183', 'ª': '184', 'ĩ': '185', 'W': '186', 'M': '187', 'ă': '188', 'ö': '189', 'ž': '190', 'ò': '191', 'µ': '192', 'f': '193', 'ň': '194', 'þ': '195', '1': '196', 'Ç': '197', 'Ć': '198', '¹': '199', 'Ŗ': '200', 'á': '201', 'c': '202', '>': '203', '8': '204', 'ł': '205', 'Š': '206', 'ő': '207', 'Ģ': '208', 'ŷ': '209', 'Ĕ': '210', 'Ś': '211', 'ŝ': '212', 'ź': '213', 'Â': '214', 'ĭ': '215', '³': '216', 'Ċ': '217', 'Ã': '218', 'į': '219', 'l': '220', 'Û': '221', 'Ĭ': '222', 'Ŧ': '223', 'Ż': '224', 'K': '225', 'N': '226', '¡': '227', '_': '228', 'å': '229', '£': '230', 'ū': '231', 'Ų': '232', '×': '233', 'Ā': '234', 'u': '235', 'ů': '236', 'Ě': '237', '*': '238', 'v': '239', 'T': '240', 'Ŕ': '241', 'ē': '242', 'A': '243', 'X': '244', '¼': '245', 'q': '246', '¤': '247', 's': '248', 'Ű': '249', 't': '250', 'Ŷ': '251', 'Č': '252', 'ĝ': '253', '\\': '254', 'Ů': '255', '#': '256', "'": '257', 'Á': '258', '¿': '259', '}': '260', 'y': '261', 'Ē': '262', 'Ŭ': '263', 'Ë': '264', '~': '265', 'Ę': '266', 'Ŵ': '267', 'Æ': '268', 'ð': '269', 'º': '270', 'Ó': '271', 'ā': '272', 'ô': '273', 'J': '274', 'ÿ': '275', 'ó': '276', 'Ĝ': '277', '&': '278', 'P': '279', '©': '280', 'Ğ': '281', 'è': '282', 'ę': '283', 'ĸ': '284', '²': '285', 'Ĥ': '286', '¢': '287', 'ŵ': '288', 'Î': '289', 'đ': '290', 'Í': '291', 'a': '292', ';': '293', 'à': '294', '¯': '295', '¸': '296', 'ņ': '297', 'L': '298', 'Ő': '299', 'ķ': '300', 'p': '301', 'Ŏ': '302', 'í': '303', 'ŧ': '304', 'ç': '305', 'Ť': '306', 'ŀ': '307', 'z': '308', 'ď': '309', 'Ň': '310', '6': '311', 'I': '312', '÷': '313', 'ú': '314', 'Ø': '315', 'Ñ': '316', 'ĉ': '317'},
            'idx_to_char': {'318': 'Ĵ', '1': '¬', '2': 'Õ', '3': 'Y', '4': 'Į', '5': 'ø', '6': 'Ÿ', '7': ',', '8': '«', '9': 'ĳ', '10': 'e', '11': 'Ô', '12': 'U', '13': '[', '14': 'j', '15': 'Ũ', '16': '3', '17': 'o', '18': 'ï', '19': 'd', '20': 'x', '21': 'ċ', '22': 'Ü', '23': 'ı', '24': 'Ð', '25': 'Ď', '26': 'Ŋ', '27': '2', '28': '®', '29': '9', '30': 'ß', '31': 'ľ', '32': '/', '33': 'V', '34': '½', '35': 'û', '36': 'h', '37': 'ě', '38': 'r', '39': 'm', '40': '¥', '41': 'g', '42': 'ĺ', '43': 'B', '44': 'Ė', '45': 'Ř', '46': 'Ĺ', '47': 'Ò', '48': 'ĥ', '49': 'À', '50': '{', '51': 'Ž', '52': 'ã', '53': ':', '54': 'Ì', '55': 'Ī', '56': 'Ķ', '57': 'ń', '58': 'õ', '59': 'Å', '60': 'G', '61': 'È', '62': 'ſ', '63': 'Ą', '64': '5', '65': 'ë', '66': 'Ō', '67': 'ŋ', '68': 'ţ', '69': 'Ħ', '70': 'Q', '71': 'č', '72': 'Ŀ', '73': '=', '74': 'Ĉ', '75': 'Ş', '76': 'Ū', '77': 'ħ', '78': 'ŗ', '79': 'É', '80': '%', '81': 'ť', '82': 'æ', '83': '±', '84': '?', '85': 'D', '86': '»', '87': 'ż', '88': 'ć', '89': '<', '90': '|', '91': 'C', '92': 'Ġ', '93': '´', '94': 'ŏ', '95': '.', '96': '$', '97': 'ü', '98': '+', '99': 'ġ', '100': 'Ï', '101': 'ŕ', '102': 'Ă', '103': 'i', '104': 'Ý', '105': '"', '106': 'w', '107': 'Ù', '108': 'Ŝ', '109': 'Đ', '110': 'Ä', '111': 'ì', '112': '`', '113': 'ű', '114': '\xad', '115': 'ģ', '116': 'î', '117': '7', '118': 'Ö', '119': 'İ', '120': 'ĵ', '121': 'Z', '122': '¶', '123': 'Ņ', '124': '¨', '125': '4', '126': 'R', '127': ']', '128': '^', '129': 'F', '130': 'ļ', '131': 'ğ', '132': 'k', '133': 'ī', '134': 'é', '135': 'ŉ', '136': 'Ń', '137': 'Ľ', '138': '!', '139': 'ù', '140': 'Ĳ', '141': 'S', '142': 'E', '143': 'â', '144': ')', '145': '·', '146': '¾', '147': 'Þ', '148': 'Ł', '149': 'ř', '150': 'Ļ', '151': 'Ê', '152': 'ä', '153': 'n', '154': 'œ', '155': '(', '156': 'ĕ', '157': '§', '158': 'ê', '159': '°', '160': 'ý', '161': '@', '162': 'Ź', '163': '-', '164': 'Ţ', '165': 'ũ', '166': 'ė', '167': '0', '168': 'Ĩ', '169': 'ş', '170': 'š', '171': 'ō', '172': 'ą', '173': 'H', '174': 'ų', '175': 'O', '176': 'ŭ', '177': ' ', '178': 'ñ', '179': 'ś', '180': 'b', '181': '¦', '182': 'Ú', '183': 'Œ', '184': 'ª', '185': 'ĩ', '186': 'W', '187': 'M', '188': 'ă', '189': 'ö', '190': 'ž', '191': 'ò', '192': 'µ', '193': 'f', '194': 'ň', '195': 'þ', '196': '1', '197': 'Ç', '198': 'Ć', '199': '¹', '200': 'Ŗ', '201': 'á', '202': 'c', '203': '>', '204': '8', '205': 'ł', '206': 'Š', '207': 'ő', '208': 'Ģ', '209': 'ŷ', '210': 'Ĕ', '211': 'Ś', '212': 'ŝ', '213': 'ź', '214': 'Â', '215': 'ĭ', '216': '³', '217': 'Ċ', '218': 'Ã', '219': 'į', '220': 'l', '221': 'Û', '222': 'Ĭ', '223': 'Ŧ', '224': 'Ż', '225': 'K', '226': 'N', '227': '¡', '228': '_', '229': 'å', '230': '£', '231': 'ū', '232': 'Ų', '233': '×', '234': 'Ā', '235': 'u', '236': 'ů', '237': 'Ě', '238': '*', '239': 'v', '240': 'T', '241': 'Ŕ', '242': 'ē', '243': 'A', '244': 'X', '245': '¼', '246': 'q', '247': '¤', '248': 's', '249': 'Ű', '250': 't', '251': 'Ŷ', '252': 'Č', '253': 'ĝ', '254': '\\', '255': 'Ů', '256': '#', '257': "'", '258': 'Á', '259': '¿', '260': '}', '261': 'y', '262': 'Ē', '263': 'Ŭ', '264': 'Ë', '265': '~', '266': 'Ę', '267': 'Ŵ', '268': 'Æ', '269': 'ð', '270': 'º', '271': 'Ó', '272': 'ā', '273': 'ô', '274': 'J', '275': 'ÿ', '276': 'ó', '277': 'Ĝ', '278': '&', '279': 'P', '280': '©', '281': 'Ğ', '282': 'è', '283': 'ę', '284': 'ĸ', '285': '²', '286': 'Ĥ', '287': '¢', '288': 'ŵ', '289': 'Î', '290': 'đ', '291': 'Í', '292': 'a', '293': ';', '294': 'à', '295': '¯', '296': '¸', '297': 'ņ', '298': 'L', '299': 'Ő', '300': 'ķ', '301': 'p', '302': 'Ŏ', '303': 'í', '304': 'ŧ', '305': 'ç', '306': 'Ť', '307': 'ŀ', '308': 'z', '309': 'ď', '310': 'Ň', '311': '6', '312': 'I', '313': '÷', '314': 'ú', '315': 'Ø', '316': 'Ñ', '317': 'ĉ'}
           }

class CharsetMapper:
    def __init__(self, max_sequence_size=128, blank_character=0):
        self.max_sequence_size = max_sequence_size
        self.blank_character = blank_character

    @staticmethod
    def remove_duplicates(idxs):
        new_idxs = []

        for i in range(len(idxs)):
            # Only append if the next character in the sequence is not
            # identical to the current character. If we're at the end of
            # the sequence, add it.
            if i + 1 == len(idxs) or idxs[i] != idxs[i + 1]:
                new_idxs.append(idxs[i])

        return new_idxs

    def idx_to_char(self, idx):
        if idx == self.blank_character:
            return ''  # Return empty string for the blank character
        else:
          try:
            return CHAR_SET['idx_to_char'][str(int(idx))]
          except KeyError:
            return ''

    def char_to_idx(self, char):
        try:
          return int(CHAR_SET['char_to_idx'][char])
        except KeyError:
          return self.blank_character

    def str_to_idxs(self, string):
        idxs = []

        zeros = np.full(self.max_sequence_size, self.blank_character)
        for char in string:
            idxs.append(self.char_to_idx(char))

        # Pad the array to the max sequence size
        idxs = np.concatenate((idxs, zeros))[:self.max_sequence_size]

        return idxs

    def idxs_to_str(self, idxs, remove_duplicates=True):
        string = ''

        if remove_duplicates:
            idxs = CharsetMapper.remove_duplicates(idxs)

        for idx in idxs:
            string += self.idx_to_char(idx)

        return string

    def str_to_idxs_batch(self, batch):
        idxs = []

        for string in batch:
            idx = self.str_to_idxs(string)
            idxs.append(idx)

        return idxs

    def idxs_to_str_batch(self, batch, remove_duplicates=True):
        strings = []

        for idxs in batch:
            strings.append(self.idxs_to_str(idxs, remove_duplicates=remove_duplicates))

        return strings

    def get_vocab_size(self):
      return len(CHAR_SET['char_to_idx']) + 1

### Data Loading
* Keras Sequence
* TfRecord Conversion

In [0]:
class ErrorSequence(tf.keras.utils.Sequence):
  def __init__(self, path='/content/error.csv'):
    self.df = pd.read_csv(path, header=None, sep='\t', names=['original', 'error'])
    self.charset_mapper = CharsetMapper()
  
  def __getitem__(self, index):
    x = self.charset_mapper.str_to_idxs(str(self.df['error'][index]))
    y = self.charset_mapper.str_to_idxs(str(self.df['original'][index]))

    return tf.constant(x), tf.constant(y)
  
  def __len__(self):
    return len(self.df)

In [7]:
sequence = ErrorSequence()
mapper = CharsetMapper()
x, y = sequence[4]

print('error:', mapper.idxs_to_str(x.numpy()))
print('corrected:', mapper.idxs_to_str(y.numpy()))

error: EXste alfarábio, nã¨o o devo ao meu velhÀ cronista do Paseioy PúbJico. É, como se dise no
corrected: Este alfarábio, não o devo ao meu velho cronista do Paseio Público. É, como se dise no


Create the TfRecord Dataset

In [0]:
def create_tfrecord_from_sequence(sequence, tfrecord_path):
    """
    Create a TfRecord dataset from a sequence

    :param sequence: The Keras sequence to load dataasdfs of arbitrary format
    :param tfrecord_path: Filepath and name for location of TfRecord dataset
    """
    print('Started creating TFRecord Dataset...')

    writer = tf.io.TFRecordWriter(tfrecord_path)

    for index, (img, label) in enumerate(sequence):
        feature = {'label': _bytes_feature(tf.io.serialize_tensor(label)),
                   'image': _bytes_feature(tf.io.serialize_tensor(img))}

        example = tf.train.Example(features=tf.train.Features(feature=feature))
        writer.write(example.SerializeToString())
        if index % 1000 == 0:
            print(str(index) + '/' + str(len(sequence)))

    print(str(len(sequence)) + '/' + str(len(sequence)))

    print('Finished: TFRecord created at', tfrecord_path)


def read_tfrecord(single_record):
    """
    Function to decode a TfRecord. Usually this function will be called within
    a TfDataset map function. Note that out_types for image and label must be
    tf.float32 and tf.int64 respectively.

    :param single_record: A single TfRecord
    :return: A decoded image and label as tensors
    """
    feature_description = {
        'label': tf.io.FixedLenFeature((), tf.string),
        'image': tf.io.FixedLenFeature((), tf.string)
    }

    single_record = tf.io.parse_single_example(single_record, feature_description)

    image = tf.io.parse_tensor(single_record['image'], out_type=tf.int64)
    label = tf.io.parse_tensor(single_record['label'], out_type=tf.int64)

    return image, label


def _bytes_feature(value):
    """Returns a bytes_list from a string / byte."""
    if isinstance(value, type(tf.constant(0))):
        value = value.numpy()  # BytesList won't unpack a string from an EagerTensor.
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


def _float_feature(value):
    """Returns a float_list from a float / double."""
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))


def _int64_feature(value):
    """Returns an int64_list from a bool / enum / int / uint."""
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

In [8]:
create_tfrecord_from_sequence(ErrorSequence(), 'error.tfrecord')

Started creating TFRecord Dataset...
0/341304
1000/341304
2000/341304
3000/341304
4000/341304
5000/341304
6000/341304
7000/341304
8000/341304
9000/341304
10000/341304
11000/341304
12000/341304
13000/341304
14000/341304
15000/341304
16000/341304
17000/341304
18000/341304
19000/341304
20000/341304
21000/341304
22000/341304
23000/341304
24000/341304
25000/341304
26000/341304
27000/341304
28000/341304
29000/341304
30000/341304
31000/341304
32000/341304
33000/341304
34000/341304
35000/341304
36000/341304
37000/341304
38000/341304
39000/341304
40000/341304
41000/341304
42000/341304
43000/341304
44000/341304
45000/341304
46000/341304
47000/341304
48000/341304
49000/341304
50000/341304
51000/341304
52000/341304
53000/341304
54000/341304
55000/341304
56000/341304
57000/341304
58000/341304
59000/341304
60000/341304
61000/341304
62000/341304
63000/341304
64000/341304
65000/341304
66000/341304
67000/341304
68000/341304
69000/341304
70000/341304
71000/341304
72000/341304
73000/341304
74000/341304
7

In [9]:
dataset = tf.data.TFRecordDataset('error.tfrecord').take(30000).map(read_tfrecord)

for image_features in dataset.take(1):
  print('Error:', mapper.idxs_to_str(image_features[0]))
  print('Corrected:', mapper.idxs_to_str(image_features[1]))

Error: A ALMADO LÁZARO
Corrected: A ALMA DO LÁZARO


### Create the Sequence-to-Sequence Model

In [0]:
EPOCHS = 10
BATCH_SIZE = 100
EMBEDDING_DIM = 128
UNITS = 128

mapper = CharsetMapper()
dataset = tf.data.TFRecordDataset('error.tfrecord').take(30000).map(read_tfrecord).batch(BATCH_SIZE)

In [0]:
class Encoder(tf.keras.Model):
  def __init__(self, vocab_size, embedding_dim, enc_units, batch_sz):
    super(Encoder, self).__init__()
    self.batch_sz = batch_sz
    self.enc_units = enc_units
    self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
    self.gru = tf.keras.layers.GRU(self.enc_units,
                                   return_sequences=True,
                                   return_state=True,
                                   recurrent_initializer='glorot_uniform')

  def call(self, x, hidden):
    x = self.embedding(x)
    output, state = self.gru(x, initial_state = hidden)
    return output, state

  def initialize_hidden_state(self):
    return tf.zeros((self.batch_sz, self.enc_units))

In [39]:
encoder = Encoder(mapper.get_vocab_size(), EMBEDDING_DIM, UNITS, BATCH_SIZE)

example_input_batch, example_target_batch = next(iter(dataset))
example_input_batch.shape, example_target_batch.shape

sample_hidden = encoder.initialize_hidden_state()
sample_output, sample_hidden = encoder(example_input_batch, sample_hidden)
print('Encoder output shape: (batch size, sequence length, units) {}'.format(sample_output.shape))
print('Encoder Hidden state shape: (batch size, units) {}'.format(sample_hidden.shape))
encoder.summary()

Encoder output shape: (batch size, sequence length, units) (64, 128, 128)
Encoder Hidden state shape: (batch size, units) (64, 128)
Model: "encoder_4"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
embedding_6 (Embedding)      multiple                  40832     
_________________________________________________________________
gru_6 (GRU)                  multiple                  99072     
Total params: 139,904
Trainable params: 139,904
Non-trainable params: 0
_________________________________________________________________


In [0]:
class BahdanauAttention(tf.keras.layers.Layer):
  def __init__(self, units):
    super(BahdanauAttention, self).__init__()
    self.W1 = tf.keras.layers.Dense(units)
    self.W2 = tf.keras.layers.Dense(units)
    self.V = tf.keras.layers.Dense(1)

  def call(self, query, values):
    # query hidden state shape == (batch_size, hidden size)
    # query_with_time_axis shape == (batch_size, 1, hidden size)
    # values shape == (batch_size, max_len, hidden size)
    # we are doing this to broadcast addition along the time axis to calculate the score
    query_with_time_axis = tf.expand_dims(query, 1)

    # score shape == (batch_size, max_length, 1)
    # we get 1 at the last axis because we are applying score to self.V
    # the shape of the tensor before applying self.V is (batch_size, max_length, units)
    score = self.V(tf.nn.tanh(
        self.W1(query_with_time_axis) + self.W2(values)))

    # attention_weights shape == (batch_size, max_length, 1)
    attention_weights = tf.nn.softmax(score, axis=1)

    # context_vector shape after sum == (batch_size, hidden_size)
    context_vector = attention_weights * values
    context_vector = tf.reduce_sum(context_vector, axis=1)

    return context_vector, attention_weights

In [41]:
attention_layer = BahdanauAttention(10)
attention_result, attention_weights = attention_layer(sample_hidden, sample_output)

print("Attention result shape: (batch size, units) {}".format(attention_result.shape))
print("Attention weights shape: (batch_size, sequence_length, 1) {}".format(attention_weights.shape))

Attention result shape: (batch size, units) (64, 128)
Attention weights shape: (batch_size, sequence_length, 1) (64, 128, 1)


In [0]:
class Decoder(tf.keras.Model):
  def __init__(self, vocab_size, embedding_dim, dec_units, batch_sz):
    super(Decoder, self).__init__()
    self.batch_sz = batch_sz
    self.dec_units = dec_units
    self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
    self.gru = tf.keras.layers.GRU(self.dec_units,
                                   return_sequences=True,
                                   return_state=True,
                                   recurrent_initializer='glorot_uniform')
    self.fc = tf.keras.layers.Dense(vocab_size)

    # used for attention
    self.attention = BahdanauAttention(self.dec_units)

  def call(self, x, hidden, enc_output):
    # enc_output shape == (batch_size, max_length, hidden_size)
    context_vector, attention_weights = self.attention(hidden, enc_output)

    # x shape after passing through embedding == (batch_size, 1, embedding_dim)
    x = self.embedding(x)

    # x shape after concatenation == (batch_size, 1, embedding_dim + hidden_size)
    x = tf.concat([tf.expand_dims(context_vector, 1), x], axis=-1)

    # passing the concatenated vector to the GRU
    output, state = self.gru(x)

    # output shape == (batch_size * 1, hidden_size)
    output = tf.reshape(output, (-1, output.shape[2]))

    # output shape == (batch_size, vocab)
    x = self.fc(output)

    return x, state, attention_weights

In [43]:
decoder = Decoder(mapper.get_vocab_size(), EMBEDDING_DIM, UNITS, BATCH_SIZE)

sample_decoder_output, _, _ = decoder(tf.random.uniform((BATCH_SIZE, 1)),
                                      sample_hidden, sample_output)

print ('Decoder output shape: (batch_size, vocab size) {}'.format(sample_decoder_output.shape))
decoder.summary()

Decoder output shape: (batch_size, vocab size) (64, 319)
Model: "decoder_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
embedding_7 (Embedding)      multiple                  40832     
_________________________________________________________________
gru_7 (GRU)                  multiple                  148224    
_________________________________________________________________
dense_20 (Dense)             multiple                  41151     
_________________________________________________________________
bahdanau_attention_6 (Bahdan multiple                  33153     
Total params: 263,360
Trainable params: 263,360
Non-trainable params: 0
_________________________________________________________________


In [0]:
optimizer = tf.keras.optimizers.Adam()
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
    from_logits=True, reduction='none')

def loss_function(real, pred):
  mask = tf.math.logical_not(tf.math.equal(real, 0))
  loss_ = loss_object(real, pred)

  mask = tf.cast(mask, dtype=loss_.dtype)
  loss_ *= mask

  return tf.reduce_mean(loss_)

In [0]:
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(optimizer=optimizer,
                                 encoder=encoder,
                                 decoder=decoder)

In [0]:
@tf.function
def train_step(inp, targ, enc_hidden):
  loss = 0

  with tf.GradientTape() as tape:
    enc_output, enc_hidden = encoder(inp, enc_hidden)

    dec_hidden = enc_hidden

    dec_input = tf.expand_dims([0] * BATCH_SIZE, 1)

    # Teacher forcing - feeding the target as the next input
    for t in range(1, targ.shape[1]):
      print('Iter:', t)
      # passing enc_output to the decoder
      predictions, dec_hidden, _ = decoder(dec_input, dec_hidden, enc_output)

      loss += loss_function(targ[:, t], predictions)

      # using teacher forcing
      dec_input = tf.expand_dims(targ[:, t], 1)

  batch_loss = (loss / int(targ.shape[1]))

  variables = encoder.trainable_variables + decoder.trainable_variables

  gradients = tape.gradient(loss, variables)

  optimizer.apply_gradients(zip(gradients, variables))

  return batch_loss

In [0]:
EPOCHS = 10

train_loss = tf.keras.metrics.Mean(name='train_loss')

for epoch in range(EPOCHS):
  start = time.time()

  enc_hidden = encoder.initialize_hidden_state()
  train_loss.reset_states()

  train_loop = tqdm(total=30000//BATCH_SIZE, position=0, leave=True)
  for (batch, (inp, targ)) in enumerate(dataset):
    batch_loss = train_step(inp, targ, enc_hidden)
    train_loss(batch_loss)
    train_loop.set_description('Epoch: {}, Loss: {:.4f}'.format(epoch, train_loss.result()))
    train_loop.update(1)
  train_loop.close()

  # saving (checkpoint) the model every 2 epochs
  if (epoch + 1) % 2 == 0:
    checkpoint.save(file_prefix = checkpoint_prefix)

  print('Time taken for 1 epoch {} sec\n'.format(time.time() - start))

Epoch: 0, Loss: 1.0877:  58%|█████▊    | 272/468 [03:51<02:37,  1.24it/s]

### Results

In [0]:
def evaluate(sentence):
  attention_plot = np.zeros((max_length_targ, max_length_inp))

  sentence = preprocess_sentence(sentence)

  inputs = [inp_lang.word_index[i] for i in sentence.split(' ')]
  inputs = tf.keras.preprocessing.sequence.pad_sequences([inputs],
                                                         maxlen=max_length_inp,
                                                         padding='post')
  inputs = tf.convert_to_tensor(inputs)

  result = ''

  hidden = [tf.zeros((1, units))]
  enc_out, enc_hidden = encoder(inputs, hidden)

  dec_hidden = enc_hidden
  dec_input = tf.expand_dims([targ_lang.word_index['<start>']], 0)

  for t in range(max_length_targ):
    predictions, dec_hidden, attention_weights = decoder(dec_input,
                                                         dec_hidden,
                                                         enc_out)

    # storing the attention weights to plot later on
    attention_weights = tf.reshape(attention_weights, (-1, ))
    attention_plot[t] = attention_weights.numpy()

    predicted_id = tf.argmax(predictions[0]).numpy()

    result += targ_lang.index_word[predicted_id] + ' '

    if targ_lang.index_word[predicted_id] == '<end>':
      return result, sentence, attention_plot

    # the predicted ID is fed back into the model
    dec_input = tf.expand_dims([predicted_id], 0)

  return result, sentence, attention_plot

In [0]:
# function for plotting the attention weights
def plot_attention(attention, sentence, predicted_sentence):
  fig = plt.figure(figsize=(10,10))
  ax = fig.add_subplot(1, 1, 1)
  ax.matshow(attention, cmap='viridis')

  fontdict = {'fontsize': 14}

  ax.set_xticklabels([''] + sentence, fontdict=fontdict, rotation=90)
  ax.set_yticklabels([''] + predicted_sentence, fontdict=fontdict)

  ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
  ax.yaxis.set_major_locator(ticker.MultipleLocator(1))

  plt.show()

In [0]:
def translate(sentence):
  result, sentence, attention_plot = evaluate(sentence)

  print('Input: %s' % (sentence))
  print('Predicted translation: {}'.format(result))

  attention_plot = attention_plot[:len(result.split(' ')), :len(sentence.split(' '))]
  plot_attention(attention_plot, sentence.split(' '), result.split(' '))