In [1]:
import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

In [2]:
import sys
sys.path.append('../tools/')
from net import HWNet

In [3]:
device = torch.device("cuda")
device

device(type='cuda')

In [4]:
def train(net, data_loader, loss_fn, optimizer, device):
    size = len(data_loader.dataset)
    net.train()
    for batch, (X, y) in enumerate(data_loader):
        X = X.to(device)
        y = y.to(device,dtype=torch.float)
        y_hat = net(X)
        loss = loss_fn(y_hat, y)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        if batch % 2 == 0:
            loss = loss.item()
            current_batch = (batch + 1) * len(X)
            print(f' loss:{loss:>.5f},[{current_batch:>5d}/{size:>5d}]', end='\r')
    print(f"Already trained an epoch, waiting for validation ...")

In [5]:
def test(net,data_loader,loss_fn,device):
    size = len(data_loader.dataset)
    num_batches = len(data_loader)
    net.eval()
    correct,test_loss = 0,0
    with torch.no_grad():
        for X,y in data_loader:
            X,y = X.to(device),y.to(device,dtype=torch.float)
            y_hat = net(X)
            test_loss += loss_fn(y_hat,y).item()
            for i,j in zip(y,y_hat):
                if i.argmax() == j.argmax():
                    correct += 1
        test_loss /= num_batches
        correct /= size
        print(f"Test error:\n accuracy: {(100*correct):0.1f}%, avg loss: {test_loss:.5f}")

In [6]:
batch_size = 256

In [7]:
from dataset import get_data_loader
trn_loader,num_labels_trn,tsn_set = get_data_loader('../data/train',batch_size,True)
val_loader,tst_loader,num_labels_tst,val_set,tst_set = get_data_loader('../data/test',batch_size,True,False)

Dict initialized successfully, there are 3755 labels in the dict.
Length of dataset is: 897725
Dict initialized successfully, there are 3755 labels in the dict.
Length of dataset is: 223716


In [8]:
net = HWNet(num_labels_trn).to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.008)

In [9]:
import time

In [10]:
epochs = 22
for t in range(epochs):
    print(f"Epoch {t+1}")
    start = time.time()
    train(net,trn_loader , loss_fn, optimizer,device)
    test(net,tst_loader, loss_fn,device)
    torch.save(net.state_dict(), f'handwriting_epoch_{t+1}.pth')
    end = time.time()
    interval = (end-start)
    print(f"Time : {interval:.3f}s\n-------------------------------")
print('***End of train***')
test(net,val_loader,loss_fn,device)

Epoch 1
Already trained an epoch, waiting for validation ...
Test error:
 accuracy: 0.1%, avg loss: 8.21876
Time : 1387.973s
-------------------------------
Epoch 2
Already trained an epoch, waiting for validation ...
Test error:
 accuracy: 3.2%, avg loss: 6.88806
Time : 1472.380s
-------------------------------
Epoch 3
Already trained an epoch, waiting for validation ...
Test error:
 accuracy: 33.3%, avg loss: 3.38476
Time : 1338.211s
-------------------------------
Epoch 4
Already trained an epoch, waiting for validation ...
Test error:
 accuracy: 65.8%, avg loss: 1.60724
Time : 1196.743s
-------------------------------
Epoch 5
Already trained an epoch, waiting for validation ...
Test error:
 accuracy: 77.4%, avg loss: 1.01014
Time : 1191.171s
-------------------------------
Epoch 6
Already trained an epoch, waiting for validation ...
Test error:
 accuracy: 82.0%, avg loss: 0.77317
Time : 37239.070s
-------------------------------
Epoch 7
Already trained an epoch, waiting for validat

KeyboardInterrupt: 

In [11]:
vocab = tsn_set.vocab


In [12]:
from torchvision import transforms

to_img = transforms.ToPILImage()
transform = transforms.Compose([
            transforms.Resize((48,48)),
            transforms.ToTensor()
        ])

In [20]:
import PIL

img = PIL.Image.open('../data_for_test/test/阿//11.jpg')
feature = transform(img)

In [21]:
feature=feature.reshape(1,1,48,48).to(device)

In [22]:
vocab.label2char(net(feature).argmax(1).item())

'沿'