In [1]:
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
from torch.autograd import Variable
import cv2
import os
import numpy as np
import torch.nn.functional as F

# 生成测试数据

# 数据准备

In [3]:
# https://github.com/pytorch/vision/issues/81

# 将图片数据写入到LMDB数据库存中

import lmdb
import glob
# 创建数据库
# import lmdb  # install lmdb by "pip install lmdb"
# env = lmdb.open('./data/lmdb', map_size=511627776)
# env = lmdb.open('./data/lmdb', map_size=511627776)
# from genLineText import GenTextImage
def checkImageIsValid(imageBin):
    if imageBin is None:
        return False
    
    imageBuf = np.frombuffer(imageBin, dtype=np.uint8)
    img = cv2.imdecode(imageBuf, cv2.IMREAD_COLOR)
    if img is None:
        return False
    imgH, imgW = img.shape[0], img.shape[1]
    if imgH * imgW == 0:
        return False
    return True

def writeCache(env, cache):
    with env.begin(write=True) as txn:
        for k, v in cache.items():
            txn.put(k.encode(), v)
            
def createDataset(outputPath, imagePathList, labelList, lexiconList=None, checkValid=True):
    """
    Create LMDB dataset for CRNN training.
    ARGS:
        outputPath    : LMDB output path
        imagePathList : list of image path
        labelList     : list of corresponding groundtruth texts
        lexiconList   : (optional) list of lexicon lists
        checkValid    : if true, check the validity of every image
    """
    # print (len(imagePathList) , len(labelList))
    assert (len(imagePathList) == len(labelList))
    nSamples = len(imagePathList)
    
    env = lmdb.open(outputPath, map_size=511627776)

    cache = {}
    cnt = 1
    for i in range(nSamples):
        imagePath = imagePathList[i]
        label = labelList[i]
        if not os.path.exists(imagePath):
            print('%s does not exist' % imagePath)
            continue
        with open(imagePath, 'rb') as f:
            imageBin = f.read()
        
        if checkValid:
            if not checkImageIsValid(imageBin):
                print('%s is not a valid image' % imagePath)
                continue

        imageKey = 'image-%09d' % cnt
        labelKey = 'label-%09d' % cnt
        cache[imageKey] = imageBin
        cache[labelKey] = label.encode()
        if lexiconList:
            lexiconKey = 'lexicon-%09d' % cnt
            cache[lexiconKey] = ' '.join(lexiconList[i]).encode()
        if cnt % 1000 == 0:
            writeCache(env, cache)
            cache = {}
            print('Written %d / %d' % (cnt, nSamples))
        cnt += 1
    nSamples = cnt - 1
    cache['num-samples'] = str(nSamples).encode()
    writeCache(env, cache)
    print('Created dataset with %d samples' % nSamples)


def read_text(path):
    with open(path) as f:
        text = f.read()
    text = text.strip()

    return text

# outputPath = './data/lmdb/train'   # 训练数据
outputPath = 'D:\\PROJECT_TW\\git\\data\\example\\lmdb'   # 测试数据
path = 'D:\\PROJECT_TW\\git\\data\\example\\image\\*.jpg'
imagePathList = glob.glob(path)
imgLabelLists = []
for p in imagePathList:
    try:
        label = p.split('\\')[-1].split('_')[0]
        imgLabelLists.append((p,label))
    except:
        continue

imgLabelList = sorted(imgLabelLists, key=lambda x: len(x[1]))
imgPaths = [p[0] for p in imgLabelList]
txtLists = [p[1] for p in imgLabelList]
createDataset(outputPath, imgPaths, txtLists, lexiconList=None, checkValid=True)

Created dataset with 5 samples


In [2]:
# 加载数据
# 注意 dataset.alignCollate 将图片转成了灰度图，后期看怎么修改一下。
# collate_fn，是用来处理不同情况下的输入dataset的封装，一般采用默认即可，除非你自定义的数据读取输出非常少见
import common.dataset as dataset
path = 'D:\\PROJECT_TW\\git\\data\\example\\lmdb'
train_dataset = dataset.lmdbDataset(root=path, transform=dataset.resizeNormalize((32,32)))
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=2,
    shuffle=False,
    sampler=None
#     collate_fn=dataset.alignCollate(imgH=32, imgW=32, keep_ratio=False)
)

# dataset 方法resizeNormalize 中用了transforms.ToTensor 会将数据做归一化处理，在正式用的时候也需要将数据调用该方法做归一化处理

# 可参看 https://blog.csdn.net/victoriaw/article/details/72822005 数据预处理torchvision.transforms 


In [3]:
for idx,v in enumerate(train_loader):
#     print(idx,v)
#     print(v)
    pass
print(v)
print(idx)
print(np.array(v[0]).shape)

[tensor([[[[ 0.6196,  0.6078,  0.6039,  ...,  0.6039,  0.6000,  0.5725],
          [ 0.6275,  0.6235,  0.6353,  ...,  0.5961,  0.6235,  0.6118],
          [ 0.6157,  0.6392,  0.6235,  ...,  0.6118,  0.5922,  0.6196],
          ...,
          [ 0.5098,  0.5961,  0.6000,  ...,  0.5647,  0.5255,  0.5686],
          [ 0.5765,  0.5843,  0.5569,  ...,  0.6157,  0.5569,  0.6235],
          [ 0.5804,  0.5725,  0.5843,  ...,  0.6078,  0.5882,  0.5569]],

         [[ 0.6431,  0.6314,  0.6275,  ...,  0.6392,  0.6353,  0.6078],
          [ 0.6510,  0.6471,  0.6588,  ...,  0.6196,  0.6588,  0.6471],
          [ 0.6392,  0.6627,  0.6471,  ...,  0.6353,  0.6275,  0.6549],
          ...,
          [ 0.5333,  0.6196,  0.6235,  ...,  0.5765,  0.5373,  0.5804],
          [ 0.6000,  0.6078,  0.5804,  ...,  0.6275,  0.5686,  0.6353],
          [ 0.6039,  0.5961,  0.6078,  ...,  0.6196,  0.6000,  0.5686]],

         [[ 0.6627,  0.6510,  0.6471,  ...,  0.6549,  0.6510,  0.6235],
          [ 0.6706,  0.6667, 

# 数据模型

In [4]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv_1 = nn.Sequential(         # input shape (1, 32, 32)
            nn.Conv2d(
                in_channels=3,              # input height
                out_channels=16,            # n_filters
                kernel_size=5,              # filter size
                stride=1,                   # filter movement/step
                padding=2,                  # if want same width and length of this image after con2d, padding=(kernel_size-1)/2 if stride=1
            ),                              # output shape (16, 28, 28)
            nn.ReLU(),                      # activation
            nn.MaxPool2d(kernel_size=2),    # choose max value in 2x2 area, output shape (16, 16, 16)
        )        
        
        self.conv_2 = nn.Sequential(         # input shape (16, 16, 16)
            nn.Conv2d(16, 32, 5, 1, 2),     # output shape (32, 16, 16)
            nn.ReLU(),                      # activation
            nn.MaxPool2d(2),                # output shape (32, 8, 8)
        )        
        
        self.out = nn.Linear(32 * 8 * 8, 2)   # fully connected layer, output 2 classes
        
    def forward(self, x):
        x = self.conv_1(x)
        x = self.conv_2(x)
        x = x.view(x.size(0), -1)           # flatten the output of conv2 to (batch_size, 32 * 8 * 8)
        output = self.out(x)
        return output       # return x for visualization    
        

# 训练

In [5]:
# https://blog.csdn.net/tianweidadada/article/details/82630735   用 pytorch 进行分类（二分类，多分类）
net = CNN()
opitmizer = torch.optim.SGD(net.parameters(),lr=0.01)
loss_fun = nn.MSELoss() 
epoches = 1000


for i in range(epoches):
    for step, values in enumerate(train_loader):
        images = values[0]
        # 二分类，target 在做损失的时候需要（0，1），（1，0）这样的格式
        target = [ [1-int(x),int(x) ] for x in values[1]]
        target =  Variable(torch.FloatTensor(target)) #变成 1*2的 tensor
        preds = F.softmax(net(images),dim=1)
        loss = loss_fun(preds,target)
        opitmizer.zero_grad()
        loss.backward()
        opitmizer.step()
    if i%100 == 0:
        print('loss --> {}'.format(loss))


loss --> 0.27750226855278015
loss --> 0.005474110133945942
loss --> 0.0012895716354250908
loss --> 0.0006567016243934631
loss --> 0.0004239852132741362
loss --> 0.0003070026286877692
loss --> 0.0002378229983150959
loss --> 0.00019256227824371308
loss --> 0.0001609015162102878
loss --> 0.00013762549497187138


# 验证

In [28]:
import torchvision.transforms as transforms 
import time
path = 'D:\\PROJECT_TW\\git\\data\\example\\image\\1_2.jpg'
image = cv2.imread(path,cv2.IMREAD_COLOR)    
if image.shape[0] != 32 or image.shape[1] != 32:
    image = cv2.resize(image,(32,32))
# aa[np.newaxis,:].shape, newaxis增加维度
# np.r_[bb,bb].shape 添加行数据
# image = image[np.newaxis,:]
print(image.shape)

start_time = time.time()
for _ in range(1):

    imdata = transforms.ToTensor()(image)
    imdata = imdata.unsqueeze(0)
#     print(imdata.size())
    preds = net(imdata)
    preds = F.softmax(preds,dim=1)
    print(preds)
    
print('time --> {}'.format((time.time()-start_time)))


(32, 32, 3)
tensor([[ 0.0774,  0.9226]])
time --> 0.009992837905883789
