In [1]:
%matplotlib inline

In [2]:
from datetime import datetime

In [3]:
#renderer
from PIL import ImageFont
import numpy as np
import cv2
import torch.nn as nn
from torch.nn.utils import clip_grad_norm_

char_size = 24
# char render
def render(text, font=None):
    if font is None:
        font = ImageFont.truetype("/usr/share/fonts/opentype/noto/NotoSansCJK-Regular.ttc", char_size)
    mask = font.getmask(text)
    size = mask.size[::-1]
    a = np.asarray(mask).reshape(size) / 255
    res = cv2.resize(a, dsize=(char_size, char_size), interpolation=cv2.INTER_CUBIC)
    return res

In [4]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [5]:
# https://github.com/yunjey/pytorch-tutorial/blob/master/tutorials/02-intermediate/language_model/data_utils.py
import torch
import re

class Dictionary(object):
    def __init__(self, max_size=None):
        self.word2idx = {}
        self.idx2word = {}
        self.idx = 1
        self.word2idx['⸘'] = 0 # as unk
        self.idx2word[0] = '⸘'
        self.max_size = max_size + 1
    
    def add_word(self, word):
        if not word in self.word2idx and self.idx < self.max_size:
            self.word2idx[word] = self.idx
            self.idx2word[self.idx] = word
            self.idx += 1
    
    def __len__(self):
        return len(self.word2idx)


class Corpus(object):
    def __init__(self, max_size=None):
        self.dictionary = Dictionary(max_size=max_size)

    def get_data(self, path, batch_size=20):
        # Add words to the dictionary
#         with open(path, 'r') as f:
#             tokens = 0
#             for line in f:
#                 words = line.split() + ['<eos>']
#                 tokens += len(words)
#                 for word in words: 
#                     self.dictionary.add_word(word)  

        # split words to char and add to dictionary
        with open(path, 'r') as f:
            tokens = 0
            for line in f:
                line = ' '.join(line) # split words to char
                line = re.sub(r'[" "]+', ' ', line) # remove continous space
                chars = line.split() + ['¿'] # ¿ as <eos>
                tokens += len(chars)
                for char in chars:
                    self.dictionary.add_word(char)
        
        # Tokenize the file content
        ids = torch.LongTensor(tokens)
        token = 0
        with open(path, 'r') as f:
            for line in f:
                line = ' '.join(line) # split words to char
                line = re.sub(r'[" "]+', ' ', line) # remove continous space
                chars = line.split() + ['¿'] # ¿ as <eos>
                for char in chars:
                    if char in self.dictionary.word2idx:
                        ids[token] = self.dictionary.word2idx[char]
                        token += 1
                    else:
                        ids[token] = self.dictionary.word2idx['⸘']
                        token += 1
        num_batches = ids.size(0) // batch_size
        ids = ids[:num_batches*batch_size]
        return ids.view(batch_size, -1)

In [6]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [7]:
# RNN based language model
class RNNLM(nn.Module):
    def __init__(self, vocab_size, embed_size, hidden_size, num_layers):
        super(RNNLM, self).__init__()
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.gru = nn.GRU(embed_size, hidden_size, num_layers, batch_first=True)
        self.linear = nn.Linear(hidden_size, vocab_size)
        
    def forward(self, char_id, char_cnn_o, h):
        # Embed word ids to vectors
        x = self.embed(char_id) + char_cnn_o
        
        # Forward propagate LSTM
        out, h = self.gru(x, h)
        
        # Reshape output to (batch_size*sequence_length, hidden_size)
        out = out.reshape(out.size(0)*out.size(1), out.size(2))
        
        # Decode hidden states of all time steps
        out = self.linear(out)
        return out, h

In [8]:
"""
Utility function for computing output of convolutions
takes a tuple of (h,w) and returns a tuple of (h,w)
"""
def conv_output_shape(h_w, kernel_size=1, stride=1, pad=0, dilation=1):
    from math import floor
    if type(kernel_size) is not tuple:
        kernel_size = (kernel_size, kernel_size)
    if type(stride) is not tuple:
        stride = (stride, stride)
    h = floor( ((h_w[0] + (2 * pad) - ( dilation * (kernel_size[0] - 1) ) - 1 )/ stride[0]) + 1)
    w = floor( ((h_w[1] + (2 * pad) - ( dilation * (kernel_size[1] - 1) ) - 1 )/ stride[1]) + 1)
    return h, w

# Dai et al. 's CNN glyph encoder
class Dai_CNN(nn.Module):
    def __init__(self, embed_size, input_size=(24, 24)):
        super(Dai_CNN, self).__init__()
        
        self.conv1 = nn.Conv2d(1, 32, (7, 7), stride=(2,2))
        torch.nn.init.xavier_uniform_(self.conv1.weight)
        torch.nn.init.zeros_(self.conv1.bias)
        h, w = conv_output_shape(input_size, (7, 7), (2, 2))
        
        self.conv2 = nn.Conv2d(32, 16, (5, 5), stride=(2,2))
        torch.nn.init.xavier_uniform_(self.conv2.weight)
        torch.nn.init.zeros_(self.conv2.bias)
        h, w = conv_output_shape((h, w), (5, 5), (2, 2))
                
        self.fc = nn.Linear(16*h*w, embed_size)
        torch.nn.init.xavier_uniform_(self.fc.weight)
        torch.nn.init.zeros_(self.fc.bias)
        
        self.h, self.w = h, w
        
    def forward(self, char_img):
        b = char_img.size(0)
        x = self.conv1(char_img)
        x = torch.nn.functional.relu(x)
        x = self.conv2(x)
        x = torch.nn.functional.relu(x)
        x = x.view(-1, 16*self.h*self.w)
        return self.fc(x)

In [9]:
# Hyper-parameters
embed_size = 300
hidden_size = 128
num_layers = 1
num_epochs = 50
batch_size = 16
seq_length = 32
learning_rate = 1e-3

# Load dataset
corpus = Corpus(max_size=4000)
ids = corpus.get_data('icwb2-data/training/msr_training.utf8', batch_size)
vocab_size = len(corpus.dictionary)
num_batches = ids.size(1) // seq_length

In [10]:
# save char images for reference
for char, idx in corpus.dictionary.word2idx.items():
    np.save(f'char_img/noto_CJK/msr4/{idx}.npy', render(char))

In [11]:
model = RNNLM(vocab_size, embed_size, hidden_size, num_layers).to(device)
cnn_encoder = Dai_CNN(embed_size, input_size=(char_size, char_size)).to(device)
model.train()
cnn_encoder.train()
params = list(model.parameters())+list(cnn_encoder.parameters())

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params, lr=learning_rate)

# Truncated backpropagation
def detach(state):
    return state.detach()

In [12]:
# Train the charID RNNLM only
for param in cnn_encoder.parameters():
    param.requires_grad = False

for epoch in range(num_epochs//3):
    # Set initial hidden and cell states
    state = torch.zeros(num_layers, batch_size, hidden_size).to(device)
    
    for i in range(0, ids.size(1) - seq_length, seq_length):
        # Get mini-batch inputs and targets
        inputs = ids[:, i:i+seq_length].to(device)
        targets = ids[:, (i+1):(i+1)+seq_length].to(device)
        
        # Get images
        images = np.zeros((inputs.size(0), inputs.size(1), char_size, char_size, 1))
        for b, seq in enumerate(ids):
            for s, idx in enumerate(ids[b][i:i+seq_length]):
                images[b, s] = np.load(f'char_img/noto_CJK/msr4/{idx}.npy').reshape(char_size,char_size,1)
        images = torch.from_numpy(images).float().to(device) # B N H W C
        images = images.view(-1, char_size, char_size, 1) # B*N H W C
        images = images.permute(0, 3, 1, 2) # from B*N H W C to B*N C H W
        
        # Get encoded images
        cnn_o = cnn_encoder(images)
        cnn_o = torch.reshape(cnn_o, (inputs.size(0), inputs.size(1), -1))
               
        # Forward pass
        state = detach(state)
        outputs, state = model(inputs, cnn_o, state)
        loss = criterion(outputs, targets.reshape(-1))
        
        # Backward and optimize
        model.zero_grad()
        cnn_encoder.zero_grad()
        loss.backward()
        clip_grad_norm_(params, 0.5)
        optimizer.step()

        step = (i+1) // seq_length
        if step % 1000 == 0:
            print ('{} Epoch [{}/{}], Step[{}/{}], Loss: {:.4f}, Perplexity: {:5.2f}'
                   .format(datetime.now(), epoch+1, num_epochs, step, num_batches, loss.item(), np.exp(loss.item())))

2019-11-25 16:34:58.072595 Epoch [1/50], Step[0/8080], Loss: 8.3401, Perplexity: 4188.49
2019-11-25 16:39:44.468942 Epoch [1/50], Step[1000/8080], Loss: 5.0944, Perplexity: 163.11
2019-11-25 16:44:32.734452 Epoch [1/50], Step[2000/8080], Loss: 4.5996, Perplexity: 99.45
2019-11-25 16:50:38.531550 Epoch [1/50], Step[3000/8080], Loss: 4.4433, Perplexity: 85.06
2019-11-25 16:56:43.342895 Epoch [1/50], Step[4000/8080], Loss: 4.2708, Perplexity: 71.58
2019-11-25 17:02:34.221519 Epoch [1/50], Step[5000/8080], Loss: 4.7157, Perplexity: 111.68
2019-11-25 17:08:04.106129 Epoch [1/50], Step[6000/8080], Loss: 4.6639, Perplexity: 106.05
2019-11-25 17:14:12.028457 Epoch [1/50], Step[7000/8080], Loss: 4.1268, Perplexity: 61.98
2019-11-25 17:20:25.275289 Epoch [1/50], Step[8000/8080], Loss: 4.3250, Perplexity: 75.57
2019-11-25 17:20:55.196630 Epoch [2/50], Step[0/8080], Loss: 4.2807, Perplexity: 72.29
2019-11-25 17:27:15.567507 Epoch [2/50], Step[1000/8080], Loss: 4.1963, Perplexity: 66.44
2019-11-25 

2019-11-26 01:26:53.356157 Epoch [11/50], Step[2000/8080], Loss: 3.8872, Perplexity: 48.77
2019-11-26 01:33:31.217732 Epoch [11/50], Step[3000/8080], Loss: 3.8239, Perplexity: 45.78
2019-11-26 01:40:06.401453 Epoch [11/50], Step[4000/8080], Loss: 3.6757, Perplexity: 39.48
2019-11-26 01:46:47.300892 Epoch [11/50], Step[5000/8080], Loss: 4.2616, Perplexity: 70.92
2019-11-26 01:53:24.876903 Epoch [11/50], Step[6000/8080], Loss: 3.9693, Perplexity: 52.95
2019-11-26 02:00:04.160616 Epoch [11/50], Step[7000/8080], Loss: 3.6650, Perplexity: 39.06
2019-11-26 02:06:42.169588 Epoch [11/50], Step[8000/8080], Loss: 3.9640, Perplexity: 52.67
2019-11-26 02:07:14.615242 Epoch [12/50], Step[0/8080], Loss: 3.8633, Perplexity: 47.62
2019-11-26 02:13:52.311128 Epoch [12/50], Step[1000/8080], Loss: 3.7098, Perplexity: 40.84
2019-11-26 02:20:29.803552 Epoch [12/50], Step[2000/8080], Loss: 3.8713, Perplexity: 48.00
2019-11-26 02:27:08.697202 Epoch [12/50], Step[3000/8080], Loss: 3.8278, Perplexity: 45.96
20

In [13]:
# Train the CNN + RNNLM without Char Embedding
for param in model.embed.parameters():
    param.requires_grad = False
for param in cnn_encoder.parameters():
    param.requires_grad = True

for epoch in range(num_epochs//3):
    # Set initial hidden and cell states
    state = torch.zeros(num_layers, batch_size, hidden_size).to(device)
    
    for i in range(0, ids.size(1) - seq_length, seq_length):
        # Get mini-batch inputs and targets
        inputs = ids[:, i:i+seq_length].to(device)
        targets = ids[:, (i+1):(i+1)+seq_length].to(device)
        
        # Get images
        images = np.zeros((inputs.size(0), inputs.size(1), char_size, char_size, 1))
        for b, seq in enumerate(ids):
            for s, idx in enumerate(ids[b][i:i+seq_length]):
                images[b, s] = np.load(f'char_img/noto_CJK/msr4/{idx}.npy').reshape(char_size,char_size,1)
        images = torch.from_numpy(images).float().to(device) # B N H W C
        images = images.view(-1, char_size, char_size, 1) # B*N H W C
        images = images.permute(0, 3, 1, 2) # from B*N H W C to B*N C H W
        
        # Get encoded images
        cnn_o = cnn_encoder(images)
        cnn_o = torch.reshape(cnn_o, (inputs.size(0), inputs.size(1), -1))
               
        # Forward pass
        state = detach(state)
        outputs, state = model(inputs, cnn_o, state)
        loss = criterion(outputs, targets.reshape(-1))
        
        # Backward and optimize
        model.zero_grad()
        cnn_encoder.zero_grad()
        loss.backward()
        clip_grad_norm_(params, 0.5)
        optimizer.step()

        step = (i+1) // seq_length
        if step % 1000 == 0:
            print ('{} Epoch [{}/{}], Step[{}/{}], Loss: {:.4f}, Perplexity: {:5.2f}'
                   .format(datetime.now(), epoch+1, num_epochs, step, num_batches, loss.item(), np.exp(loss.item())))

2019-11-26 06:35:11.802643 Epoch [1/50], Step[0/8080], Loss: 3.8297, Perplexity: 46.05
2019-11-26 06:41:53.660148 Epoch [1/50], Step[1000/8080], Loss: 3.6675, Perplexity: 39.15
2019-11-26 06:48:35.315657 Epoch [1/50], Step[2000/8080], Loss: 3.8359, Perplexity: 46.34
2019-11-26 06:55:19.225863 Epoch [1/50], Step[3000/8080], Loss: 3.7271, Perplexity: 41.56
2019-11-26 07:01:59.101525 Epoch [1/50], Step[4000/8080], Loss: 3.6714, Perplexity: 39.31
2019-11-26 07:08:39.327278 Epoch [1/50], Step[5000/8080], Loss: 4.2187, Perplexity: 67.94
2019-11-26 07:15:23.734155 Epoch [1/50], Step[6000/8080], Loss: 3.8757, Perplexity: 48.21
2019-11-26 07:22:02.632987 Epoch [1/50], Step[7000/8080], Loss: 3.6237, Perplexity: 37.48
2019-11-26 07:28:43.033837 Epoch [1/50], Step[8000/8080], Loss: 3.9379, Perplexity: 51.31
2019-11-26 07:29:15.034192 Epoch [2/50], Step[0/8080], Loss: 3.8115, Perplexity: 45.22
2019-11-26 07:35:50.747305 Epoch [2/50], Step[1000/8080], Loss: 3.6427, Perplexity: 38.20
2019-11-26 07:42

2019-11-26 15:44:49.533847 Epoch [11/50], Step[2000/8080], Loss: 3.8393, Perplexity: 46.50
2019-11-26 15:50:59.593768 Epoch [11/50], Step[3000/8080], Loss: 3.7506, Perplexity: 42.55
2019-11-26 15:57:09.218949 Epoch [11/50], Step[4000/8080], Loss: 3.6560, Perplexity: 38.70
2019-11-26 16:03:20.958462 Epoch [11/50], Step[5000/8080], Loss: 4.2227, Perplexity: 68.22
2019-11-26 16:09:27.367717 Epoch [11/50], Step[6000/8080], Loss: 3.8905, Perplexity: 48.93
2019-11-26 16:15:42.681668 Epoch [11/50], Step[7000/8080], Loss: 3.6294, Perplexity: 37.69
2019-11-26 16:21:52.130293 Epoch [11/50], Step[8000/8080], Loss: 3.8928, Perplexity: 49.05
2019-11-26 16:22:20.546166 Epoch [12/50], Step[0/8080], Loss: 3.8212, Perplexity: 45.66
2019-11-26 16:28:33.284201 Epoch [12/50], Step[1000/8080], Loss: 3.6425, Perplexity: 38.19
2019-11-26 16:35:02.758678 Epoch [12/50], Step[2000/8080], Loss: 3.8267, Perplexity: 45.91
2019-11-26 16:41:52.984385 Epoch [12/50], Step[3000/8080], Loss: 3.7374, Perplexity: 41.99
20

In [14]:
# Train the CNN + RNNLM
for param in model.embed.parameters():
    param.requires_grad = True

for epoch in range(num_epochs//3 + num_epochs%3):
    # Set initial hidden and cell states
    state = torch.zeros(num_layers, batch_size, hidden_size).to(device)
    
    for i in range(0, ids.size(1) - seq_length, seq_length):
        # Get mini-batch inputs and targets
        inputs = ids[:, i:i+seq_length].to(device)
        targets = ids[:, (i+1):(i+1)+seq_length].to(device)
        
        # Get images
        images = np.zeros((inputs.size(0), inputs.size(1), char_size, char_size, 1))
        for b, seq in enumerate(ids):
            for s, idx in enumerate(ids[b][i:i+seq_length]):
                images[b, s] = np.load(f'char_img/noto_CJK/msr4/{idx}.npy').reshape(char_size,char_size,1)
        images = torch.from_numpy(images).float().to(device) # B N H W C
        images = images.view(-1, char_size, char_size, 1) # B*N H W C
        images = images.permute(0, 3, 1, 2) # from B*N H W C to B*N C H W
        
        # Get encoded images
        cnn_o = cnn_encoder(images)
        cnn_o = torch.reshape(cnn_o, (inputs.size(0), inputs.size(1), -1))
               
        # Forward pass
        state = detach(state)
        outputs, state = model(inputs, cnn_o, state)
        loss = criterion(outputs, targets.reshape(-1))
        
        # Backward and optimize
        model.zero_grad()
        cnn_encoder.zero_grad()
        loss.backward()
        clip_grad_norm_(params, 0.5)
        optimizer.step()

        step = (i+1) // seq_length
        if step % 1000 == 0:
            print ('{} Epoch [{}/{}], Step[{}/{}], Loss: {:.4f}, Perplexity: {:5.2f}'
                   .format(datetime.now(), epoch+1, num_epochs, step, num_batches, loss.item(), np.exp(loss.item())))

2019-11-26 21:02:53.902375 Epoch [1/50], Step[0/8080], Loss: 3.8049, Perplexity: 44.92
2019-11-26 21:09:43.977851 Epoch [1/50], Step[1000/8080], Loss: 3.6734, Perplexity: 39.39
2019-11-26 21:16:31.442470 Epoch [1/50], Step[2000/8080], Loss: 3.8881, Perplexity: 48.82
2019-11-26 21:23:17.103985 Epoch [1/50], Step[3000/8080], Loss: 3.7566, Perplexity: 42.80
2019-11-26 21:30:08.216743 Epoch [1/50], Step[4000/8080], Loss: 3.6858, Perplexity: 39.88
2019-11-26 21:36:55.311133 Epoch [1/50], Step[5000/8080], Loss: 4.2488, Perplexity: 70.02
2019-11-26 21:43:44.661322 Epoch [1/50], Step[6000/8080], Loss: 3.9033, Perplexity: 49.57
2019-11-26 21:50:31.997673 Epoch [1/50], Step[7000/8080], Loss: 3.6539, Perplexity: 38.62
2019-11-26 21:57:18.693975 Epoch [1/50], Step[8000/8080], Loss: 3.9283, Perplexity: 50.82
2019-11-26 21:57:51.727076 Epoch [2/50], Step[0/8080], Loss: 3.8039, Perplexity: 44.87
2019-11-26 22:04:44.687927 Epoch [2/50], Step[1000/8080], Loss: 3.6599, Perplexity: 38.86
2019-11-26 22:11

2019-11-27 06:24:10.799696 Epoch [11/50], Step[2000/8080], Loss: 3.8758, Perplexity: 48.22
2019-11-27 06:30:50.435989 Epoch [11/50], Step[3000/8080], Loss: 3.7868, Perplexity: 44.11
2019-11-27 06:37:29.160993 Epoch [11/50], Step[4000/8080], Loss: 3.6692, Perplexity: 39.22
2019-11-27 06:44:10.851592 Epoch [11/50], Step[5000/8080], Loss: 4.1723, Perplexity: 64.86
2019-11-27 06:50:51.983604 Epoch [11/50], Step[6000/8080], Loss: 3.8793, Perplexity: 48.39
2019-11-27 06:57:30.485356 Epoch [11/50], Step[7000/8080], Loss: 3.6125, Perplexity: 37.06
2019-11-27 07:04:03.348289 Epoch [11/50], Step[8000/8080], Loss: 3.9218, Perplexity: 50.49
2019-11-27 07:04:36.664392 Epoch [12/50], Step[0/8080], Loss: 3.8147, Perplexity: 45.36
2019-11-27 07:11:11.777896 Epoch [12/50], Step[1000/8080], Loss: 3.5736, Perplexity: 35.64
2019-11-27 07:17:48.786294 Epoch [12/50], Step[2000/8080], Loss: 3.8825, Perplexity: 48.54
2019-11-27 07:24:25.754819 Epoch [12/50], Step[3000/8080], Loss: 3.7569, Perplexity: 42.82
20

In [15]:
model.eval()
cnn_encoder.eval()

perplexity = .0
num_step = 0
for i in range(0, ids.size(1) - seq_length, seq_length):
    # Get mini-batch inputs and targets
    inputs = ids[:, i:i+seq_length].to(device)
    targets = ids[:, (i+1):(i+1)+seq_length].to(device)

    # Get images
    images = np.zeros((inputs.size(0), inputs.size(1), char_size, char_size, 1))
    for b, seq in enumerate(ids):
        for s, idx in enumerate(ids[b][i:i+seq_length]):
            images[b, s] = np.load(f'char_img/noto_CJK/msr4/{idx}.npy').reshape(char_size,char_size,1)
    images = torch.from_numpy(images).float().to(device) # B N H W C
    images = images.view(-1, char_size, char_size, 1) # B*N H W C
    images = images.permute(0, 3, 1, 2) # from B*N H W C to B*N C H W

    # Get encoded images
    cnn_o = cnn_encoder(images)
    cnn_o = torch.reshape(cnn_o, (inputs.size(0), inputs.size(1), -1))

    # Forward pass
    state = detach(state)
    outputs, state = model(inputs, cnn_o, state)
    loss = criterion(outputs, targets.reshape(-1))

    perplexity += np.exp(loss.item())

    num_step += 1
    
print(f"Train Perplexity: {perplexity / num_step}")

Train Perplexity: 44.965010491396036


In [16]:
test_ids = corpus.get_data('icwb2-data/testing/msr_test.utf8', batch_size)
# filter out unknown character
test_ids = test_ids.view(-1)
mask = test_ids < vocab_size
test_ids = test_ids[mask]
num_batches = test_ids.size(0) // batch_size
test_ids = test_ids[:num_batches*batch_size]
test_ids = test_ids.view(batch_size, -1)

In [17]:
perplexity = .0
num_step = 0
for i in range(0, test_ids.size(1) - seq_length, seq_length):
    # Get mini-batch inputs and targets
    inputs = test_ids[:, i:i+seq_length].to(device)
    targets = test_ids[:, (i+1):(i+1)+seq_length].to(device)

    # Get images
    images = np.zeros((inputs.size(0), inputs.size(1), char_size, char_size, 1))
    for b, seq in enumerate(test_ids):
        for s, idx in enumerate(test_ids[b][i:i+seq_length]):
            images[b, s] = np.load(f'char_img/noto_CJK/msr4/{idx}.npy').reshape(char_size,char_size,1)
    images = torch.from_numpy(images).float().to(device) # B N H W C
    images = images.view(-1, char_size, char_size, 1) # B*N H W C
    images = images.permute(0, 3, 1, 2) # from B*N H W C to B*N C H W

    # Get encoded images
    cnn_o = cnn_encoder(images)
    cnn_o = torch.reshape(cnn_o, (inputs.size(0), inputs.size(1), -1))

    # Forward pass
    state = detach(state)
    outputs, state = model(inputs, cnn_o, state)
    loss = criterion(outputs, targets.reshape(-1))
    
    perplexity += np.exp(loss.item())

    num_step += 1
    
print(f"Test Perplexity: {perplexity / num_step}")

Test Perplexity: 58.09019991798291
