In [1]:
%matplotlib inline

In [2]:
from datetime import datetime

In [3]:
#renderer
from PIL import ImageFont
import numpy as np
import cv2
import torch.nn as nn
from torch.nn.utils import clip_grad_norm_

char_size = 24
# char render
def render(text, font=None):
    if font is None:
        font = ImageFont.truetype("/usr/share/fonts/opentype/noto/NotoSansCJK-Regular.ttc", char_size)
    mask = font.getmask(text)
    size = mask.size[::-1]
    a = np.asarray(mask).reshape(size) / 255
    res = cv2.resize(a, dsize=(char_size, char_size), interpolation=cv2.INTER_CUBIC)
    return res

In [4]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [5]:
# https://github.com/yunjey/pytorch-tutorial/blob/master/tutorials/02-intermediate/language_model/data_utils.py
import torch
import re

class Dictionary(object):
    def __init__(self, max_size=None):
        self.word2idx = {}
        self.idx2word = {}
        self.idx = 1
        self.word2idx['⸘'] = 0 # as unk
        self.idx2word[0] = '⸘'
        self.max_size = max_size + 1
    
    def add_word(self, word):
        if not word in self.word2idx and self.idx < self.max_size:
            self.word2idx[word] = self.idx
            self.idx2word[self.idx] = word
            self.idx += 1
    
    def __len__(self):
        return len(self.word2idx)


class Corpus(object):
    def __init__(self, max_size=None):
        self.dictionary = Dictionary(max_size=max_size)

    def get_data(self, path, batch_size=20):
        # Add words to the dictionary
#         with open(path, 'r') as f:
#             tokens = 0
#             for line in f:
#                 words = line.split() + ['<eos>']
#                 tokens += len(words)
#                 for word in words: 
#                     self.dictionary.add_word(word)  

        # split words to char and add to dictionary
        with open(path, 'r') as f:
            tokens = 0
            for line in f:
                line = ' '.join(line) # split words to char
                line = re.sub(r'[" "]+', ' ', line) # remove continous space
                chars = line.split() + ['¿'] # ¿ as <eos>
                tokens += len(chars)
                for char in chars:
                    self.dictionary.add_word(char)
        
        # Tokenize the file content
        ids = torch.LongTensor(tokens)
        token = 0
        with open(path, 'r') as f:
            for line in f:
                line = ' '.join(line) # split words to char
                line = re.sub(r'[" "]+', ' ', line) # remove continous space
                chars = line.split() + ['¿'] # ¿ as <eos>
                for char in chars:
                    if char in self.dictionary.word2idx:
                        ids[token] = self.dictionary.word2idx[char]
                        token += 1
                    else:
                        ids[token] = self.dictionary.word2idx['⸘']
                        token += 1
        num_batches = ids.size(0) // batch_size
        ids = ids[:num_batches*batch_size]
        return ids.view(batch_size, -1)

In [6]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [7]:
# RNN based language model
class RNNLM(nn.Module):
    def __init__(self, vocab_size, embed_size, hidden_size, num_layers):
        super(RNNLM, self).__init__()
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.gru = nn.GRU(embed_size, hidden_size, num_layers, batch_first=True)
        self.linear = nn.Linear(hidden_size, vocab_size)
        
    def forward(self, char_id, h):
        # Embed word ids to vectors
        x = self.embed(char_id)
        
        # Forward propagate LSTM
        out, h = self.gru(x, h)
        
        # Reshape output to (batch_size*sequence_length, hidden_size)
        out = out.reshape(out.size(0)*out.size(1), out.size(2))
        
        # Decode hidden states of all time steps
        out = self.linear(out)
        return out, h

In [8]:
# Hyper-parameters
embed_size = 300
hidden_size = 128
num_layers = 1
num_epochs = 50
batch_size = 16
seq_length = 32
learning_rate = 1e-3

# Load dataset
corpus = Corpus(max_size=4000)
ids = corpus.get_data('icwb2-data/training/msr_training.utf8', batch_size)
vocab_size = len(corpus.dictionary)
num_batches = ids.size(1) // seq_length

In [9]:
model = RNNLM(vocab_size, embed_size, hidden_size, num_layers).to(device)
model.train()
params = list(model.parameters())

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params, lr=learning_rate)

# Truncated backpropagation
def detach(state):
    return state.detach()

In [10]:
# Train the model
for epoch in range(num_epochs):
    # Set initial hidden and cell states
    state = torch.zeros(num_layers, batch_size, hidden_size).to(device)
    
    for i in range(0, ids.size(1) - seq_length, seq_length):
        # Get mini-batch inputs and targets
        inputs = ids[:, i:i+seq_length].to(device)
        targets = ids[:, (i+1):(i+1)+seq_length].to(device)
               
        # Forward pass
        state = detach(state)
        outputs, state = model(inputs, state)
        loss = criterion(outputs, targets.reshape(-1))
        
        # Backward and optimize
        model.zero_grad()
        loss.backward()
        clip_grad_norm_(params, 0.5)
        optimizer.step()

        step = (i+1) // seq_length
        if step % 1000 == 0:
            print ('{} Epoch [{}/{}], Step[{}/{}], Loss: {:.4f}, Perplexity: {:5.2f}'
                   .format(datetime.now(), epoch+1, num_epochs, step, num_batches, loss.item(), np.exp(loss.item())))

2019-11-25 16:34:42.927360 Epoch [1/50], Step[0/8080], Loss: 8.3332, Perplexity: 4159.65
2019-11-25 16:34:46.226363 Epoch [1/50], Step[1000/8080], Loss: 5.0667, Perplexity: 158.65
2019-11-25 16:34:49.540944 Epoch [1/50], Step[2000/8080], Loss: 4.6069, Perplexity: 100.17
2019-11-25 16:34:52.933879 Epoch [1/50], Step[3000/8080], Loss: 4.3897, Perplexity: 80.61
2019-11-25 16:34:56.535907 Epoch [1/50], Step[4000/8080], Loss: 4.3132, Perplexity: 74.68
2019-11-25 16:35:00.530098 Epoch [1/50], Step[5000/8080], Loss: 4.7531, Perplexity: 115.94
2019-11-25 16:35:04.527642 Epoch [1/50], Step[6000/8080], Loss: 4.6090, Perplexity: 100.39
2019-11-25 16:35:08.178163 Epoch [1/50], Step[7000/8080], Loss: 4.1619, Perplexity: 64.19
2019-11-25 16:35:11.882611 Epoch [1/50], Step[8000/8080], Loss: 4.2702, Perplexity: 71.54
2019-11-25 16:35:12.196362 Epoch [2/50], Step[0/8080], Loss: 4.3602, Perplexity: 78.27
2019-11-25 16:35:15.548286 Epoch [2/50], Step[1000/8080], Loss: 4.2432, Perplexity: 69.63
2019-11-25

2019-11-25 16:39:33.553033 Epoch [11/50], Step[2000/8080], Loss: 3.9155, Perplexity: 50.17
2019-11-25 16:39:37.046685 Epoch [11/50], Step[3000/8080], Loss: 3.7420, Perplexity: 42.18
2019-11-25 16:39:40.605762 Epoch [11/50], Step[4000/8080], Loss: 3.7413, Perplexity: 42.15
2019-11-25 16:39:44.165428 Epoch [11/50], Step[5000/8080], Loss: 4.3696, Perplexity: 79.01
2019-11-25 16:39:47.717851 Epoch [11/50], Step[6000/8080], Loss: 3.9294, Perplexity: 50.88
2019-11-25 16:39:51.289860 Epoch [11/50], Step[7000/8080], Loss: 3.7481, Perplexity: 42.44
2019-11-25 16:39:54.925217 Epoch [11/50], Step[8000/8080], Loss: 3.9216, Perplexity: 50.48
2019-11-25 16:39:55.197532 Epoch [12/50], Step[0/8080], Loss: 3.9281, Perplexity: 50.81
2019-11-25 16:39:58.678749 Epoch [12/50], Step[1000/8080], Loss: 3.8325, Perplexity: 46.18
2019-11-25 16:40:02.152361 Epoch [12/50], Step[2000/8080], Loss: 3.8973, Perplexity: 49.27
2019-11-25 16:40:05.674003 Epoch [12/50], Step[3000/8080], Loss: 3.7405, Perplexity: 42.12
20

2019-11-25 16:44:35.028521 Epoch [21/50], Step[3000/8080], Loss: 3.6895, Perplexity: 40.03
2019-11-25 16:44:38.828042 Epoch [21/50], Step[4000/8080], Loss: 3.6738, Perplexity: 39.40
2019-11-25 16:44:42.505853 Epoch [21/50], Step[5000/8080], Loss: 4.3070, Perplexity: 74.22
2019-11-25 16:44:46.540201 Epoch [21/50], Step[6000/8080], Loss: 3.8736, Perplexity: 48.11
2019-11-25 16:44:50.602572 Epoch [21/50], Step[7000/8080], Loss: 3.6539, Perplexity: 38.63
2019-11-25 16:44:54.804658 Epoch [21/50], Step[8000/8080], Loss: 3.8664, Perplexity: 47.77
2019-11-25 16:44:55.122196 Epoch [22/50], Step[0/8080], Loss: 3.8729, Perplexity: 48.08
2019-11-25 16:44:59.216249 Epoch [22/50], Step[1000/8080], Loss: 3.7090, Perplexity: 40.81
2019-11-25 16:45:03.125923 Epoch [22/50], Step[2000/8080], Loss: 3.9050, Perplexity: 49.65
2019-11-25 16:45:06.893025 Epoch [22/50], Step[3000/8080], Loss: 3.6865, Perplexity: 39.90
2019-11-25 16:45:10.855960 Epoch [22/50], Step[4000/8080], Loss: 3.6629, Perplexity: 38.97
20

2019-11-25 16:49:56.827747 Epoch [31/50], Step[4000/8080], Loss: 3.6275, Perplexity: 37.62
2019-11-25 16:50:00.617289 Epoch [31/50], Step[5000/8080], Loss: 4.2477, Perplexity: 69.94
2019-11-25 16:50:04.628475 Epoch [31/50], Step[6000/8080], Loss: 3.8688, Perplexity: 47.88
2019-11-25 16:50:08.537782 Epoch [31/50], Step[7000/8080], Loss: 3.6273, Perplexity: 37.61
2019-11-25 16:50:12.925974 Epoch [31/50], Step[8000/8080], Loss: 3.8978, Perplexity: 49.30
2019-11-25 16:50:13.267956 Epoch [32/50], Step[0/8080], Loss: 3.8473, Perplexity: 46.86
2019-11-25 16:50:17.467991 Epoch [32/50], Step[1000/8080], Loss: 3.6638, Perplexity: 39.01
2019-11-25 16:50:21.800254 Epoch [32/50], Step[2000/8080], Loss: 3.9033, Perplexity: 49.57
2019-11-25 16:50:25.536254 Epoch [32/50], Step[3000/8080], Loss: 3.6492, Perplexity: 38.45
2019-11-25 16:50:29.509969 Epoch [32/50], Step[4000/8080], Loss: 3.6274, Perplexity: 37.61
2019-11-25 16:50:33.284363 Epoch [32/50], Step[5000/8080], Loss: 4.2537, Perplexity: 70.36
20

2019-11-25 16:55:24.200865 Epoch [41/50], Step[5000/8080], Loss: 4.2617, Perplexity: 70.93
2019-11-25 16:55:28.077689 Epoch [41/50], Step[6000/8080], Loss: 3.8557, Perplexity: 47.26
2019-11-25 16:55:31.821122 Epoch [41/50], Step[7000/8080], Loss: 3.6191, Perplexity: 37.30
2019-11-25 16:55:35.792220 Epoch [41/50], Step[8000/8080], Loss: 3.8641, Perplexity: 47.66
2019-11-25 16:55:36.111000 Epoch [42/50], Step[0/8080], Loss: 3.8405, Perplexity: 46.55
2019-11-25 16:55:39.999243 Epoch [42/50], Step[1000/8080], Loss: 3.6239, Perplexity: 37.48
2019-11-25 16:55:43.843061 Epoch [42/50], Step[2000/8080], Loss: 3.8773, Perplexity: 48.29
2019-11-25 16:55:47.792781 Epoch [42/50], Step[3000/8080], Loss: 3.6432, Perplexity: 38.21
2019-11-25 16:55:51.524966 Epoch [42/50], Step[4000/8080], Loss: 3.6464, Perplexity: 38.34
2019-11-25 16:55:55.259601 Epoch [42/50], Step[5000/8080], Loss: 4.2188, Perplexity: 67.95
2019-11-25 16:55:59.103004 Epoch [42/50], Step[6000/8080], Loss: 3.8921, Perplexity: 49.01
20

In [11]:
model.eval()

perplexity = .0
num_step = 0
for i in range(0, ids.size(1) - seq_length, seq_length):
    # Get mini-batch inputs and targets
    inputs = ids[:, i:i+seq_length].to(device)
    targets = ids[:, (i+1):(i+1)+seq_length].to(device)

    # Forward pass
    state = detach(state)
    outputs, state = model(inputs, state)
    loss = criterion(outputs, targets.reshape(-1))
    
    perplexity += np.exp(loss.item())

    num_step += 1
    
print(f"Train Perplexity: {perplexity / num_step}")

Train Perplexity: 43.86830965343008


In [12]:
test_ids = corpus.get_data('icwb2-data/testing/msr_test.utf8', batch_size)
# filter out unknown character
test_ids = test_ids.view(-1)
mask = test_ids < vocab_size
test_ids = test_ids[mask]
num_batches = test_ids.size(0) // batch_size
test_ids = test_ids[:num_batches*batch_size]
test_ids = test_ids.view(batch_size, -1)

In [13]:
perplexity = .0
num_step = 0
for i in range(0, test_ids.size(1) - seq_length, seq_length):
    # Get mini-batch inputs and targets
    inputs = test_ids[:, i:i+seq_length].to(device)
    targets = test_ids[:, (i+1):(i+1)+seq_length].to(device)

    # Forward pass
    state = detach(state)
    outputs, state = model(inputs, state)
    loss = criterion(outputs, targets.reshape(-1))

    perplexity += np.exp(loss.item())

    num_step += 1
    
print(f"Test Perplexity: {perplexity / num_step}")

Test Perplexity: 58.28757771350248
