In [4]:
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
from torch.autograd import Variable
import cv2
import os
import matplotlib.pyplot as plt
import numpy as np
import torch.nn.functional as F
from PIL import Image
plt.rcParams['figure.figsize'] = 15, 10

# 生成测试数据

In [18]:
DEBUG = False
# 取出文本框
def get_text_rect(image):
    img = image.copy()

    img = cv2.cvtColor(img,cv2.COLOR_BGR2GRAY)
    img = cv2.GaussianBlur(img,(21,7),0)


    img = cv2.adaptiveThreshold(img,255,cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY_INV,11,2)
    # plt.imshow(img,'gray')
    # plt.show()

    element = cv2.getStructuringElement(cv2.MORPH_RECT, (5, 3))
    img = cv2.erode(img, element)
    # plt.imshow(img,'gray')
    # plt.show()

    img = cv2.dilate(img, element)
    if DEBUG:
        plt.imshow(img,'gray')
        plt.show()

    img = cv2.blur(img, tuple((5, 5)))
    if DEBUG:
        plt.imshow(img,'gray')
        plt.show()
    
    # 去黑边
    vertical = np.copy(img)
    cols = vertical.shape[1]
    vertical_size = int(cols / 40)
    verticalStructure = cv2.getStructuringElement(cv2.MORPH_RECT, (1, vertical_size))
    vertical = cv2.erode(vertical, verticalStructure)
    vertical = cv2.dilate(vertical, verticalStructure)

    vertical = cv2.blur(vertical, tuple((3, 3)))
    element = cv2.getStructuringElement(cv2.MORPH_RECT, tuple((17, 17)), (-1, -1))
    vertical = cv2.dilate(vertical, element, iterations=1)
    vertical = cv2.bitwise_not(vertical)
    img = cv2.bitwise_and(img,vertical)    
    if DEBUG:
        plt.imshow(img,'gray')
        plt.show()    

    element = cv2.getStructuringElement(cv2.MORPH_RECT, tuple((13, 5)), (-1, -1))
    img = cv2.dilate(img, element, iterations=2)
    if DEBUG:
        plt.imshow(img,'gray')
        plt.show()
    
    nimg, contours, hierarchy = cv2.findContours(img, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
    cnts = []
    # print(contours)
    for cnt in contours:
        x,y,w,h = cv2.boundingRect(cnt)
        if w > 150 and w < 400:
            cnts.append([y,x,h,w])
    return cnts

def check_in_areas(rect, rect_lists):
    in_area = False
    y,x,h,w = rect
    for item in rect_lists:
        iy,ix,ih,iw = item
        if y>iy and x>ix and (y+h) < (iy + ih) and (x+w) <  (ix + iw):
            in_area = True
    return in_area

# 取出字符文档框
def get_char_rect(image):
    img = image.copy()
    img = cv2.cvtColor(img,cv2.COLOR_BGR2GRAY)
    # plt.imshow(img,'gray')
    # plt.show()
    img = cv2.GaussianBlur(img,(3,3),0)
    img = cv2.adaptiveThreshold(img,255,cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY_INV,11,2)
    # plt.imshow(img,'gray')
    # plt.show()

    start_time = time.time()

    # 去燥点
    element = cv2.getStructuringElement(cv2.MORPH_RECT, (2, 2))
    img = cv2.erode(img, element)
    img = cv2.dilate(img, element,iterations=3)
    # plt.imshow(img,'gray')
    # plt.show()


    nimg, contours, hierarchy = cv2.findContours(img, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
    cnts = []
    for cnt in contours:
        if cv2.contourArea(cnt) > 250:    
            
            rect = cv2.boundingRect(cnt)
#             if rect[2] >= 15:
            x,y,w,h = rect
            # 检查新增的区域是否在现有区域里面
            if not check_in_areas([y,x,h,w], cnts):
                cnts.append([y,x,h,w])

    return cnts

# 对文本框位置列表信息按Y、X坐标点进行排序
# rect_list: 边框（y,x,h,w)位置列表， 类型 list
# 返回排序号的rect_list, 类型list
def text_rect_sort(rect_list):
    rects = np.array(rect_list)
    print('y mean --> {}'.format(np.mean(rects[:,0])))
    print('y std --> {}'.format(np.std(rects[:,0])))
    
    if np.std(rects[:,0]) < 10:
        # std 小于10 表明数据在一行，不需再作特殊处理
        ridx = np.lexsort([rects[:,1]])
        rects = rects[ridx]    
    else:
        ridx = np.lexsort([rects[:,1],rects[:,0]])
        rects = rects[ridx]
        # 对边框按坐标点Y进行区间划分
        rect_area = (rects[:, 0] - rects[0][0])/ np.std(rects[:, 0],ddof=1)
        rect_area = np.around(rect_area,1) 

        rect_current_class = 0
        rect_class = [rect_current_class]
        for inx in range(1, len(rect_area)):
            if abs(rect_area[inx] - rect_area[inx-1]) > 0.5:
                rect_current_class = rect_current_class + 1
            rect_class.append(rect_current_class)        

        rects = np.insert(rects,4,values=np.array(rect_class),axis=1)
        ridx = np.lexsort([rects[:,1],rects[:,4]])
        rects = rects[ridx]
    return rects[:,0:4].tolist()

def answer_bind_rect(rect_list, answer_list):
    rect_len = len(rect_list)
    answer_len = len(answer_list)
    answers = {answer_list[answer_len-idx-1]:rect_list[rect_len-idx -1] for idx in range(answer_len)}
    return answers
    

In [15]:
# 读取数据
from common.examdetect import EXDetect
clip_path = 'D:\\PROJECT_TW\\git\\data\\example\\image\\clip\\'

# moible define
qlists = [0,1,2,4]
qdefines = {0:[1, 2, 3, 4, 5, 6],1:[7,8,9,10,11,12,13,14],2:[15, 16, 17, 18, 19, 20],
            4:[26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45]}
answerdefines = { 
    1:'A-CDEF',2:'ABC-EF',3:'ABC-EF',4:'AB-DEF',5:'ABCDE-',6:'-BCDEF',7:'-BCD',8:'ABC-',
    9:'AB-D',10:'AB-D',11:'-BCD',12:'AB-D',13:'AB-D',14:'-BCD',15:'T-',16:'-F',17:'T-',
    18:'-F',19:'T-',20:'T-',26:'-BCD',27:'A-CD',28:'AB-D',29:'-BCD',30:'ABC-',31:'AB-D',
    32:'-BCD',33:'AB-D',34:'A-CD',35:'A-CD',36:'ABC-',37:'A-CD',38:'-BCD',39:'-BCD',40:'A-CD',
    41:'-BCD',42:'AB-D',43:'A-CD',44:'A-CD',45:'AB-D'
}


# mobile3 define
# qlists = [0,1]
# qdefines = {0:[82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93],1:[94,  95,  96,  97,  98,  99, 100, 101, 102, 103, 104, 105]}
# answerdefines = {82:'-BCDE',83:'-BCDE',84:'ABCD-',85:'-BCDE',86:'A-CDE',87:'ABC-E',
#                  88:'A-CDE',89:'A-CDE',90:'-BCDE',91:'ABCD-',92:'ABC-E',93:'ABCD-',
#                  94:'--CDE',95:'-BCD-',96:'AB-D-',97:'A-CDE',98:'AB-DE',99:'A--D-',
#                  100:'-BCDE',101:'A-CD-',102:'AB-DE',103:'ABC-E',104:'A--DE',105:'-BCDE'
#                 }


# filename = '乐视2max'
path = 'D:\\PROJECT_TW\\examrc\\data\\exam\\mobile\\'

file_list = os.listdir(path)
file_list = [x for x in file_list if x.endswith('jpg')]

for item in file_list:
    filename = item.split('.')[0]
    print('------------ 开始处理 {} -----------------'.format(filename))
    with  open('{}{}'.format(path,item),'rb') as f:
        imgb = f.read()
    image_bin = np.frombuffer(imgb, np.uint8)    
    exd = EXDetect(id=1, image=image_bin, mobile_type='default')
    # 分隔试卷，取出包含选择题的项
    regions = exd.clip_exam()

    for qidx in qlists:
        y,x,h,w = regions[qidx]
        rect_list = get_text_rect(exd.origin_image[y:y+h,x:x+w])
        cv2.imwrite('d:\\clip_11.jpg',exd.origin_image[y:y+h,x:x+w])
        rect_list = text_rect_sort(rect_list)
        print(rect_list)
        answers = answer_bind_rect(rect_list, qdefines[qidx])
        print(answers)
        for ans_no in answers.keys():
            ay,ax,ah,aw = answers[ans_no]
            rimg = exd.origin_image[y+ay:y+ay+ah, x+ax:x+ax+aw]
            with open('{}{}_{}_{}.jpg'.format(clip_path,filename,ans_no,answerdefines[ans_no]),'wb') as f:
                imgbin = cv2.imencode('.jpg',rimg)[1]
                f.write(imgbin)
            
    


------------ 开始处理 360n4s -----------------
[[2282, 846, 31, 412], [2287, 378, 30, 416], [2365, 844, 30, 413], [2370, 375, 30, 417], [2446, 842, 31, 414], [2452, 372, 31, 418], [2529, 840, 31, 416], [2535, 370, 32, 418], [2613, 839, 30, 415], [2618, 368, 33, 419]]
[[1120, 312, 298, 1816], [1418, 309, 381, 1816], [1799, 299, 300, 1825], [2099, 291, 553, 1834], [2652, 275, 483, 1853]]
y mean --> 183.5
y std --> 36.16973873281365
[[153, 88, 47, 353], [157, 495, 46, 352], [160, 901, 44, 350], [162, 1306, 44, 343], [234, 89, 46, 351], [235, 494, 47, 353]]
{6: [235, 494, 47, 353], 5: [234, 89, 46, 351], 4: [162, 1306, 44, 343], 3: [160, 901, 44, 350], 2: [157, 495, 46, 352], 1: [153, 88, 47, 353]}
y mean --> 264.75
y std --> 38.93825240043523
[[235, 84, 46, 265], [235, 404, 45, 264], [235, 724, 47, 262], [233, 1039, 48, 262], [235, 1353, 48, 260], [317, 83, 44, 264], [314, 402, 46, 265], [314, 721, 46, 263]]
{14: [314, 721, 46, 263], 13: [314, 402, 46, 265], 12: [317, 83, 44, 264], 11: [235, 

y mean --> 215.0
y std --> 1.8257418583505538
[[218, 14, 48, 169], [217, 243, 50, 172], [214, 476, 51, 173], [214, 708, 49, 167], [213, 935, 50, 169], [214, 1164, 50, 171]]
{20: [214, 1164, 50, 171], 19: [213, 935, 50, 169], 18: [214, 708, 49, 167], 17: [214, 476, 51, 173], 16: [217, 243, 50, 172], 15: [218, 14, 48, 169]}
y mean --> 254.66666666666666
y std --> 94.25817298570546
[[68, 777, 50, 178], [151, 78, 45, 251], [150, 376, 46, 252], [149, 675, 45, 252], [146, 974, 45, 251], [143, 1273, 48, 250], [229, 76, 45, 252], [226, 374, 47, 253], [224, 673, 47, 254], [225, 948, 44, 277], [222, 1273, 45, 248], [307, 75, 44, 250], [304, 372, 46, 254], [304, 673, 44, 253], [302, 972, 44, 253], [299, 1272, 47, 251], [385, 71, 47, 254], [371, 370, 58, 255], [382, 671, 45, 254], [379, 971, 46, 254], [382, 1273, 45, 252]]
{45: [382, 1273, 45, 252], 44: [379, 971, 46, 254], 43: [382, 671, 45, 254], 42: [371, 370, 58, 255], 41: [385, 71, 47, 254], 40: [299, 1272, 47, 251], 39: [302, 972, 44, 253], 

------------ 开始处理 华为mate7 -----------------
[[2170, 987, 29, 388], [2172, 551, 29, 390], [2248, 986, 29, 389], [2250, 549, 29, 391], [2325, 984, 29, 390], [2327, 547, 30, 391], [2403, 983, 29, 390], [2405, 544, 30, 392], [2481, 981, 30, 391], [2484, 542, 30, 393]]
[[1058, 486, 289, 1722], [1347, 482, 364, 1721], [1711, 475, 288, 1722], [1999, 469, 521, 1727], [2520, 454, 456, 1741]]
y mean --> 180.83333333333334
y std --> 31.184486884061798
[[152, 82, 47, 334], [159, 464, 45, 335], [162, 848, 42, 331], [163, 1227, 43, 333], [229, 79, 46, 336], [220, 463, 59, 334]]
{6: [220, 463, 59, 334], 5: [229, 79, 46, 336], 4: [163, 1227, 43, 333], 3: [162, 848, 42, 331], 2: [159, 464, 45, 335], 1: [152, 82, 47, 334]}
y mean --> 255.75
y std --> 36.36189626518397
[[225, 80, 45, 251], [228, 380, 44, 249], [229, 680, 45, 248], [227, 975, 47, 251], [229, 1273, 47, 250], [303, 77, 43, 251], [302, 377, 45, 252], [303, 676, 44, 251]]
{14: [303, 676, 44, 251], 13: [302, 377, 45, 252], 12: [303, 77, 43, 25

KeyboardInterrupt: 

In [19]:
# 从文本框提取字符
import time
clip_path = 'D:\\PROJECT_TW\\git\\data\\example\\image\\clip\\'
clip_save_path = 'D:\\PROJECT_TW\\git\\data\\example\\image\\char\\'
clip_file_name = '23_A-CD'
b_number = 0
n_number = 0

files = os.listdir(clip_path)
# files = ['360n4s_17_T-.jpg',]
start_time = time.time()
for file in files:
    clip_file_name = file.split('.')[0]
    print(clip_file_name)
#     image = cv2.imread('{}{}.jpg'.format(clip_path,clip_file_name),cv2.IMREAD_COLOR)
    with open('{}{}.jpg'.format(clip_path,clip_file_name),'rb') as ff:
        imgbin = ff.read()
    
    imgbin = np.frombuffer(imgbin, np.uint8)    
    image = cv2.imdecode(imgbin,cv2.IMREAD_COLOR)
    
#     print(image.shape)
    cnts = get_char_rect(image)
    cnts_array = np.array(cnts)
    cnts_inx = np.lexsort([cnts_array[:,1]])
    cnts_array = cnts_array[cnts_inx]
    cnts = cnts_array.tolist()
    if len(cnts) < (len(clip_file_name.split('_')[-1])):
        raise Exception('{} 切割失败 切割长度 {} 最少长度 {}'.format(clip_file_name,len(cnts), 
                                                         (len(clip_file_name.split('_')[-1]) + 1)))
    
    binx = [idx for idx,x in enumerate(clip_file_name.split('_')[-1]) if x =='-']
    # 取坐标后几位
    cnts = cnts[len(cnts)-len(clip_file_name.split('_')[-1]):]
    for i, (y,x,h,w) in enumerate(cnts):
        crop_img = image[y:y+h,x:x+w]
        if i in binx:
            cv2.imwrite('{}1_{}.jpg'.format(clip_save_path,b_number),crop_img)
            b_number = b_number + 1
        else:
            cv2.imwrite('{}0_{}.jpg'.format(clip_save_path,n_number),crop_img)
            n_number = n_number + 1
        

print('total files {}  time {:.4f}'.format(len(files),(time.time() - start_time)))
    
        

360n4s_10_AB-D
360n4s_11_-BCD
360n4s_12_AB-D
360n4s_13_AB-D
360n4s_14_-BCD
360n4s_15_T-
360n4s_16_-F
360n4s_17_T-
360n4s_18_-F
360n4s_19_T-
360n4s_1_A-CDEF
360n4s_20_T-
360n4s_26_-BCD
360n4s_27_A-CD
360n4s_28_AB-D
360n4s_29_-BCD
360n4s_2_ABC-EF
360n4s_30_ABC-
360n4s_31_AB-D
360n4s_32_-BCD
360n4s_33_AB-D
360n4s_34_A-CD
360n4s_35_A-CD
360n4s_36_ABC-
360n4s_37_A-CD
360n4s_38_-BCD
360n4s_39_-BCD
360n4s_3_ABC-EF
360n4s_40_A-CD
360n4s_41_-BCD
360n4s_42_AB-D
360n4s_43_A-CD
360n4s_44_A-CD
360n4s_45_AB-D
360n4s_4_AB-DEF
360n4s_5_ABCDE-
360n4s_6_-BCDEF
360n4s_7_-BCD
360n4s_8_ABC-
360n4s_9_AB-D
iphone6p_10_AB-D
iphone6p_11_-BCD
iphone6p_12_AB-D
iphone6p_13_AB-D
iphone6p_14_-BCD
iphone6p_15_T-
iphone6p_16_-F
iphone6p_17_T-
iphone6p_18_-F
iphone6p_19_T-
iphone6p_1_A-CDEF
iphone6p_20_T-
iphone6p_26_-BCD
iphone6p_27_A-CD
iphone6p_28_AB-D
iphone6p_29_-BCD
iphone6p_2_ABC-EF
iphone6p_30_ABC-
iphone6p_31_AB-D
iphone6p_32_-BCD
iphone6p_33_AB-D
iphone6p_34_A-CD
iphone6p_35_A-CD
iphone6p_36_ABC-
iphone6p_37

华为荣耀V8_3_ABC-EF
华为荣耀V8_40_A-CD
华为荣耀V8_41_-BCD
华为荣耀V8_42_AB-D
华为荣耀V8_43_A-CD
华为荣耀V8_44_A-CD
华为荣耀V8_45_AB-D
华为荣耀V8_4_AB-DEF
华为荣耀V8_5_ABCDE-
华为荣耀V8_6_-BCDEF
华为荣耀V8_7_-BCD
华为荣耀V8_8_ABC-
华为荣耀V8_9_AB-D
小米5_100_-BCDE
小米5_101_A-CD-
小米5_102_AB-DE
小米5_103_ABC-E
小米5_104_A--DE
小米5_105_-BCDE
小米5_10_AB-D
小米5_11_-BCD
小米5_12_AB-D
小米5_13_AB-D
小米5_14_-BCD
小米5_15_T-
小米5_16_-F
小米5_17_T-
小米5_18_-F
小米5_19_T-
小米5_1_A-CDEF
小米5_20_T-
小米5_26_-BCD
小米5_27_A-CD
小米5_28_AB-D
小米5_29_-BCD
小米5_2_ABC-EF
小米5_30_ABC-
小米5_31_AB-D
小米5_32_-BCD
小米5_33_AB-D
小米5_34_A-CD
小米5_35_A-CD
小米5_36_ABC-
小米5_37_A-CD
小米5_38_-BCD
小米5_39_-BCD
小米5_3_ABC-EF
小米5_40_A-CD
小米5_41_-BCD
小米5_42_AB-D
小米5_43_A-CD
小米5_44_A-CD
小米5_45_AB-D
小米5_4_AB-DEF
小米5_5_ABCDE-
小米5_6_-BCDEF
小米5_7_-BCD
小米5_82_-BCDE
小米5_83_-BCDE
小米5_84_ABCD-
小米5_85_-BCDE
小米5_86_A-CDE
小米5_87_ABC-E
小米5_88_A-CDE
小米5_89_A-CDE
小米5_8_ABC-
小米5_90_-BCDE
小米5_91_ABCD-
小米5_92_ABC-E
小米5_93_ABCD-
小米5_94_--CDE
小米5_95_-BCD-
小米5_96_AB-D-
小米5_97_A-CDE
小米5_98_AB-DE
小米5_99_A--D-
小米5_9_AB-D
小米6_10_AB-D
小米6

# 数据准备

In [20]:
# https://github.com/pytorch/vision/issues/81

# 将图片数据写入到LMDB数据库存中

import lmdb
import glob
# 创建数据库
# import lmdb  # install lmdb by "pip install lmdb"
# env = lmdb.open('./data/lmdb', map_size=511627776)
# env = lmdb.open('./data/lmdb', map_size=511627776)
# from genLineText import GenTextImage
def checkImageIsValid(imageBin):
    if imageBin is None:
        return False
    
    imageBuf = np.frombuffer(imageBin, dtype=np.uint8)
    img = cv2.imdecode(imageBuf, cv2.IMREAD_COLOR)
    if img is None:
        return False
    imgH, imgW = img.shape[0], img.shape[1]
    if imgH * imgW == 0:
        return False
    return True

def writeCache(env, cache):
    with env.begin(write=True) as txn:
        for k, v in cache.items():
            txn.put(k.encode(), v)
            
def createDataset(outputPath, imagePathList, labelList, lexiconList=None, checkValid=True):
    """
    Create LMDB dataset for CRNN training.
    ARGS:
        outputPath    : LMDB output path
        imagePathList : list of image path
        labelList     : list of corresponding groundtruth texts
        lexiconList   : (optional) list of lexicon lists
        checkValid    : if true, check the validity of every image
    """
    # print (len(imagePathList) , len(labelList))
    assert (len(imagePathList) == len(labelList))
    nSamples = len(imagePathList)
    
    env = lmdb.open(outputPath, map_size=511627776)

    cache = {}
    cnt = 1
    for i in range(nSamples):
        imagePath = imagePathList[i]
        label = labelList[i]
        if not os.path.exists(imagePath):
            print('%s does not exist' % imagePath)
            continue
        with open(imagePath, 'rb') as f:
            imageBin = f.read()
        
        if checkValid:
            if not checkImageIsValid(imageBin):
                print('%s is not a valid image' % imagePath)
                continue

        imageKey = 'image-%09d' % cnt
        labelKey = 'label-%09d' % cnt
        cache[imageKey] = imageBin
        cache[labelKey] = label.encode()
        if lexiconList:
            lexiconKey = 'lexicon-%09d' % cnt
            cache[lexiconKey] = ' '.join(lexiconList[i]).encode()
        if cnt % 1000 == 0:
            writeCache(env, cache)
            cache = {}
            print('Written %d / %d' % (cnt, nSamples))
        cnt += 1
    nSamples = cnt - 1
    cache['num-samples'] = str(nSamples).encode()
    writeCache(env, cache)
    print('Created dataset with %d samples' % nSamples)


def read_text(path):
    with open(path) as f:
        text = f.read()
    text = text.strip()

    return text

# outputPath = './data/lmdb/train'   # 训练数据
outputPath = 'D:\\PROJECT_TW\\git\\data\\example\\lmdb'   # 测试数据
path = 'D:\\PROJECT_TW\\git\\data\\example\\image\\char\\*.jpg'
imagePathList = glob.glob(path)
imgLabelLists = []
for p in imagePathList:
    try:
        label = p.split('\\')[-1].split('_')[0]
        imgLabelLists.append((p,label))
    except:
        continue

imgLabelList = sorted(imgLabelLists, key=lambda x: len(x[1]))
imgPaths = [p[0] for p in imgLabelList]
txtLists = [p[1] for p in imgLabelList]
createDataset(outputPath, imgPaths, txtLists, lexiconList=None, checkValid=True)

Written 1000 / 3158
Written 2000 / 3158
Written 3000 / 3158
Created dataset with 3158 samples


In [28]:
# 加载数据
# 注意 dataset.alignCollate 将图片转成了灰度图，后期看怎么修改一下。
# collate_fn，是用来处理不同情况下的输入dataset的封装，一般采用默认即可，除非你自定义的数据读取输出非常少见
import common.dataset as dataset
path = 'D:\\PROJECT_TW\\git\\data\\example\\lmdb'
train_dataset = dataset.lmdbDataset(root=path, transform=dataset.resizeNormalize((32,32)))
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=100,
    shuffle=True,
    sampler=None
#     collate_fn=dataset.alignCollate(imgH=32, imgW=32, keep_ratio=False)
)

# dataset 方法resizeNormalize 中用了transforms.ToTensor 会将数据做归一化处理，在正式用的时候也需要将数据调用该方法做归一化处理

# 可参看 https://blog.csdn.net/victoriaw/article/details/72822005 数据预处理torchvision.transforms 


In [22]:
for idx,v in enumerate(train_loader):
#     print(idx,v)
#     print(v)
    pass
# print(v)
print(idx)
# print(np.array(v[0]).shape)

1578


# 数据模型

In [23]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv_1 = nn.Sequential(         # input shape (1, 32, 32)
            nn.Conv2d(
                in_channels=3,              # input height
                out_channels=16,            # n_filters
                kernel_size=5,              # filter size
                stride=1,                   # filter movement/step
                padding=2,                  # if want same width and length of this image after con2d, padding=(kernel_size-1)/2 if stride=1
            ),                              # output shape (16, 28, 28)
            nn.ReLU(),                      # activation
            nn.MaxPool2d(kernel_size=2),    # choose max value in 2x2 area, output shape (16, 16, 16)
        )        
        
        self.conv_2 = nn.Sequential(         # input shape (16, 16, 16)
            nn.Conv2d(16, 32, 5, 1, 2),     # output shape (32, 16, 16)
            nn.ReLU(),                      # activation
            nn.MaxPool2d(2),                # output shape (32, 8, 8)
        )        
        
        self.out = nn.Linear(32 * 8 * 8, 2)   # fully connected layer, output 2 classes
        
    def forward(self, x):
        x = self.conv_1(x)
        x = self.conv_2(x)
        x = x.view(x.size(0), -1)           # flatten the output of conv2 to (batch_size, 32 * 8 * 8)
        output = self.out(x)
        return output       # return x for visualization    
        

# 训练

In [29]:
# https://blog.csdn.net/tianweidadada/article/details/82630735   用 pytorch 进行分类（二分类，多分类）
net = CNN()
opitmizer = torch.optim.SGD(net.parameters(),lr=0.01)
loss_fun = nn.MSELoss() 
epoches = 1000


for i in range(epoches):
    for step, values in enumerate(train_loader):
        images = values[0]
        # 二分类，target 在做损失的时候需要（0，1），（1，0）这样的格式
        target = [ [1-int(x),int(x) ] for x in values[1]]
        target =  Variable(torch.FloatTensor(target)) #变成 1*2的 tensor
        preds = F.softmax(net(images),dim=1)
        loss = loss_fun(preds,target)
        opitmizer.zero_grad()
        loss.backward()
        opitmizer.step()
        if step % 100 == 0:
            print('{} --> {}  loss : {:.8f}'.format(i, step, loss))
    if i%100 == 0:
        print('loss --> {:.4f}'.format(loss))


0 --> 0  loss : 0.27270165
loss --> 0.1460
1 --> 0  loss : 0.14705318
2 --> 0  loss : 0.09346888
3 --> 0  loss : 0.07818905
4 --> 0  loss : 0.08490984
5 --> 0  loss : 0.04774324
6 --> 0  loss : 0.03185656
7 --> 0  loss : 0.01841942
8 --> 0  loss : 0.01615721
9 --> 0  loss : 0.01287981
10 --> 0  loss : 0.00671923
11 --> 0  loss : 0.01088424
12 --> 0  loss : 0.00874346
13 --> 0  loss : 0.00457397
14 --> 0  loss : 0.00618647
15 --> 0  loss : 0.00769486
16 --> 0  loss : 0.00346732
17 --> 0  loss : 0.00188480
18 --> 0  loss : 0.00185975
19 --> 0  loss : 0.00235405
20 --> 0  loss : 0.00162395
21 --> 0  loss : 0.00395346
22 --> 0  loss : 0.00271449
23 --> 0  loss : 0.00133114
24 --> 0  loss : 0.00168175
25 --> 0  loss : 0.00195039
26 --> 0  loss : 0.00120667
27 --> 0  loss : 0.00113640
28 --> 0  loss : 0.00157883
29 --> 0  loss : 0.00246879
30 --> 0  loss : 0.00089439
31 --> 0  loss : 0.00161752
32 --> 0  loss : 0.00064279
33 --> 0  loss : 0.00043792
34 --> 0  loss : 0.00102361
35 --> 0  loss

TypeError: src data type = 17 is not supported

# 验证

In [40]:
import torchvision.transforms as transforms 
import time
path = 'D:\\PROJECT_TW\\git\\data\\example\\image\\char\\1_447.jpg'
image = cv2.imread(path,cv2.IMREAD_COLOR)    
if image.shape[0] != 32 or image.shape[1] != 32:
    image = cv2.resize(image,(32,32))
# aa[np.newaxis,:].shape, newaxis增加维度
# np.r_[bb,bb].shape 添加行数据
# image = image[np.newaxis,:]
print(image.shape)

start_time = time.time()
for _ in range(1):

    imdata = transforms.ToTensor()(image)
    imdata = imdata.unsqueeze(0)
#     print(imdata.size())
    preds = net(imdata)
    preds = F.softmax(preds,dim=1)
    print(preds)
    
print('time --> {}'.format((time.time()-start_time)))


(32, 32, 3)
tensor([[ 0.0089,  0.9911]])
time --> 0.007995367050170898
