In [2]:
#导入需要的包
import os
import zipfile
import random
import json
import cv2
import numpy as np
from PIL import Image
import paddle
import paddle.fluid as fluid
from paddle.fluid.dygraph import Linear
import matplotlib.pyplot as plt
from argparse import Namespace as args
from pathlib import Path,PosixPath
import shutil
import pandas as pd
from sklearn.model_selection import train_test_split

# 1、函数准备

In [10]:
def unzip_data(src_path,target_path):

    '''
    解压原始数据集，将src_path路径下的zip包解压至data/dataset目录下
    '''

    if(not os.path.isdir(target_path)):    
        print(f"源文件地址：{src_path}",f"解压目标目录：{target_path}")
        z = zipfile.ZipFile(src_path, 'r')
        z.extractall(path=target_path)
        z.close()
    else:
        print("文件已解压")
    __MACOSX = Path(target_path) / '__MACOSX'
    if __MACOSX.is_dir():
        shutil.rmtree(__MACOSX)   
        
def data_reader(df):
    '''
    自定义data_reader
    '''
    def reader():
        for img_path,_,lbl in df.itertuples(index=False):
            img = Image.open(img_path)
            if img.mode != 'RGB': 
                img = img.convert('RGB') 
            img = img.resize((64, 64), Image.BILINEAR)
            img = np.array(img).astype('float32') 
            img = img.transpose((2, 0, 1))  # HWC to CHW 
            img = img/255                   # 像素值归一化 
            yield img, int(lbl)     
    return reader

# 2、参数初始化

In [3]:
SEED = 1234
random.seed(SEED)
np.random.seed(SEED)
dataset_prefix='test' #数据集名
train_params_path='work/'+dataset_prefix+'_params.json' #保存train_params
train_params={
 'input_size': [3, 64, 64],
 'class_dim': 25,
 'augment_path': 'augment/',
 'src_path': 'data/data55032/archive_train.zip',
 'target_path': 'work/dataset/'+dataset_prefix+'/',
 'train_txt': 'work/'+dataset_prefix+'_train.txt',
 'eval_txt': 'work/'+dataset_prefix+'_eval.txt',
 'label_dict': {},
 'num_epochs': 10,
 'batch_size': 8,
 'learning_strategy': {'lr': 0.001},
 'gen_img_count': 30,
 }

# 数据准备

In [13]:
'''
解压原始数据到指定路径
'''
unzip_data( train_params['src_path'], train_params['target_path'])

源文件地址：data/data55032/archive_train.zip 解压目标目录：work/dataset/test/


### 划分训练集测试集

In [4]:
#训练数据文件夹
targetPath = Path(train_params['target_path'])
class_dirs = sorted(targetPath.glob("*"))
print(class_dirs)
train_params['class_dim']=len(class_dirs)
#获取数据 metadata
lst_data = []
for i, class_dir in enumerate(class_dirs):
    lst_path = list(class_dir.glob("*.jpg"))
    lst_gemName = [p.parent.name for p in lst_path]
    #zip当前路径、名称以及该路径下的宝石类别,
    #lst_path、lst_genName都是一个list，所以类别也需要是一个list
    lst_data.extend( zip(map(str,lst_path),lst_gemName,[i]*len(lst_path)) )
    #print([i]*len(lst_path))
#print(lst_data)    

[]


In [15]:
train_params

{'input_size': [3, 64, 64],
 'class_dim': 25,
 'augment_path': 'augment/',
 'src_path': 'data/data55032/archive_train.zip',
 'target_path': 'work/dataset/test/',
 'train_txt': 'work/test_train.txt',
 'eval_txt': 'work/test_eval.txt',
 'label_dict': {},
 'num_epochs': 10,
 'batch_size': 8,
 'learning_strategy': {'lr': 0.001},
 'gen_img_count': 30}

In [17]:
#构建数据 dataframe 并且打乱    
df_data = pd.DataFrame(lst_data,columns=['gem_path','gem_name','lbl']).sample(frac=1,replace=False,random_state=SEED)
df_data.head(5)

Unnamed: 0,gem_path,gem_name,lbl
78,work\dataset\test\Benitoite\benitoite_21.jpg,Benitoite,2
730,work\dataset\test\Tanzanite\tanzanite_25.jpg,Tanzanite,22
160,work\dataset\test\Carnelian\carnelian_4.jpg,Carnelian,4
285,work\dataset\test\Emerald\emerald_34.jpg,Emerald,8
156,work\dataset\test\Carnelian\carnelian_33.jpg,Carnelian,4


```
DataFrame.sample(n=None, frac=None, replace=False, weights=None, random_state=None, axis=None)
其中：n和frac不能同时出现，n抽样几个，frac抽样比例，可以大于1，replace表示是否允许重复抽样，frac大于1的时候必须为True，
```

### 得到标签字典

In [18]:
#得到标签字典，无重复
dic = df_data[['lbl','gem_name']].drop_duplicates()
dic.head(5)

Unnamed: 0,lbl,gem_name
78,2,Benitoite
730,22,Tanzanite
160,4,Carnelian
285,8,Emerald
201,6,Danburite


In [21]:
train_params['label_dict'] = {v:int(k) for k,v in dic.to_records(index=False)}
train_params

{'input_size': [3, 64, 64],
 'class_dim': 25,
 'augment_path': 'augment/',
 'src_path': 'data/data55032/archive_train.zip',
 'target_path': 'work/dataset/test/',
 'train_txt': 'work/test_train.txt',
 'eval_txt': 'work/test_eval.txt',
 'label_dict': {'Benitoite': 2,
  'Tanzanite': 22,
  'Carnelian': 4,
  'Emerald': 8,
  'Danburite': 6,
  'Malachite': 16,
  'Hessonite': 11,
  'Garnet Red': 10,
  'Beryl Golden': 3,
  'Onyx Black': 17,
  'Variscite': 23,
  'Jade': 13,
  'Almandine': 1,
  'Cats Eye': 5,
  'Iolite': 12,
  'Fluorite': 9,
  'Zircon': 24,
  'Rhodochrosite': 20,
  'Labradorite': 15,
  'Pearl': 18,
  'Diamond': 7,
  'Alexandrite': 0,
  'Quartz Beer': 19,
  'Sapphire Blue': 21,
  'Kunzite': 14},
 'num_epochs': 10,
 'batch_size': 8,
 'learning_strategy': {'lr': 0.001},
 'gen_img_count': 30}

In [22]:
#分割 traing,validation 数据集
train_data, eval_data = train_test_split(df_data, test_size=0.1, random_state=42)
train_data.__class__
train_data.head(5)


Unnamed: 0,gem_path,gem_name,lbl
737,work\dataset\test\Tanzanite\tanzanite_32.jpg,Tanzanite,22
83,work\dataset\test\Benitoite\benitoite_26.jpg,Benitoite,2
503,work\dataset\test\Labradorite\labradorite_25.jpg,Labradorite,15
280,work\dataset\test\Emerald\emerald_29.jpg,Emerald,8
102,work\dataset\test\Beryl Golden\beryl golden_14...,Beryl Golden,3


### 拷贝评估集文件到另一个目录

In [24]:
#apply
eval_data.loc[:,'dest_path']=eval_data.loc[:,'gem_path'].apply(lambda x:str(x).replace('dataset','dataset_eval'))
eval_data.head(5)

Unnamed: 0,gem_path,gem_name,lbl,dest_path
33,work\dataset\test\Alexandrite\alexandrite_9.jpg,Alexandrite,0,work\dataset_eval\test\Alexandrite\alexandrite...
193,work\dataset\test\Cats Eye\cats eye_6.jpg,Cats Eye,5,work\dataset_eval\test\Cats Eye\cats eye_6.jpg
728,work\dataset\test\Tanzanite\tanzanite_23.jpg,Tanzanite,22,work\dataset_eval\test\Tanzanite\tanzanite_23.jpg
168,work\dataset\test\Cats Eye\cats eye_11.jpg,Cats Eye,5,work\dataset_eval\test\Cats Eye\cats eye_11.jpg
409,work\dataset\test\Iolite\iolite_24.jpg,Iolite,12,work\dataset_eval\test\Iolite\iolite_24.jpg


In [25]:
# eval_data中的图片移动到另一个目录，然后可以对剩下图片进行增强
# 就不用担心eval中数据被增强的问题
#eval.txt文件移动到另一个目录，然后对train.txt中文件进行数据增强
def move_eval_imgs():
    for row in eval_data[['gem_path','dest_path']].itertuples():        
        src_path=row[1]
        dest_path=row[2] 
        tmp_=Path(dest_path)    
        if not tmp_.exists(): tmp_.parent.mkdir(parents=True, exist_ok=True)
        shutil.move(src_path,dest_path)

def gen_eval_txt(eval_txt_path):
    with open(eval_txt_path,'w') as f:
        for row in eval_data[['dest_path','lbl']].itertuples():
            dest_path=str(row[1])
            label=str(row[2]) 
            f.write('{}\t{}\n'.format(dest_path,label))            
            
move_eval_imgs()            
gen_eval_txt(train_params['eval_txt'])            
    

### reset_index()参数说明：
```
dropbool, default False
Do not try to insert index into dataframe columns. This resets the index to the default integer index.

inplacebool, default False
Modify the DataFrame in place (do not create a new object).
```

In [26]:
train_data.reset_index(drop=True,inplace=True)
eval_data.reset_index(drop=True,inplace=True)

In [27]:
eval_data.head(10)    

Unnamed: 0,gem_path,gem_name,lbl,dest_path
0,work\dataset\test\Alexandrite\alexandrite_9.jpg,Alexandrite,0,work\dataset_eval\test\Alexandrite\alexandrite...
1,work\dataset\test\Cats Eye\cats eye_6.jpg,Cats Eye,5,work\dataset_eval\test\Cats Eye\cats eye_6.jpg
2,work\dataset\test\Tanzanite\tanzanite_23.jpg,Tanzanite,22,work\dataset_eval\test\Tanzanite\tanzanite_23.jpg
3,work\dataset\test\Cats Eye\cats eye_11.jpg,Cats Eye,5,work\dataset_eval\test\Cats Eye\cats eye_11.jpg
4,work\dataset\test\Iolite\iolite_24.jpg,Iolite,12,work\dataset_eval\test\Iolite\iolite_24.jpg
5,work\dataset\test\Carnelian\carnelian_35.jpg,Carnelian,4,work\dataset_eval\test\Carnelian\carnelian_35.jpg
6,work\dataset\test\Jade\jade_5.jpg,Jade,13,work\dataset_eval\test\Jade\jade_5.jpg
7,work\dataset\test\Rhodochrosite\rhodochrosite_...,Rhodochrosite,20,work\dataset_eval\test\Rhodochrosite\rhodochro...
8,work\dataset\test\Fluorite\fluorite_21.jpg,Fluorite,9,work\dataset_eval\test\Fluorite\fluorite_21.jpg
9,work\dataset\test\Labradorite\labradorite_17.jpg,Labradorite,15,work\dataset_eval\test\Labradorite\labradorite...


### 数据增强

In [28]:
# 此处代码主要是消除原始 4 通道图片的影响
def proc_img(src):
    for root, dirs, files in os.walk(src):
        for file in files:            
            src=os.path.join(root,file)
            #print(src)
            img=Image.open(src)
            if img.mode != 'RGB': 
                    img = img.convert('RGB') 
                    img.save(src)            


if __name__=='__main__':
    proc_img(train_params['target_path'])

In [30]:
import os, Augmentor
import shutil, glob

augment_path = train_params['augment_path']
gen_img_count=train_params['gen_img_count']
img_root=train_params['target_path']
def aug():
    if not os.path.exists(augment_path): # 控制不重复增强数据
        for root, dirs, files in os.walk(img_root, topdown=False):
            for name in dirs:
                path_ = os.path.join(root, name)
                if '__MACOSX' in path_:continue
                print('数据增强：',os.path.join(root, name))
                print('image：',os.path.join(root, name))
                
                p = Augmentor.Pipeline(os.path.join(root, name),output_directory='output')
                p.rotate(probability=0.6, max_left_rotation=2, max_right_rotation=2)
                p.zoom(probability=0.6, min_factor=0.9, max_factor=1.1)
                p.random_distortion(probability=0.4, grid_height=2, grid_width=2, magnitude=1)
                p.flip_left_right(probability=0.3)
                p.flip_top_bottom(probability=0.3)
                p.crop_random(probability=0.3,percentage_area=0.8)
                p.greyscale(probability=0.2)
                p.random_brightness(probability=0.2,min_factor=0.8,max_factor=1.2)
                
                count = gen_img_count - len(glob.glob(pathname=path_+'/*.jpg'))
                p.sample(count, multi_threaded=False)
                p.process()

        print('将生成的图片拷贝到正确的目录')
        tmp_dirs=Path(img_root).iterdir()
        #print(tmp_dirs)
        for dir_ in tmp_dirs:
            src_path=dir_/'output'
            dest_path=augment_path+"/"+dir_.name
            #print(src_path,dest_path)
            shutil.move(str(src_path),dest_path)
        print('完成数据增强')
aug()        

### 得到增强后的训练集

In [31]:
targetPath_aug = Path(train_params['augment_path'])
class_dirs_aug = sorted(targetPath_aug.glob("*"))
class_dirs_aug[:10]

[WindowsPath('augment/0'),
 WindowsPath('augment/1'),
 WindowsPath('augment/10'),
 WindowsPath('augment/11'),
 WindowsPath('augment/12'),
 WindowsPath('augment/13'),
 WindowsPath('augment/14'),
 WindowsPath('augment/15'),
 WindowsPath('augment/16'),
 WindowsPath('augment/17')]

In [32]:
lst_data = []
for i, class_dir in enumerate(class_dirs_aug):# 遍历增强后的数据
    lst_path = list(class_dir.glob("*.jpg"))  
    img = [p.name for p in lst_path ]
    lst_gemName = [p.parent.name for p in lst_path]  
    lst_data.extend( zip(map(str,lst_path), img, lst_gemName, [i]*len(lst_path)) )

lst_data[0]

('augment\\0\\0_original_img_10.jpg_8b847718-a221-4182-b5c5-615e8a6a62f3.jpg',
 '0_original_img_10.jpg_8b847718-a221-4182-b5c5-615e8a6a62f3.jpg',
 '0',
 0)

In [33]:
#构建数据 dataframe 并且打乱    
train_data = pd.DataFrame(lst_data,columns=['gem_path','img','gem_name','lbl']).sample(frac=1,replace=False,random_state=SEED)
train_data.head(10)

Unnamed: 0,gem_path,img,gem_name,lbl
6876,augment\23\23_original_img_11327.jpg_09b5c54a-...,23_original_img_11327.jpg_09b5c54a-8770-4e14-b...,23,16
9379,augment\29\29_original_img_14543.jpg_a504702e-...,29_original_img_14543.jpg_a504702e-838f-4fae-9...,29,22
6133,augment\21\21_original_img_10449.jpg_7b258511-...,21_original_img_10449.jpg_7b258511-5a76-43d6-9...,21,14
1051,augment\10\10_original_img_4562.jpg_5ecc752f-9...,10_original_img_4562.jpg_5ecc752f-9ec1-4fc7-aa...,10,2
10627,augment\31\31_original_img_15609.jpg_4cd04177-...,31_original_img_15609.jpg_4cd04177-66f8-42e0-b...,31,25
2946,augment\14\14_original_img_7102.jpg_93b01d26-7...,14_original_img_7102.jpg_93b01d26-766e-4282-a5...,14,6
6931,augment\23\23_original_img_11386.jpg_4dcfc25e-...,23_original_img_11386.jpg_4dcfc25e-1c8a-4201-8...,23,16
5460,augment\20\20_original_img_9606.jpg_5b346e52-2...,20_original_img_9606.jpg_5b346e52-2d74-46bc-91...,20,13
10579,augment\31\31_original_img_15530.jpg_5862fb9e-...,31_original_img_15530.jpg_5862fb9e-2a50-4fc0-a...,31,25
9356,augment\29\29_original_img_14513.jpg_f8077b73-...,29_original_img_14513.jpg_f8077b73-4673-46df-b...,29,22


#### 重新生成train.txt

In [34]:
#重新产生train.txt
#生成标签
def gen_train_txt(train_txt_path):
    with open(train_txt_path,'w') as f:
        for row in train_data[['gem_path','lbl']].itertuples():
            dest_path=str(row[1])
            label=str(row[2]) 
            f.write('{}\t{}\n'.format(dest_path,label))
gen_train_txt(train_params['train_txt'])


### 保存参数

In [35]:
train_data.to_json('pandas_train.json')
eval_data.to_json('pandas_eval.json')
with open(train_params_path,'w') as f:
    tmp_=json.dumps(train_params)
    f.write(tmp_)
    

In [6]:
train_data=pd.read_json('pandas_train.json')
eval_data=pd.read_json('pandas_eval.json')
with open(train_params_path,'r') as f:
    tmp_str=f.read()
train_params=json.loads(tmp_str)    
print(train_params)

{'input_size': [3, 64, 64], 'class_dim': 25, 'augment_path': 'augment/', 'src_path': 'data/data55032/archive_train.zip', 'target_path': 'work/dataset/test', 'train_txt': 'work/test_train.txt', 'eval_txt': 'work/test_eval.txt', 'label_dict': {'Benitoite': 2, 'Tanzanite': 22, 'Carnelian': 4, 'Emerald': 8, 'Danburite': 6, 'Malachite': 16, 'Hessonite': 11, 'Garnet Red': 10, 'Beryl Golden': 3, 'Onyx Black': 17, 'Variscite': 23, 'Jade': 13, 'Almandine': 1, 'Cats Eye': 5, 'Iolite': 12, 'Fluorite': 9, 'Zircon': 24, 'Rhodochrosite': 20, 'Labradorite': 15, 'Pearl': 18, 'Diamond': 7, 'Alexandrite': 0, 'Quartz Beer': 19, 'Sapphire Blue': 21, 'Kunzite': 14}, 'num_epochs': 10, 'batch_size': 8, 'learning_strategy': {'lr': 0.001}, 'gen_img_count': 30}


In [36]:
print("== 数据集总体情况:总类别数",train_data.lbl.max() + 1)
print("== 训练集不同类别的样本数：")
print(train_data.gem_name.value_counts())
print(f"== 训练集样本数：{len(train_data)}", f"验证集样本数：{len(eval_data)}")

== 数据集总体情况:总类别数 40
== 训练集不同类别的样本数：
11    635
21    566
25    472
27    460
6     400
26    400
13    400
38    400
36    400
8     400
7     400
31    400
4     400
35    400
22    400
28    400
39    400
20    400
10    400
16    400
29    400
30    400
2     400
3     400
37    400
14    400
32    400
15    400
24    400
5     400
34    400
19    400
1     400
18    400
23    400
17    400
0     400
33    400
9     400
12    400
Name: gem_name, dtype: int64
== 训练集样本数：16533 验证集样本数：82


In [37]:
'''
构造数据提供器
'''
train_reader = paddle.batch(data_reader(train_data),batch_size=train_params['batch_size'],drop_last=True)

eval_reader = paddle.batch(data_reader(eval_data),batch_size=train_params['batch_size'],drop_last=True)

# df用作训练接
下边代码不是完整代码

In [None]:
train_parameters = {
    "input_size": [3, 416, 416],  # 输入图片的尺寸
    "class_dim": -1,  # 分类数，用 -1 占位
    "data_path": "/home/aistudio/data/data101839/data.zip",  # 原始数据集的路径
    "target_path": "/home/aistudio/data",  # 保存文件的路径
    "train_list_path": "/home/aistudio/data/train.txt",  # train.txt 路径
    "val_list_path": "/home/aistudio/data/val.txt",  # val.txt 路径
    "test_list_path": "/home/aistudio/data/test.txt",  # test.txt 路径
    "readme_path": "/home/aistudio/data/readme.json",  # readme.json 路径
    "label_dict":{},  # 标签字典
    "num_epochs": 5,  # 训练轮数
    "batch_size": 16,  # 训练时每个批次的大小
    "learning_strategy": {  # 优化函数相关的配置
        "lr": 0.0005  # 学习率
    },
    "skip_steps": 5,  # 每N个批次打印一次结果
    "save_steps": 100,  # 每N个批次保存一次模型参数
    "checkpoints": "/home/aistudio/work/logs"  # 保存的路径
}

In [None]:
from pathlib import Path
import pandas as pd
root=Path(target_path,'data')
img_data=[]
for id,_dir in enumerate(sorted(root.glob('*'))):
    print(_dir)
    lst_path=list(_dir.glob("*.jpg"))
    img_data.extend(zip(map(str,lst_path),[_dir.name]*len(lst_path),[id]*len(lst_path)))

np.random.shuffle(img_data)

df=pd.DataFrame(img_data,columns=['img','dir','label']).sample(frac=1)
print(df.head())
dic=df[['label','dir']].drop_duplicates()
print(f'dic===:\n{dic}')
print(dic.describe())
train_parameters['label_dict']={v:k for k,v in dic.to_records(index=False)}
print(train_parameters['label_dict'])

In [None]:
num_imgs=len(df)
num_train=int(0.8*num_imgs)
num_eval=int(0.1*num_imgs)
num_test=num_imgs-num_train-num_eval
train_data=df[:num_train]
eval_data=df[num_train:num_train+num_eval]
test_data=df[-num_test:]

train_data.reset_index(drop=True,inplace=True)
eval_data.reset_index(drop=True,inplace=True)
test_data.reset_index(drop=True,inplace=True)
print(f'{num_train} {num_eval} {num_test}')
train_data.to_json('train.json')
eval_data.to_json('eval.json')
test_data.to_json('test.json')


In [None]:
print(train_data.iloc[0]['img'])

In [None]:
class DatasetTask(Dataset):
    def __init__(self, train_parameters, _df):
        """
        读取数据
        params:
                train_parameters: 参数字典
                mode: train or val or test
        """
        super(DatasetTask, self).__init__()
        self.target_path = train_parameters['target_path']
        self.input_size = train_parameters['input_size']
        self._df=_df



    def __getitem__(self, index):
        """
        获取一组数据
        params:
                index: 文件索引号
        """
        # 第一步打开图像文件并获取label值
        img_path = self._df.iloc[index]['img']
        img = Image.open(img_path)
        # 将数据集的图片大小统一缩放到指定大小
        img = img.resize((self.input_size[1], self.input_size[2]), Image.ANTIALIAS)
        if img.mode != 'RGB':
            img = img.convert('RGB')
        img = np.array(img).astype('float32')
        img = img.transpose((2, 0, 1)) / 255
        label = self._df.iloc[index]['label']
        label = np.array([label], dtype="int64")
        return img, label

    def print_sample(self, index: int = 0):
        """ 打印示例 """
        print("文件名", self._df.iloc[index]['img'], "\t标签值", self._df.iloc[index]['label'])

    def __len__(self):
        return len(self._df)