In [1]:
%matplotlib inline

In [2]:
from datetime import datetime

In [3]:
#renderer
from PIL import ImageFont
import numpy as np
import cv2
import torch.nn as nn
from torch.nn.utils import clip_grad_norm_

char_size = 12
# char render
def render(text, font=None):
    if font is None:
        font = ImageFont.truetype("/usr/share/fonts/opentype/noto/NotoSansCJK-Regular.ttc", char_size)
    mask = font.getmask(text)
    size = mask.size[::-1]
    a = np.asarray(mask).reshape(size) / 255
    res = cv2.resize(a, dsize=(char_size, char_size), interpolation=cv2.INTER_CUBIC)
    return res

In [4]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [5]:
# https://github.com/yunjey/pytorch-tutorial/blob/master/tutorials/02-intermediate/language_model/data_utils.py
import torch
import re

class Dictionary(object):
    def __init__(self, max_size=None):
        self.word2idx = {}
        self.idx2word = {}
        self.idx = 1
        self.word2idx['⸘'] = 0 # as unk
        self.idx2word[0] = '⸘'
        self.max_size = max_size + 1
    
    def add_word(self, word):
        if not word in self.word2idx and self.idx < self.max_size:
            self.word2idx[word] = self.idx
            self.idx2word[self.idx] = word
            self.idx += 1
    
    def __len__(self):
        return len(self.word2idx)


class Corpus(object):
    def __init__(self, max_size=None):
        self.dictionary = Dictionary(max_size=max_size)

    def get_data(self, path, batch_size=20):
        # Add words to the dictionary
#         with open(path, 'r') as f:
#             tokens = 0
#             for line in f:
#                 words = line.split() + ['<eos>']
#                 tokens += len(words)
#                 for word in words: 
#                     self.dictionary.add_word(word)  

        # split words to char and add to dictionary
        with open(path, 'r') as f:
            tokens = 0
            for line in f:
                line = ' '.join(line) # split words to char
                line = re.sub(r'[" "]+', ' ', line) # remove continous space
                chars = line.split() + ['¿'] # ¿ as <eos>
                tokens += len(chars)
                for char in chars:
                    self.dictionary.add_word(char)
        
        # Tokenize the file content
        ids = torch.LongTensor(tokens)
        token = 0
        with open(path, 'r') as f:
            for line in f:
                line = ' '.join(line) # split words to char
                line = re.sub(r'[" "]+', ' ', line) # remove continous space
                chars = line.split() + ['¿'] # ¿ as <eos>
                for char in chars:
                    if char in self.dictionary.word2idx:
                        ids[token] = self.dictionary.word2idx[char]
                        token += 1
                    else:
                        ids[token] = self.dictionary.word2idx['⸘']
                        token += 1
        num_batches = ids.size(0) // batch_size
        ids = ids[:num_batches*batch_size]
        return ids.view(batch_size, -1)

In [6]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [7]:
# RNN based language model
class RNNLM(nn.Module):
    def __init__(self, vocab_size, embed_size, hidden_size, num_layers):
        super(RNNLM, self).__init__()
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.gru = nn.GRU(embed_size, hidden_size, num_layers, batch_first=True)
        self.linear = nn.Linear(hidden_size, vocab_size)
        
    def forward(self, char_id, char_cnn_o, h):
        # Embed word ids to vectors
        x = self.embed(char_id) + char_cnn_o
        
        # Forward propagate LSTM
        out, h = self.gru(x, h)
        
        # Reshape output to (batch_size*sequence_length, hidden_size)
        out = out.reshape(out.size(0)*out.size(1), out.size(2))
        
        # Decode hidden states of all time steps
        out = self.linear(out)
        return out, h

In [8]:
# encoding: utf-8
"""
From https://github.com/ShannonAI/glyce/blob/master/glyce/glyph_cnn_models/glyph_group_cnn.py
@author: wuwei
@contact: wu.wei@pku.edu.cn
@version: 1.0
@file: cnn_for_fonts.py
@time: 19-1-2 上午11:07
用CNN将字体的灰度图卷积成特征向量
"""
import torch.nn.functional as F 

class GlyphGroupCNN(nn.Module):     #55000
    def __init__(self, cnn_type='simple', kernel_size=5, font_channels=1, shuffle=False, ntokens=4000,
                 num_features=8*12*4, final_width=2, cnn_drop=0.5, groups=16):
        super(GlyphGroupCNN, self).__init__()
        self.aux_logits=False
        self.cnn_type = cnn_type
        output_channels = num_features
        self.conv1 = nn.Conv2d(font_channels, output_channels, kernel_size)
        midchannels = output_channels//4
        self.mid_groups = max(groups//2, 1)
        self.downsample = nn.Conv2d(output_channels, midchannels, kernel_size=1, groups=self.mid_groups)   #//2是因为参数量主要集中在下一层
        self.max_pool = nn.AdaptiveMaxPool2d((final_width, final_width))
        self.num_features = num_features
        self.reweight_conv = nn.Conv2d(midchannels, output_channels, kernel_size=final_width, groups=groups)
        self.output_channels=output_channels
        self.shuffle = shuffle
        self.dropout = nn.Dropout(cnn_drop)
        self.init_weights()

    def forward(self, x):
        # x = self.base_conv(x)
        x = F.relu(self.conv1(x) ) # [(seq_len*batchsize, Co, h, w), ...]*len(Ks)
        x = self.max_pool(x)  # n, c, 2, 2
        x = self.downsample(x)
        if self.shuffle:
            x = channel_shuffle(x, groups=2)
        x = F.relu(self.reweight_conv(x))
        if self.shuffle:
            x = channel_shuffle(x, groups=2)
        x = x.view(x.size(0), -1)
        x = self.dropout(x)
        return x  # (seq_len*batchsize, nfeats)

    def init_weights(self):
        initrange = 0.1
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                m.weight.data.uniform_(-initrange, initrange)
                # init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.normal_(mean=1, std=0.001)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.weight.data.normal_(std=0.001)
                if m.bias is not None:
                    m.bias.data.zero_()

In [9]:
# Hyper-parameters
embed_size = 256
hidden_size = 128
num_layers = 1
num_epochs = 50
batch_size = 16
seq_length = 32
learning_rate = 1e-3

# Load dataset
corpus = Corpus(max_size=4000)
ids = corpus.get_data('icwb2-data/training/msr_training.utf8', batch_size)
vocab_size = len(corpus.dictionary)
num_batches = ids.size(1) // seq_length

In [10]:
# save char images for reference
for char, idx in corpus.dictionary.word2idx.items():
    np.save(f'char_img/noto_CJK/msr2/{idx}.npy', render(char))

In [11]:
model = RNNLM(vocab_size, embed_size, hidden_size, num_layers).to(device)
cnn_encoder = GlyphGroupCNN(num_features=256).to(device)
model.train()
cnn_encoder.train()
params = list(model.parameters())+list(cnn_encoder.parameters())

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params, lr=learning_rate)

# Truncated backpropagation
def detach(state):
    return state.detach()

In [12]:
# Train the CNN + RNNLM
for param in model.embed.parameters():
    param.requires_grad = True

for epoch in range(num_epochs):
    # Set initial hidden and cell states
    state = torch.zeros(num_layers, batch_size, hidden_size).to(device)
    
    for i in range(0, ids.size(1) - seq_length, seq_length):
        # Get mini-batch inputs and targets
        inputs = ids[:, i:i+seq_length].to(device)
        targets = ids[:, (i+1):(i+1)+seq_length].to(device)
        
        # Get images
        images = np.zeros((inputs.size(0), inputs.size(1), char_size, char_size, 1))
        for b, seq in enumerate(ids):
            for s, idx in enumerate(ids[b][i:i+seq_length]):
                images[b, s] = np.load(f'char_img/noto_CJK/msr2/{idx}.npy').reshape(char_size,char_size,1)
        images = torch.from_numpy(images).float().to(device) # B N H W C
        images = images.view(-1, char_size, char_size, 1) # B*N H W C
        images = images.permute(0, 3, 1, 2) # from B*N H W C to B*N C H W
        
        # Get encoded images
        cnn_o = cnn_encoder(images)
        cnn_o = torch.reshape(cnn_o, (inputs.size(0), inputs.size(1), -1))
               
        # Forward pass
        state = detach(state)
        outputs, state = model(inputs, cnn_o, state)
        loss = criterion(outputs, targets.reshape(-1))
        
        # Backward and optimize
        model.zero_grad()
        cnn_encoder.zero_grad()
        loss.backward()
        clip_grad_norm_(params, 0.5)
        optimizer.step()

        step = (i+1) // seq_length
        if step % 1000 == 0:
            print ('{} Epoch [{}/{}], Step[{}/{}], Loss: {:.4f}, Perplexity: {:5.2f}'
                   .format(datetime.now(), epoch+1, num_epochs, step, num_batches, loss.item(), np.exp(loss.item())))

2019-11-25 16:44:04.337784 Epoch [1/50], Step[0/8080], Loss: 8.3254, Perplexity: 4127.25
2019-11-25 16:50:21.965699 Epoch [1/50], Step[1000/8080], Loss: 5.0842, Perplexity: 161.45
2019-11-25 16:56:39.449228 Epoch [1/50], Step[2000/8080], Loss: 4.6569, Perplexity: 105.31
2019-11-25 17:02:40.757554 Epoch [1/50], Step[3000/8080], Loss: 4.4453, Perplexity: 85.23
2019-11-25 17:08:18.068562 Epoch [1/50], Step[4000/8080], Loss: 4.3006, Perplexity: 73.74
2019-11-25 17:14:42.974429 Epoch [1/50], Step[5000/8080], Loss: 4.6906, Perplexity: 108.92
2019-11-25 17:21:07.573676 Epoch [1/50], Step[6000/8080], Loss: 4.6470, Perplexity: 104.27
2019-11-25 17:27:41.312060 Epoch [1/50], Step[7000/8080], Loss: 4.1596, Perplexity: 64.05
2019-11-25 17:34:11.123083 Epoch [1/50], Step[8000/8080], Loss: 4.3503, Perplexity: 77.50
2019-11-25 17:34:46.625706 Epoch [2/50], Step[0/8080], Loss: 4.3467, Perplexity: 77.22
2019-11-25 17:42:19.159301 Epoch [2/50], Step[1000/8080], Loss: 4.1743, Perplexity: 64.99
2019-11-25

2019-11-26 01:57:04.684106 Epoch [11/50], Step[2000/8080], Loss: 3.9504, Perplexity: 51.96
2019-11-26 02:03:55.588541 Epoch [11/50], Step[3000/8080], Loss: 3.7801, Perplexity: 43.82
2019-11-26 02:10:48.046512 Epoch [11/50], Step[4000/8080], Loss: 3.6888, Perplexity: 40.00
2019-11-26 02:17:34.528396 Epoch [11/50], Step[5000/8080], Loss: 4.3289, Perplexity: 75.86
2019-11-26 02:24:23.865908 Epoch [11/50], Step[6000/8080], Loss: 4.0757, Perplexity: 58.89
2019-11-26 02:31:15.048665 Epoch [11/50], Step[7000/8080], Loss: 3.6867, Perplexity: 39.91
2019-11-26 02:38:05.125886 Epoch [11/50], Step[8000/8080], Loss: 3.9961, Perplexity: 54.39
2019-11-26 02:38:37.910790 Epoch [12/50], Step[0/8080], Loss: 3.9975, Perplexity: 54.46
2019-11-26 02:45:30.387381 Epoch [12/50], Step[1000/8080], Loss: 3.7632, Perplexity: 43.09
2019-11-26 02:52:17.886011 Epoch [12/50], Step[2000/8080], Loss: 3.9313, Perplexity: 50.98
2019-11-26 02:59:09.571361 Epoch [12/50], Step[3000/8080], Loss: 3.7830, Perplexity: 43.95
20

2019-11-26 11:16:24.594493 Epoch [21/50], Step[3000/8080], Loss: 3.7555, Perplexity: 42.75
2019-11-26 11:23:15.306751 Epoch [21/50], Step[4000/8080], Loss: 3.7017, Perplexity: 40.52
2019-11-26 11:30:10.400361 Epoch [21/50], Step[5000/8080], Loss: 4.3047, Perplexity: 74.05
2019-11-26 11:37:04.232856 Epoch [21/50], Step[6000/8080], Loss: 3.9889, Perplexity: 54.00
2019-11-26 11:43:55.435974 Epoch [21/50], Step[7000/8080], Loss: 3.6042, Perplexity: 36.75
2019-11-26 11:50:45.800613 Epoch [21/50], Step[8000/8080], Loss: 3.9447, Perplexity: 51.66
2019-11-26 11:51:18.129917 Epoch [22/50], Step[0/8080], Loss: 3.8864, Perplexity: 48.74
2019-11-26 11:58:07.477336 Epoch [22/50], Step[1000/8080], Loss: 3.7012, Perplexity: 40.50
2019-11-26 12:04:59.423795 Epoch [22/50], Step[2000/8080], Loss: 3.8749, Perplexity: 48.18
2019-11-26 12:11:48.501669 Epoch [22/50], Step[3000/8080], Loss: 3.7330, Perplexity: 41.81
2019-11-26 12:18:36.148179 Epoch [22/50], Step[4000/8080], Loss: 3.6774, Perplexity: 39.54
20

2019-11-26 20:42:59.989336 Epoch [31/50], Step[4000/8080], Loss: 3.6531, Perplexity: 38.59
2019-11-26 20:50:00.214771 Epoch [31/50], Step[5000/8080], Loss: 4.2701, Perplexity: 71.53
2019-11-26 20:56:56.937645 Epoch [31/50], Step[6000/8080], Loss: 3.9853, Perplexity: 53.80
2019-11-26 21:03:50.743179 Epoch [31/50], Step[7000/8080], Loss: 3.5733, Perplexity: 35.63
2019-11-26 21:10:49.224484 Epoch [31/50], Step[8000/8080], Loss: 3.8911, Perplexity: 48.96
2019-11-26 21:11:21.845656 Epoch [32/50], Step[0/8080], Loss: 3.8752, Perplexity: 48.19
2019-11-26 21:18:20.197078 Epoch [32/50], Step[1000/8080], Loss: 3.6515, Perplexity: 38.53
2019-11-26 21:25:15.311932 Epoch [32/50], Step[2000/8080], Loss: 3.8456, Perplexity: 46.79
2019-11-26 21:32:16.510291 Epoch [32/50], Step[3000/8080], Loss: 3.7303, Perplexity: 41.69
2019-11-26 21:39:10.745286 Epoch [32/50], Step[4000/8080], Loss: 3.6540, Perplexity: 38.63
2019-11-26 21:46:09.731599 Epoch [32/50], Step[5000/8080], Loss: 4.2921, Perplexity: 73.12
20

2019-11-27 06:11:48.602098 Epoch [41/50], Step[5000/8080], Loss: 4.2600, Perplexity: 70.81
2019-11-27 06:18:38.005206 Epoch [41/50], Step[6000/8080], Loss: 4.0064, Perplexity: 54.95
2019-11-27 06:25:23.578977 Epoch [41/50], Step[7000/8080], Loss: 3.5278, Perplexity: 34.05
2019-11-27 06:32:09.730196 Epoch [41/50], Step[8000/8080], Loss: 3.8940, Perplexity: 49.11
2019-11-27 06:32:43.504074 Epoch [42/50], Step[0/8080], Loss: 3.8349, Perplexity: 46.29
2019-11-27 06:39:27.187743 Epoch [42/50], Step[1000/8080], Loss: 3.6717, Perplexity: 39.32
2019-11-27 06:46:14.123452 Epoch [42/50], Step[2000/8080], Loss: 3.8392, Perplexity: 46.49
2019-11-27 06:53:01.551189 Epoch [42/50], Step[3000/8080], Loss: 3.7238, Perplexity: 41.42
2019-11-27 06:59:46.985931 Epoch [42/50], Step[4000/8080], Loss: 3.6165, Perplexity: 37.21
2019-11-27 07:06:34.277138 Epoch [42/50], Step[5000/8080], Loss: 4.2637, Perplexity: 71.07
2019-11-27 07:13:21.898472 Epoch [42/50], Step[6000/8080], Loss: 3.9841, Perplexity: 53.74
20

In [13]:
model.eval()
cnn_encoder.eval()

perplexity = .0
num_step = 0
for i in range(0, ids.size(1) - seq_length, seq_length):
    # Get mini-batch inputs and targets
    inputs = ids[:, i:i+seq_length].to(device)
    targets = ids[:, (i+1):(i+1)+seq_length].to(device)

    # Get images
    images = np.zeros((inputs.size(0), inputs.size(1), char_size, char_size, 1))
    for b, seq in enumerate(ids):
        for s, idx in enumerate(ids[b][i:i+seq_length]):
            images[b, s] = np.load(f'char_img/noto_CJK/msr2/{idx}.npy').reshape(char_size,char_size,1)
    images = torch.from_numpy(images).float().to(device) # B N H W C
    images = images.view(-1, char_size, char_size, 1) # B*N H W C
    images = images.permute(0, 3, 1, 2) # from B*N H W C to B*N C H W

    # Get encoded images
    cnn_o = cnn_encoder(images)
    cnn_o = torch.reshape(cnn_o, (inputs.size(0), inputs.size(1), -1))

    # Forward pass
    state = detach(state)
    outputs, state = model(inputs, cnn_o, state)
    loss = criterion(outputs, targets.reshape(-1))

    perplexity += np.exp(loss.item())

    num_step += 1
    
print(f"Train Perplexity: {perplexity / num_step}")

Train Perplexity: 43.96638490339131


In [14]:
test_ids = corpus.get_data('icwb2-data/testing/msr_test.utf8', batch_size)
# filter out unknown character
test_ids = test_ids.view(-1)
mask = test_ids < vocab_size
test_ids = test_ids[mask]
num_batches = test_ids.size(0) // batch_size
test_ids = test_ids[:num_batches*batch_size]
test_ids = test_ids.view(batch_size, -1)

In [15]:
perplexity = .0
num_step = 0
for i in range(0, test_ids.size(1) - seq_length, seq_length):
    # Get mini-batch inputs and targets
    inputs = test_ids[:, i:i+seq_length].to(device)
    targets = test_ids[:, (i+1):(i+1)+seq_length].to(device)

    # Get images
    images = np.zeros((inputs.size(0), inputs.size(1), char_size, char_size, 1))
    for b, seq in enumerate(test_ids):
        for s, idx in enumerate(test_ids[b][i:i+seq_length]):
            images[b, s] = np.load(f'char_img/noto_CJK/msr2/{idx}.npy').reshape(char_size,char_size,1)
    images = torch.from_numpy(images).float().to(device) # B N H W C
    images = images.view(-1, char_size, char_size, 1) # B*N H W C
    images = images.permute(0, 3, 1, 2) # from B*N H W C to B*N C H W

    # Get encoded images
    cnn_o = cnn_encoder(images)
    cnn_o = torch.reshape(cnn_o, (inputs.size(0), inputs.size(1), -1))

    # Forward pass
    state = detach(state)
    outputs, state = model(inputs, cnn_o, state)
    loss = criterion(outputs, targets.reshape(-1))
    
    perplexity += np.exp(loss.item())

    num_step += 1
    
print(f"Test Perplexity: {perplexity / num_step}")

Test Perplexity: 58.55850155913826
