In [1]:
import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

In [2]:
import sys
sys.path.append('../tools/')
from net import HWNet

In [3]:
device = torch.device("cuda")



In [4]:
def train(net,data_loader,loss_fn,optimizer,device):
    size = len(data_loader.dataset)
    net.train()
    for batch,(X,y) in enumerate(data_loader):
        X = X.to(device)
        y = y.float().to(device)
        y_hat = net(X)
        loss = loss_fn(y_hat,y)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        if batch%2==0:
            loss = loss.item()
            current_batch = (batch+1)*len(X)
            print(f' loss:{loss:>.5f},[{current_batch:>5d}/{size:>5d}]',end='\r')
    print(f"Already trained an epoch,waiting for validate ...")

In [5]:
def test(net,data_loader,loss_fn,device):
    size = len(data_loader.dataset)
    num_batches = len(data_loader)
    net.eval()
    correct,test_loss = 0,0
    with torch.no_grad():
        for X,y in data_loader:
            X,y = X.to(device),y.float().to(device)
            y_hat = net(X)
            test_loss += loss_fn(y_hat,y).item()
            for i,j in zip(y,y_hat):
                if i.argmax() == j.argmax():
                    correct += 1
        test_loss /= num_batches
        correct /= size
        print(f"Test error:\n accuracy: {(100*correct):0.1f}%, avg loss: {test_loss:.5f}")

In [6]:
batch_size = 16

In [7]:
from dataset import get_data_loader
trn_loader,num_labels_trn,tsn_set = get_data_loader('../data_for_test/train',batch_size,True)
val_loader,tst_loader,num_labels_tst,val_set,tst_set = get_data_loader('../data_for_test/test',batch_size,True,False)

Dict initialized successfully, there are 8 labels in the dict.
Length of dataset is: 1918
Dict initialized successfully, there are 8 labels in the dict.
Length of dataset is: 477


In [8]:
len(val_set) + len(tst_set)

477

In [9]:
net = HWNet(num_labels_trn).to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=2e-3)

In [10]:
import time

In [11]:
epochs = 12
for t in range(epochs):
    print(f"Epoch {t+1}")
    start = time.time()
    train(net,trn_loader , loss_fn, optimizer,device)
    test(net,tst_loader, loss_fn,device)
    torch.save(net.state_dict(), 'handwriting.params')
    end = time.time()
    interval = (end-start)
    print(f"Time : {interval:.3f}s\n-------------------------------")
print('***End of train***')
test(net,val_loader,loss_fn,device)

Epoch 1
Already trained an epoch,waiting for validate ...
Test error:
 accuracy: 23.4%, avg loss: 2.01960
Time : 7.052s
-------------------------------
Epoch 2
Already trained an epoch,waiting for validate ...
Test error:
 accuracy: 33.9%, avg loss: 1.82290
Time : 4.389s
-------------------------------
Epoch 3
Already trained an epoch,waiting for validate ...
Test error:
 accuracy: 59.4%, avg loss: 1.48348
Time : 4.384s
-------------------------------
Epoch 4
Already trained an epoch,waiting for validate ...
Test error:
 accuracy: 64.4%, avg loss: 1.20101
Time : 4.360s
-------------------------------
Epoch 5
Already trained an epoch,waiting for validate ...
Test error:
 accuracy: 80.3%, avg loss: 0.82195
Time : 4.433s
-------------------------------
Epoch 6
Already trained an epoch,waiting for validate ...
Test error:
 accuracy: 82.4%, avg loss: 0.64165
Time : 4.529s
-------------------------------
Epoch 7
Already trained an epoch,waiting for validate ...
Test error:
 accuracy: 83.3%, 

In [12]:
vocab = tsn_set.vocab

In [13]:
from torchvision import transforms

to_img = transforms.ToPILImage()

In [14]:
transform = transforms.Compose([
            transforms.Resize((48,48)),
            transforms.ToTensor()
        ])

In [15]:
import PIL

img = PIL.Image.open('../data_for_test/阿/101.jpg')
feature = transform(img)

In [16]:
feature=feature.reshape(1,1,48,48).to(device)

In [17]:
feature.shape

torch.Size([1, 1, 48, 48])

In [18]:
net(feature).argmax(1).item()

7

In [19]:
vocab.label2char(net(feature).argmax(1).item())

'阿'

In [20]:
torch.save(net.state_dict(),'model_parameters.pth')

In [21]:
new_net = HWNet(8)

In [22]:
new_net.load_state_dict(torch.load('handwriting.params'))

<All keys matched successfully>

In [23]:
new_net.eval()

HWNet(
  (conv1x1): Conv2d(1, 3, kernel_size=(1, 1), stride=(1, 1), bias=False)
  (res101): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inpl

In [24]:
img = PIL.Image.open('../data_for_test/train/阿//101.jpg')
feature = transform(img)

In [25]:
feature=feature.reshape(1,1,48,48)

In [26]:
new_net(feature).argmax(1).item()

7