In [1]:
import matplotlib.pyplot as plt
import json
import os
from PIL import Image
from sklearn.model_selection import train_test_split

import torch
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader

import torchvision
from torchvision import transforms
from torchvision.utils import make_grid

print('PyTorch version:', torch.__version__)
print('torchvision version:', torchvision.__version__)
use_gpu = torch.cuda.is_available()
print('Is GPU available:', use_gpu)

PyTorch version: 1.0.0
torchvision version: 0.2.1
Is GPU available: True


In [15]:
# general settings

# device
device = torch.device('cuda' if use_gpu else 'cpu')

# batchsize
batchsize = 64

# seed setting (warning : cuDNN's randomness is remaining)
seed = 1
torch.manual_seed(seed)
if use_gpu:
    torch.cuda.manual_seed(seed)

In [102]:
# directory settings
root_dir = '../../data/'
# directory for training data images
image_dir = root_dir + 'kkanji2_expansion_can_get_radical/'
# directory for save logs and trained weights
save_dir = root_dir + 'kkanji2_result/'
if not os.path.exists(save_dir):
    os.mkdir(save_dir)

# load json for label
with open(root_dir + 'kanjivg_radical/utf16_to_radical.json') as f:
    utf16_to_radical = json.load(f)

In [17]:
# prepare dict for one-hot encoding
radical_set = set()
for value in utf16_to_radical.values():
    for v in value:
        radical_set.add(v)
radical_list = sorted(list(radical_set))

radical_dict = {}
for index, radical in enumerate(radical_list):
    radical_dict[radical] = index

n_radical = len(radical_dict)
print('the number of radical:', n_radical)

the number of radical: 1300


In [18]:
# make dataset class for image loading and label setting
class KanjiRadicalDataset(Dataset):
    def __init__(self, image_dir, image_name_list, utf16_to_radical, radical_dict, transform=None):
        self.image_dir = image_dir
        self.image_name_list = sorted(image_name_list)
        
        self.utf16_to_radical = utf16_to_radical
        self.radical_dict = radical_dict
        
        self.n_radical = len(radical_dict)
        self.transform = transform
        
    def __len__(self):
        return len(self.image_name_list)
    
    def __getitem__(self, idx):
        image_name = self.image_name_list[idx]        
        image = Image.open(self.image_dir + image_name)
        if self.transform:
            image = self.transform(image)
            
        label = torch.zeros(self.n_radical)
        utf16_code = image_name[:4]
        radical_list = self.utf16_to_radical[utf16_code]
        for radical in radical_list:
            label[radical_dict[radical]] = 1
            
        return image, label

In [19]:
# make dataset and train test split
train_name_list, validation_name_list = train_test_split(os.listdir(image_dir), test_size = 0.2, random_state = seed)

tf_train = transforms.Compose([transforms.RandomCrop(64, padding=8), transforms.ToTensor()])
tf_validation = transforms.ToTensor()

train_data = KanjiRadicalDataset(image_dir, train_name_list, utf16_to_radical, radical_dict, transform=tf_train)
validation_data = KanjiRadicalDataset(image_dir, validation_name_list, utf16_to_radical, radical_dict, transform=tf_validation)

print('The number of training data:', len(train_data))
print('The number of validation data:', len(validation_data))

# make DataLoader
train_loader = DataLoader(train_data, batch_size=batchsize, shuffle=True)
validation_loader = DataLoader(validation_data, batch_size=batchsize, shuffle=False)

The number of training data: 94127
The number of validation data: 23532


In [20]:
# a = set()
# for vs in utf16_to_radical.values():
#     for v in vs:
#        a.add(v)
# print(a)

In [21]:
# データのラベルの分布を見る
# radical_count = torch.zeros(n_radical)
# for index, (image, label) in enumerate(validation_loader):
#    radical_count += torch.sum(label, dim=0)
#    print('\rprogress[%d/%d]' % (index+1, len(validation_loader)), end='')

In [22]:
# radical_count_non_zero = radical_count[radical_count > 0]
# plt.xlim([0,500])
# plt.ylim([0,100])
# plt.hist(radical_count_non_zero, bins = 100)
# plt.show()

In [23]:
# plt.bar(range(len(radical_count_non_zero)), radical_count_non_zero)
# plt.show()

In [24]:
# https://github.com/kuangliu/pytorch-cifar/blob/master/models/preact_resnet.pyより PreActResNet-18
# 一番最初の入力チャネルを1チャネルに変更、フィルタ数を全体的に増やしてある
# 本当は多分初期化をちゃんとやったほうが良いが、取り敢えずはこのまま
# preactResNet-18の採用理由はkmnistの提案論文のbaselineに合わせるためだが、あれは32x32のkmnist, k49で使われたものなので、
# 64x64のkkanjiに適用するのは微妙かも
# データセットの提案論文ではmanifold mixupをdata augumentationとして採用していますが、取り敢えずはまだやっていない

class PreActBlock(nn.Module):
    '''Pre-activation version of the BasicBlock.'''
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(PreActBlock, self).__init__()
        self.bn1 = nn.BatchNorm2d(in_planes)
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)

        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)
            )

    def forward(self, x):
        out = F.relu(self.bn1(x))
        shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x
        out = self.conv1(out)
        out = self.conv2(F.relu(self.bn2(out)))
        out += shortcut
        return out

    
class PreActResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(PreActResNet, self).__init__()
        planes = 256
        self.in_planes = planes
        
        self.conv1 = nn.Conv2d(1, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.layer1 = self._make_layer(block,   planes, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 2*planes, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 4*planes, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 8*planes, num_blocks[3], stride=2)
        self.linear = nn.Linear(8*planes*block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.conv1(x)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 8)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

# Todo:情報量ボトルネックがavg_poolの辺りにあるので、ImageNetのPreActResNetとかを参考に改善すること
# これで良いっぽい？
def PreActResNet18(num_classes):
    return PreActResNet(PreActBlock, [2,2,2,2], num_classes)

In [25]:
net = PreActResNet18(n_radical)
net = net.to(device)

criterion = nn.BCEWithLogitsLoss()
optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)

num_trainable_params = sum(p.numel() for p in net.parameters() if p.requires_grad)

print('The number of trainable parameters:', num_trainable_params)
print('\nModel:\n', net)
print('\nLoss function:\n', criterion)
print('\nOptimizer:\n', optimizer)

The number of trainable parameters: 181213204

Model:
 PreActResNet(
  (conv1): Conv2d(1, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (layer1): Sequential(
    (0): PreActBlock(
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    )
    (1): PreActBlock(
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    )
  )
  

In [72]:
def calculate_accuracy_per_kanji_and_radical(outputs, labels):
    with torch.no_grad():
        outputs = torch.sigmoid(outputs) > 0.5
        labels = labels.type(torch.uint8)
        is_correct = (outputs == labels)
        accuracy_per_kanji = (torch.sum(is_correct, dim=1) == outputs.size(1)).float().mean()
        accuracy_per_radical = is_correct.float().mean()
    return [accuracy_per_kanji, arrucary_per_radical]

In [74]:
def train(train_loader):
    net.train()
    running_loss = 0
    running_accuracy_per_kanji = 0
    running_accuracy_per_radical = 0
    
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        accuracies = calculate_accuracy_per_kanji_and_radical(outputs.detach(), labels)
        running_accuracy_per_kanji += accuracies[0]
        running_accuracy_per_radical += accuracies[1]
        
    train_loss = running_loss / len(train_loader)
    accuracy_per_kanji = running_accuracy_per_kanji / len(train_loader)
    accuracy_per_radical = running_accuracy_per_radical / len(train_loader)
    
    return train_loss, accuracy_per_kanji, accuracy_per_radical

In [76]:
def validation(validation_loader):
    net.eval()
    running_loss = 0
    running_accuracy_per_kanji = 0
    running_accuracy_per_radical = 0
    
    with torch.no_grad():
        for inputs, labels in validation_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            accuracies = calculate_accuracy_per_kanji_and_radical(outputs.detach(), labels)
            running_accuracy_per_kanji += accuracies[0]
            running_accuracy_per_radical += accuracies[1]            

    validation_loss = running_loss / len(validation_loader)
    accuracy_per_kanji = running_accuracy_per_kanji / len(validation_loader)
    accuracy_per_radical = running_accuracy_per_radical / len(validation_loader)
    
    return validation_loss, accuracy_per_kanji, accuracy_per_radical

In [None]:
train_loss_list = []
train_accuracy_per_kanji_list = []
train_accuracy_per_radical_list = []

validation_loss_list = []
validation_accuracy_per_kanji_list = []
validation_accuracy_per_radical_list = []

n_epochs = 2

for epoch in range(n_epochs):
    train_loss, train_accuracy_per_kanji, train_accuracy_per_radical = train(train_loader)
    validation_loss, validation_accuracy_per_kanji, validation_accuracy_per_radical = validation(validation_loader)
    
    train_loss_list.append(train_loss)
    train_accuracy_per_kanji_list.append(train_accuracy_per_kanji)
    train_accuracy_per_radical_list.append(train_accuracy_per_radical)
    
    validation_loss_list.append(validation_loss)
    validation_accuracy_per_kanji_list.append(validation_acccuracy_per_kanji)
    validation_accuracy_per_radical_list.append(validation_arrucary_per_radical)
    
    print('epoch[%3d/%3d] train[loss:%1.4f accuracy_per_kanji:%1.4f accuracy_per_radical:%1.4f]' \
          % (epoch, n_epochs, train_loss, train_accuracy_per_kanji, train_accuracy_per_radical), \
          '\n          validation[loss:%1.4f accuracy_per_kanji:%1.4f accuracy_per_radical:%1.4f]'
          % (validation_loss, validation_accuracy_per_kanji, validation_accuracy_per_radical))
    
np.save(save_dir + 'train_loss_list.npy', np.array(train_loss_list))
np.save(save_dir + 'train_accuracy_per_kanji_list.npy', np.array(train_accuracy_per_kanji_list))
np.save(save_dir + 'train_accuracy_per_radical_list.npy', np.array(train_accuracy_per_radical_list))

np.save(save_dir + 'validation_loss_list.npy', np.array(validation_loss_list))
np.save(save_dir + 'validation_accuracy_per_kanji_list.npy', np.array(validation_accuracy_per_kanji_list))
np.save(save_dir + 'validation_accuracy_per_radical_list.npy', np.array(validation_accuracy_per_radical_list))

torch.save(net.state_dict(), save_dir + 'weights.pth')