# 数据预处理

In [1]:
import os
import random
data_path=r'plant-seedlings-classification'
train_path=os.path.join(data_path,'train')
test_path=os.path.join(data_path,'test')

In [2]:
train_path,test_path

('plant-seedlings-classification\\train',
 'plant-seedlings-classification\\test')

In [3]:
def get_data_list(path,train_list_path,eval_list_path):
    
    print("path=%s"%path)
    img_classes=os.listdir(path)

    train_list=[]
    eval_list=[]

    label=0
    cnt=0
    for img_class in img_classes:
        label_dict[str(label)]=img_class
        print("{}:{}".format(label,img_class))

        img_class_path=os.path.join(path,img_class)
        imgs=os.listdir(img_class_path)
        for img in imgs:
            img_path=os.path.join(img_class_path,img)
            cnt+=1
            if cnt%10==0:
                eval_list.append("%s\t%s\n"%(img_path,str(label)))
                print("%s\t%s"%(img_path,str(label)))
            else:
                train_list.append("%s\t%s\n"%(img_path,str(label)))
                print("%s\t%s\n"%(img_path,str(label)))
        label+=1
    
    random.shuffle(train_list)
    with open(train_list_path,'w') as f:
        for line in train_list:
            f.write(line)
    
    with open(eval_list_path,'w') as f:
        for line in eval_list:
            f.write(line)

In [4]:
label_dict={}
get_data_list(train_path,"train_list.txt","eval_list.txt")

path=plant-seedlings-classification\train
0:Black-grass
plant-seedlings-classification\train\Black-grass\0050f38b3.png	0

plant-seedlings-classification\train\Black-grass\0183fdf68.png	0

plant-seedlings-classification\train\Black-grass\0260cffa8.png	0

plant-seedlings-classification\train\Black-grass\05eedce4d.png	0

plant-seedlings-classification\train\Black-grass\075d004bc.png	0

plant-seedlings-classification\train\Black-grass\078eae073.png	0

plant-seedlings-classification\train\Black-grass\082314602.png	0

plant-seedlings-classification\train\Black-grass\0ace21089.png	0

plant-seedlings-classification\train\Black-grass\0b228a6b8.png	0

plant-seedlings-classification\train\Black-grass\0b3e7a7a9.png	0
plant-seedlings-classification\train\Black-grass\0bb75ded8.png	0

plant-seedlings-classification\train\Black-grass\0be707615.png	0

plant-seedlings-classification\train\Black-grass\0c67c3fc3.png	0

plant-seedlings-classification\train\Black-grass\0d1a9985f.png	0

plant-seedlings-class

In [5]:
label_dict

{'0': 'Black-grass',
 '1': 'Charlock',
 '2': 'Cleavers',
 '3': 'Common Chickweed',
 '4': 'Common wheat',
 '5': 'Fat Hen',
 '6': 'Loose Silky-bent',
 '7': 'Maize',
 '8': 'Scentless Mayweed',
 '9': 'Shepherds Purse',
 '10': 'Small-flowered Cranesbill',
 '11': 'Sugar beet'}

# 定义Dataset

In [6]:
from PIL import Image 
import torch
import numpy as np
class Reader(torch.utils.data.Dataset):
    def __init__(self,path):
        super().__init__()
        
        self.img_paths=[]
        self.labels=[]
        
        with open(path,'r',encoding='utf-8') as f:
            for line in f.readlines():
                self.img_paths.append(line.strip().split('\t')[0])
                self.labels.append(int(line.strip().split('\t')[1]))

    def __getitem__(self,index):

        img_path=self.img_paths[index]
        label=self.labels[index]
        
        img=Image.open(img_path)
        if img.mode !='RGB':
            img=img.convert('RGB')
        img=img.resize((224,224),Image.BILINEAR)
        img=np.array(img).astype('float32')
        img=img.transpose((2,0,1))
        img/=255.0
        
        label=np.array([label],dtype='int64')
        return img,label
    
    def pt(self,index):
        print("路径:{}\t 标签值:{}".format(self.img_paths[index],self.labels[index]))

    def __len__(self):
        return len(self.img_paths)

  from .autonotebook import tqdm as notebook_tqdm


In [7]:
batch_size=8

In [8]:
train_dataset=Reader("train_list.txt")
train_loader=torch.utils.data.DataLoader(train_dataset,batch_size=batch_size,shuffle=True)


eval_dataset=Reader("eval_list.txt")
eval_loader=torch.utils.data.DataLoader(eval_dataset,batch_size=batch_size,shuffle=False)

In [9]:
train_dataset.pt(100)
print("训练集样本数:%i"%train_dataset.__len__())
print("测试集样本数:%i"%eval_dataset.__len__())
eval_dataset.pt(0)

路径:plant-seedlings-classification\train\Cleavers\bd4f2a692.png	 标签值:2
训练集样本数:4275
测试集样本数:475
路径:plant-seedlings-classification\train\Black-grass\0b3e7a7a9.png	 标签值:0


# 配置模型

In [10]:
class ConvPool(torch.nn.Module):
    def __init__(self,num_channels,num_filters,filter_size,pool_size,pool_stride,groups,conv_stride=1,conv_padding=1):
        super(ConvPool,self).__init__()

        for i in range(groups):
            self.add_module(
                "bb_%d"%i,
                torch.nn.Conv2d(
                    in_channels=num_channels,
                    out_channels=num_filters,
                    kernel_size=filter_size,
                    stride=conv_stride,
                    padding=conv_padding
                )
            )
            self.add_module(
                'relu%d'%i,
                torch.nn.LeakyReLU(0.0001)
            )
            num_channels=num_filters # 循环中下次的输入通道和上次的输出通道数相同
        
        self.add_module(
            'Maxpool',
            torch.nn.MaxPool2d(
                kernel_size=pool_size,
                stride=pool_stride
            )
        )
    def forward(self,inputs):
        x=inputs
        for prefix,sub_layer in self.named_children():
            x=sub_layer(x)
        return x

In [11]:

class VGGNet(torch.nn.Module):
    def __init__(self):
        super(VGGNet,self).__init__()
        self.convpool01=ConvPool(3,64,3,2,2,2)
        self.convpool02=ConvPool(64,128,3,2,2,2)
        self.convpool03=ConvPool(128,256,3,2,2,3)
        self.convpool04=ConvPool(256,512,3,2,2,3)
        self.convpool05=ConvPool(512,512,3,2,2,3)

        self.fc01=torch.nn.Linear(512*7*7,12)
#         self.fc02=torch.nn.Linear(4096,4096)
#         self.fc03=torch.nn.Linear(4096,12)
    
    def forward(self,inputs):
        # print('input_size:{}'.format(inputs.shape))
        
        x=self.convpool01(inputs)
#         print(self.convpool01.state_dict())
        x=self.convpool02(x)
        x=self.convpool03(x)
        
        
        
        x=self.convpool04(x)
        x=self.convpool05(x)
        
        x=torch.reshape(x,[-1,512*7*7])
        x=self.fc01(x)
#         x=self.fc02(x)
#         y=self.fc03(x)
        return x

In [12]:
class test_model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.fc0=torch.nn.Linear(3*224*224,1024)
        self.fc1=torch.nn.Linear(1024,256)
        self.fc2=torch.nn.Linear(256,64)
        self.fc3=torch.nn.Linear(64,12)
    
    def forward(self,x):
        x=self.fc0(torch.reshape(x,[-1,3*224*224]))
        x=self.fc1(x)
        x=self.fc2(x)
        x=self.fc3(x)
        return x

In [13]:
# model=VGGNet()
model=test_model()


In [19]:
def fit():
    device = torch.device('cuda')
    model.to(device)
    model.train()

    cross_entropy=torch.nn.CrossEntropyLoss()
    opt=torch.optim.Adam(params=model.parameters(),lr=0.00005)

    steps=0
    Iters,losses,accs=[],[],[]
    maxacc=0
    for epo in range(50):
        for batch_id,data in enumerate(train_loader):
            opt.zero_grad()
            x=data[0].to(device)
            y=data[1].to(device)
            pred=model(x)
#             print(pred.shape,y.shape)
            loss=cross_entropy(pred,y.squeeze())
            res = torch.argmax(pred,1)
#             print(res,y.squeeze())
            acc = ((torch.eq(res,y.squeeze())).sum()).item()/res.shape[0]
#             acc=(pred==y).mean()
            loss.backward()
            opt.step()


            steps+=1
            if steps%10==0:
                Iters.append(steps)
                losses.append(loss.item())
                accs.append(acc)
                print("epoch:{},step:{},loss:{},acc:{}".format(epo,steps,loss.item(),acc))

            if steps%20==0:
                if maxacc<acc:
                    maxacc=acc
                    save_path='best.pd'
                    print('save to %s'%save_path)
                    torch.save(model.state_dict(),save_path)

        torch.save(model.state_dict(),"final.pt")

In [None]:
fit()

  img=img.resize((224,224),Image.BILINEAR)


epoch:0,step:10,loss:0.7246943712234497,acc:0.875
epoch:0,step:20,loss:0.6717216968536377,acc:0.625
save to best.pd
epoch:0,step:30,loss:1.3189449310302734,acc:0.5
epoch:0,step:40,loss:2.207777976989746,acc:0.5
epoch:0,step:50,loss:0.7840787172317505,acc:0.75
epoch:0,step:60,loss:0.4390544891357422,acc:0.875
save to best.pd
epoch:0,step:70,loss:1.0815703868865967,acc:0.625
epoch:0,step:80,loss:0.3363659679889679,acc:1.0
save to best.pd
epoch:0,step:90,loss:0.8209835290908813,acc:0.625
epoch:0,step:100,loss:0.5085561275482178,acc:0.75
epoch:0,step:110,loss:0.4212138056755066,acc:0.875
epoch:0,step:120,loss:0.4519316852092743,acc:0.75
epoch:0,step:130,loss:2.14052414894104,acc:0.5
epoch:0,step:140,loss:1.1353451013565063,acc:0.375
epoch:0,step:150,loss:1.4551613330841064,acc:0.5
epoch:0,step:160,loss:1.126947045326233,acc:0.5
epoch:0,step:170,loss:0.9341404438018799,acc:0.75
epoch:0,step:180,loss:0.9432311058044434,acc:0.5
epoch:0,step:190,loss:1.0590041875839233,acc:0.625
epoch:0,step:2

epoch:3,step:1630,loss:0.6640348434448242,acc:0.75
epoch:3,step:1640,loss:1.0299476385116577,acc:0.75
epoch:3,step:1650,loss:1.0055170059204102,acc:0.625
epoch:3,step:1660,loss:0.5094414353370667,acc:0.875
epoch:3,step:1670,loss:1.0461534261703491,acc:0.625
epoch:3,step:1680,loss:0.8393316864967346,acc:0.625
epoch:3,step:1690,loss:0.82033771276474,acc:0.75
epoch:3,step:1700,loss:0.9700087308883667,acc:0.75
epoch:3,step:1710,loss:0.3080141842365265,acc:1.0
epoch:3,step:1720,loss:0.3201158046722412,acc:1.0
epoch:3,step:1730,loss:0.6604888439178467,acc:0.75
epoch:3,step:1740,loss:0.8449487090110779,acc:0.75
epoch:3,step:1750,loss:1.3846219778060913,acc:0.5
epoch:3,step:1760,loss:0.9289695620536804,acc:0.75
epoch:3,step:1770,loss:0.4782201945781708,acc:0.875
epoch:3,step:1780,loss:1.7217713594436646,acc:0.5
epoch:3,step:1790,loss:0.9764373302459717,acc:0.625
epoch:3,step:1800,loss:1.723541498184204,acc:0.5
epoch:3,step:1810,loss:0.7130775451660156,acc:0.625
epoch:3,step:1820,loss:0.2374432

epoch:6,step:3230,loss:1.0071499347686768,acc:0.75
epoch:6,step:3240,loss:0.4999796748161316,acc:0.75
epoch:6,step:3250,loss:0.6065613627433777,acc:0.75
epoch:6,step:3260,loss:0.6871888637542725,acc:0.75
epoch:6,step:3270,loss:0.29320085048675537,acc:0.875
epoch:6,step:3280,loss:0.24876369535923004,acc:1.0
epoch:6,step:3290,loss:0.6298748254776001,acc:0.875
epoch:6,step:3300,loss:0.15671707689762115,acc:1.0
epoch:6,step:3310,loss:0.6281723976135254,acc:0.75
epoch:6,step:3320,loss:0.4764429032802582,acc:0.875
epoch:6,step:3330,loss:0.48408690094947815,acc:0.875
epoch:6,step:3340,loss:1.2753826379776,acc:0.5
epoch:6,step:3350,loss:0.31379374861717224,acc:0.875
epoch:6,step:3360,loss:0.5163087844848633,acc:0.75
epoch:6,step:3370,loss:0.558951199054718,acc:0.75
epoch:6,step:3380,loss:0.33534151315689087,acc:1.0
epoch:6,step:3390,loss:0.7394810914993286,acc:0.75
epoch:6,step:3400,loss:0.33501100540161133,acc:0.875
epoch:6,step:3410,loss:0.698578953742981,acc:0.875
epoch:6,step:3420,loss:0.3

epoch:9,step:4830,loss:0.8244940638542175,acc:0.5
epoch:9,step:4840,loss:0.6083436012268066,acc:0.625
epoch:9,step:4850,loss:0.46697232127189636,acc:0.875
epoch:9,step:4860,loss:0.1507672369480133,acc:1.0
epoch:9,step:4870,loss:0.4120815396308899,acc:0.875
epoch:9,step:4880,loss:0.9510124325752258,acc:0.625
epoch:9,step:4890,loss:0.503869891166687,acc:0.875
epoch:9,step:4900,loss:0.6702907681465149,acc:0.75
epoch:9,step:4910,loss:0.3475397527217865,acc:1.0
epoch:9,step:4920,loss:0.36337754130363464,acc:1.0
epoch:9,step:4930,loss:1.1955804824829102,acc:0.75
epoch:9,step:4940,loss:0.6211116313934326,acc:0.75
epoch:9,step:4950,loss:1.139604091644287,acc:0.625
epoch:9,step:4960,loss:0.3790247142314911,acc:0.875
epoch:9,step:4970,loss:0.5148234963417053,acc:0.75
epoch:9,step:4980,loss:0.3876262605190277,acc:0.875
epoch:9,step:4990,loss:0.5575256943702698,acc:0.75
epoch:9,step:5000,loss:0.4731884300708771,acc:0.75
epoch:9,step:5010,loss:0.996226966381073,acc:0.875
epoch:9,step:5020,loss:1.03

epoch:11,step:6410,loss:1.2512202262878418,acc:0.75
epoch:11,step:6420,loss:0.25441405177116394,acc:1.0
epoch:12,step:6430,loss:0.48599979281425476,acc:0.75
epoch:12,step:6440,loss:1.3787751197814941,acc:0.5
epoch:12,step:6450,loss:1.2157384157180786,acc:0.75
epoch:12,step:6460,loss:0.717886209487915,acc:0.75
epoch:12,step:6470,loss:0.5812557935714722,acc:0.875
epoch:12,step:6480,loss:0.74409419298172,acc:0.75
epoch:12,step:6490,loss:0.21691162884235382,acc:1.0
epoch:12,step:6500,loss:1.4738925695419312,acc:0.25
epoch:12,step:6510,loss:0.25331535935401917,acc:0.875
epoch:12,step:6520,loss:0.4743640422821045,acc:0.875
epoch:12,step:6530,loss:0.7019385099411011,acc:0.75
epoch:12,step:6540,loss:0.8759492039680481,acc:0.625
epoch:12,step:6550,loss:1.030710220336914,acc:0.625
epoch:12,step:6560,loss:0.9299812912940979,acc:0.625
epoch:12,step:6570,loss:1.116363525390625,acc:0.625
epoch:12,step:6580,loss:0.8264061808586121,acc:0.75
epoch:12,step:6590,loss:0.44842326641082764,acc:0.875
epoch:1

epoch:14,step:7980,loss:0.9912042617797852,acc:0.625
epoch:14,step:7990,loss:0.4288875460624695,acc:0.875
epoch:14,step:8000,loss:0.18543381989002228,acc:1.0
epoch:14,step:8010,loss:0.2418965995311737,acc:1.0
epoch:14,step:8020,loss:0.4767400920391083,acc:0.75
epoch:15,step:8030,loss:0.27721378207206726,acc:0.875
epoch:15,step:8040,loss:0.8370087742805481,acc:0.625
epoch:15,step:8050,loss:0.23141661286354065,acc:1.0
epoch:15,step:8060,loss:0.7672930359840393,acc:0.75
epoch:15,step:8070,loss:0.3563608229160309,acc:0.875
epoch:15,step:8080,loss:0.4858035147190094,acc:0.875
epoch:15,step:8090,loss:0.2861197888851166,acc:0.875
epoch:15,step:8100,loss:0.5599643588066101,acc:0.75
epoch:15,step:8110,loss:0.6313034296035767,acc:0.75
epoch:15,step:8120,loss:1.0548893213272095,acc:0.625
epoch:15,step:8130,loss:0.6215237379074097,acc:0.75
epoch:15,step:8140,loss:0.5788406133651733,acc:0.875
epoch:15,step:8150,loss:0.3419177532196045,acc:0.875
epoch:15,step:8160,loss:0.36977142095565796,acc:0.875


epoch:17,step:9550,loss:0.4159487187862396,acc:0.75
epoch:17,step:9560,loss:0.7216272950172424,acc:0.625
epoch:17,step:9570,loss:0.5368025898933411,acc:0.875
epoch:17,step:9580,loss:0.7198483943939209,acc:0.625
epoch:17,step:9590,loss:0.2303471863269806,acc:1.0
epoch:17,step:9600,loss:0.5153545141220093,acc:0.75
epoch:17,step:9610,loss:0.3792434334754944,acc:0.875
epoch:17,step:9620,loss:0.5927518010139465,acc:0.875
epoch:17,step:9630,loss:0.10375157743692398,acc:1.0
epoch:18,step:9640,loss:0.32037168741226196,acc:1.0
epoch:18,step:9650,loss:1.286491870880127,acc:0.625
epoch:18,step:9660,loss:0.09276081621646881,acc:1.0
epoch:18,step:9670,loss:0.7672497034072876,acc:0.75
epoch:18,step:9680,loss:0.18979883193969727,acc:1.0
epoch:18,step:9690,loss:0.5891258716583252,acc:0.625
epoch:18,step:9700,loss:0.1012631431221962,acc:1.0
epoch:18,step:9710,loss:0.538529098033905,acc:0.75
epoch:18,step:9720,loss:0.37066778540611267,acc:0.875
epoch:18,step:9730,loss:0.36092671751976013,acc:0.75
epoch:

epoch:20,step:11100,loss:0.509047269821167,acc:0.75
epoch:20,step:11110,loss:0.4483267068862915,acc:0.625
epoch:20,step:11120,loss:0.7114523649215698,acc:0.75
epoch:20,step:11130,loss:0.15093308687210083,acc:1.0
epoch:20,step:11140,loss:0.17132937908172607,acc:1.0
epoch:20,step:11150,loss:0.41949957609176636,acc:0.75
epoch:20,step:11160,loss:1.3206614255905151,acc:0.5
epoch:20,step:11170,loss:0.12228165566921234,acc:1.0
epoch:20,step:11180,loss:1.6354210376739502,acc:0.5
epoch:20,step:11190,loss:0.6297823786735535,acc:0.75
epoch:20,step:11200,loss:0.2884833812713623,acc:0.875
epoch:20,step:11210,loss:0.9723929166793823,acc:0.75
epoch:20,step:11220,loss:0.17527860403060913,acc:1.0
epoch:20,step:11230,loss:1.115792155265808,acc:0.75
epoch:21,step:11240,loss:0.177921861410141,acc:1.0
epoch:21,step:11250,loss:0.6973922252655029,acc:0.75
epoch:21,step:11260,loss:0.23759835958480835,acc:1.0
epoch:21,step:11270,loss:0.33275970816612244,acc:0.875
epoch:21,step:11280,loss:0.42670249938964844,ac