In [1]:
import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

In [2]:
import sys
sys.path.append('../tools/')
from net import HWNet

In [3]:
device = torch.device("cuda")



In [4]:
def train(net,data_loader,loss_fn,optimizer,device):
    size = len(data_loader.dataset)
    net.train()
    for batch,(X,y) in enumerate(data_loader):
        X = X.to(device)
        y = y.float().to(device)
        y_hat = net(X)
        loss = loss_fn(y_hat,y)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        if batch%10==0:
            loss = loss.item()
            current_batch = (batch+1)*len(X)
            print(f' loss:{loss:>.5f},[{current_batch:>5d}/{size:>5d}]',end='\r')
    print(f"Already trained a epoch,wait for test ...")

In [5]:
def test(net,data_loader,loss_fn,device):
    size = len(data_loader.dataset)
    num_batches = len(data_loader)
    net.eval()
    correct,test_loss = 0,0
    with torch.no_grad():
        for X,y in data_loader:
            X,y = X.to(device),y.float().to(device)
            y_hat = net(X)
            test_loss += loss_fn(y_hat,y).item()
            for i,j in zip(y,y_hat):
                if i.argmax() == j.argmax():
                    correct += 1
        test_loss /= num_batches
        correct /= size
        print(f"Test error:\n accuracy: {(100*correct):0.1f}%, avg loss: {test_loss:.5f}")

In [6]:
batch_size = 32

In [7]:
from dataset import get_data_loader
trn_loader,num_labels_trn,tsn_set = get_data_loader('../data_for_test/train',batch_size,True)
val_loader,tst_loader,num_labels_tst,val_set,tst_set = get_data_loader('../data_for_test/test',batch_size,True,False)

dict initialized successfully,there's 8 lables in the dict.
lenth of dataset is : 1918
dict initialized successfully,there's 8 lables in the dict.
lenth of dataset is : 477


In [8]:
len(val_set) + len(tst_set)

477

In [9]:
net = HWNet(num_labels_trn).to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=2e-3)

In [10]:
import time

In [11]:
epochs = 32
for t in range(epochs):
    print(f"Epoch {t+1}")
    start = time.time()
    train(net,trn_loader , loss_fn, optimizer,device)
    test(net,tst_loader, loss_fn,device)
    end = time.time()
    interval = (end-start)
    print(f"Time : {interval:.3f}s\n-------------------------------")
print('***End of train***')
test(net,val_loader,loss_fn,device)

Epoch 1
Already trained a epoch,wait for test ...
test error:
 accuracy: 13.4%, avg loss: 2.09940
time : 5.854s
-------------------------------
Epoch 2
Already trained a epoch,wait for test ...
test error:
 accuracy: 18.0%, avg loss: 2.04949
time : 5.255s
-------------------------------
Epoch 3
Already trained a epoch,wait for test ...
test error:
 accuracy: 27.6%, avg loss: 1.93983
time : 5.899s
-------------------------------
Epoch 4
Already trained a epoch,wait for test ...
test error:
 accuracy: 35.1%, avg loss: 1.85944
time : 6.533s
-------------------------------
Epoch 5
Already trained a epoch,wait for test ...
test error:
 accuracy: 43.1%, avg loss: 1.73367
time : 6.747s
-------------------------------
Epoch 6
Already trained a epoch,wait for test ...
test error:
 accuracy: 46.4%, avg loss: 1.59462
time : 6.791s
-------------------------------
Epoch 7
Already trained a epoch,wait for test ...
test error:
 accuracy: 50.6%, avg loss: 1.45908
time : 7.009s
------------------------

In [12]:
from dataset import HWVocab

In [13]:
vocab = HWVocab('../data_for_test/train')

dict initialized successfully,there's 8 lables in the dict.


In [14]:
from torchvision import transforms

to_img = transforms.ToPILImage()

In [15]:
transform = transforms.Compose([
            transforms.Resize((48,48)),
            transforms.ToTensor()
        ])

In [16]:
import PIL

img = PIL.Image.open('../data_for_test/train/皑/101.jpg')
feature = transform(img)

In [17]:
feature=feature.reshape(1,1,48,48).to(device)

In [18]:
feature.shape

torch.Size([1, 1, 48, 48])

In [19]:
net(feature).argmax(1).item()

6

In [20]:
vocab.lable2char(net(feature).argmax(1).item())

'皑'