In [49]:
import torch
import numpy as np
import torchvision
import os
import random
import torchvision
from torchvision import transforms
from skimage.transform import resize
from PIL import Image
from torch.utils.tensorboard import SummaryWriter

In [2]:
file_path = 'train'

In [3]:
mapping = {}
all_files = []
for label, directory in enumerate(os.listdir(file_path)):
#     print(label) # 0-11
#     print(directory) # 12个文件夹的名字
    mapping[label] = directory # 为mapping字典添加内容
    tmp_list = [[file_path + '/' + directory+'/'+file,label] for file in os.listdir(file_path + '/' + directory)] # 将每个文件夹里的文件路径和类别写成list
    all_files.extend(tmp_list) # 将这些list添加到一起，得到所有图片的路径和对应的label

In [4]:
# mapping字典里的内容 {0: 'Black-grass', 1: 'Charlock', 2: 'Cleavers', 3: 'Common Chickweed', 4: 'Common wheat', 5: 'Fat Hen', 6: 'Loose Silky-bent', 7: 'Maize', 8: 'Scentless Mayweed'， 9: 'Shepherds Purse',10: 'Small-flowered Cranesbill',11: 'Sugar beet'}
# mapping

In [5]:
# all_files列表的内容[['train/Black-grass/0050f38b3.png', 0],['train/Black-grass/0183fdf68.png', 0], ['train/Black-grass/0260cffa8.png', 0],
# all_files

In [6]:
random.shuffle(all_files) # 将all_files的顺序打乱

In [7]:
# 打乱后的all_files [['train/Black-grass/d622ca3d2.png', 0], ['train/Maize/a5c2eec2d.png', 7], ['train/Scentless Mayweed/4205830a0.png', 8], ['train/Loose Silky-bent/b3f997421.png', 6], ['train/Black-grass/0183fdf68.png', 0], ['train/Shepherds Purse/eae41be4f.png', 9],
# all_files

In [9]:
resize_target = 200 # 由于图片大小不同 指定为固定尺寸200*200
batch_size = 32
data_transform = transforms.Compose([
        transforms.transforms.Resize((resize_target,resize_target)),
        transforms.transforms.ColorJitter(brightness=1, contrast=0.5, saturation=0.5, hue=0),
        transforms.transforms.RandomHorizontalFlip(),
        transforms.transforms.RandomRotation(180),
        transforms.transforms.RandomVerticalFlip(),
        transforms.ToTensor()
    ])

In [10]:
def adjust_colors(img):
    img = transforms.functional.adjust_brightness(img, 2)
    img = transforms.functional.adjust_contrast(img, 1.1)
    img = transforms.functional.adjust_saturation(img, 1.1)
    return img


class dataset(torch.utils.data.Dataset):
    """Reads through a DB one by one, perform transforms"""
    def __init__(self, file_list, segment = False, transform = False):
        self.file_list = file_list
        self.transform = transform 
        self.segment = segment
        
    def __len__(self):
        return len(self.file_list)
        
    def __getitem__(self, idx):
        item = self.file_list[idx][0]
        label = self.file_list[idx][1]
        img = Image.open(item).convert('RGB')
        if self.segment:
            img = segment_plant(img)
        img = adjust_colors(img)
        if self.transform:
            img = self.transform (img)
        else:
            img =  transforms.functional.resize(img, (resize_target,resize_target))
            img =  transforms.functional.to_tensor(img)
        return img, label

In [11]:
dataset = dataset(all_files, segment = False,transform = data_transform)

In [12]:
len(dataset)

4750

In [15]:
# 检查dataset
for idx,(img,label) in enumerate(dataset):
    print('idx=',idx)
#     print(img)
    print(img.shape)
    print('label=',label)
    if idx == 2:
        break

idx= 0
torch.Size([3, 200, 200])
label= 3
idx= 1
torch.Size([3, 200, 200])
label= 1
idx= 2
torch.Size([3, 200, 200])
label= 9
idx= 3
torch.Size([3, 200, 200])
label= 9
idx= 4
torch.Size([3, 200, 200])
label= 3
idx= 5
torch.Size([3, 200, 200])
label= 8


In [16]:
split = int(np.floor(0.2 * len(dataset))) # 设置val所占比例0.2
indices = list(range(len(dataset)))
train_idx, valid_idx = indices[split:], indices[:split] 
train_sampler_random = torch.utils.data.SubsetRandomSampler(train_idx) 
valid_sampler_random = torch.utils.data.SubsetRandomSampler(valid_idx)

train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,sampler=train_sampler_random)
val_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,sampler=valid_sampler_random)

In [18]:
# 检查train_loader
for batch_idx,(img,label) in enumerate(train_loader):
    print('batch_idx=',batch_idx)
#     print(img)
    print(img.shape)
    print('label=',label)
    if batch_idx == 2:
        break

batch_idx= 0
torch.Size([32, 3, 200, 200])
label= tensor([ 8,  3,  4,  6, 10, 10,  8,  6,  3,  4,  5, 10,  8,  5,  2,  9,  3,  5,
        10,  0, 10,  6, 10,  2,  6,  5,  6, 11, 10,  9,  9,  2])
batch_idx= 1
torch.Size([32, 3, 200, 200])
label= tensor([ 3,  8,  6,  9,  3, 10,  7,  2,  3,  5, 10,  0,  3, 10,  9,  8,  5,  3,
         3,  2, 11,  1,  9,  7,  3,  1,  3,  8,  1,  6,  5,  2])
batch_idx= 2
torch.Size([32, 3, 200, 200])
label= tensor([ 0,  9,  3,  6, 10,  4, 10,  5,  8,  8,  4,  2,  6,  8,  9,  8,  0,  6,
         5, 11,  6,  4,  0,  2,  6,  9, 10,  3, 11,  6,  6,  2])


In [19]:
# 检查val_loader
for batch_idx,(img,label) in enumerate(val_loader):
    print('batch_idx=',batch_idx)
#     print(img)
    print(img.shape)
    print('label=',label)
    if batch_idx == 2:
        break

batch_idx= 0
torch.Size([32, 3, 200, 200])
label= tensor([10,  2, 11,  4,  2,  4, 11,  3,  3,  4, 11, 10,  7,  9,  6,  2,  8,  2,
         5,  9, 10,  8, 11, 10, 10,  4,  5,  1,  8,  9,  8,  4])
batch_idx= 1
torch.Size([32, 3, 200, 200])
label= tensor([ 5, 10,  6,  8,  1,  4, 10,  3,  8,  6,  6,  1,  3,  2,  1,  5,  8,  8,
         8,  3,  4, 11,  6,  3,  0,  5,  4,  4, 10,  6,  5, 10])
batch_idx= 2
torch.Size([32, 3, 200, 200])
label= tensor([11,  3,  0, 11, 11,  2,  3,  8, 10,  7,  6, 11,  1,  5, 10,  6,  5,  5,
         5, 11,  6,  7,  7,  2,  7,  5,  1,  8,  6, 10,  1,  8])


In [50]:
len(train_loader)

119

In [51]:
len(val_loader)

30

In [26]:
'-'*30+'模型构建(这里为方便测试先使用torchvision自带的模型及ImageNet上的预训练模型)'+'-'*30

'------------------------------模型构建(这里为方便测试先使用torchvision自带的模型及ImageNet上的预训练模型)------------------------------'

In [45]:
resnet18 = torchvision.models.resnet18(pretrained=True)
resnet18.fc = torch.nn.Linear(512,12)
vgg11 = torchvision.models.vgg11(pretrained=True)
vgg11.classifier[6] = torch.nn.Linear(4096,12)

In [47]:
device = torch.device('cuda:0'if torch.cuda.is_available() else 'cpu') # 选择设备
print('device:',device)
Model = resnet18 # 选择模型
model = Model.to(device)
criterion = torch.nn.CrossEntropyLoss() # 损失函数
optimizer = torch.optim.SGD(model.parameters(),lr=0.01,momentum=0.9) # 优化器

device: cuda:0


In [24]:
'-'*30+'开始训练'+'-'*30

'------------------------------开始训练------------------------------'

In [50]:
!rm -rf logs
%reload_ext tensorboard
%tensorboard --logdir logs/fit --port=6007
# --logdir后面为tensorboard数据的地址  port为端口号
writer = SummaryWriter(log_dir='logs/fit/loss') # tensorboard 绘图
writer1 = SummaryWriter(log_dir='logs/fit/acc')

'rm' 不是内部或外部命令，也不是可运行的程序
或批处理文件。


Reusing TensorBoard on port 6007 (pid 14952), started 0:00:24 ago. (Use '!kill 14952' to kill it.)

In [51]:
final_train_loss = []
final_train_acc = []
final_valid_loss = []
final_valid_acc = []
best_acc = 0

In [52]:
import time
num_epochs = 100
start = time.time()


for epoch in range(1,num_epochs+1):
    start1 = time.time()
    print('----------------------epoch = %s------------------------' % epoch)
    total_step = 0
    train_loss_list = []
    train_acc_list = []
    for ind, (img, cls) in enumerate(train_loader):
        model.train()
        x, y = img.to(device), cls.to(device)
        y_pred = model(x)
        loss = criterion(y_pred, y)
        _,predicted = torch.max(y_pred.data,1)
        acc = (predicted == y).sum()/len(y)
        train_loss_list.append(loss.item())
        train_acc_list.append(acc.item())
        total_step += 1
        if total_step % 10 == 0:
            print('**epoch=', epoch, '**train_loss=', loss.item(),'**acc=',acc.item(), '**batch / num of batch  =  ', total_step,'/',len(train_loader))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    total_valid_loss = 0
    valid_loss_list = []
    valid_acc_list = []
    with torch.no_grad():
        total_step1 = 0
        for ind, (img, cls) in enumerate(valid_loader):
            model.eval()
            x1, y1 = img.to(device), cls.to(device)
            y_pred1 = model(x1)
            loss = criterion(y_pred1, y1)  # 每次for循环得到一个batch的loss

            _,predicted = torch.max(y_pred1.data,1)
            acc = (predicted == y1).sum()/len(y1)

            valid_loss_list.append(loss.item())
            valid_acc_list.append(acc.item())
            
            total_step1 += 1
            if total_step1 % 10 == 0:
                print('**epoch=', epoch, '**valid_loss=', loss.item(),'**valid_acc=',acc.item(), '**step / num of batch  =  ', total_step1,'/',len(test_loader))
    train_loss = np.mean(train_loss_list)            
    valid_loss = np.mean(valid_loss_list)
    train_acc =np.mean(train_acc_list)
    valid_acc =np.mean(valid_acc_list)

    if valid_acc>= best_acc:
      best_acc = valid_acc
      torch.save(model.state_dict(),'model_state_dict_epoch.pt')
      print('已保存第%s个epoch的模型' % epoch)
    print('epoch=', epoch, 'mean_train_loss=', train_loss)
    print('epoch=', epoch, 'mean_train_acc=', train_acc)
    print('epoch=', epoch, 'mean_valid_loss=', valid_loss)
    print('epoch=', epoch, 'mean_valid_acc=', valid_acc)
    final_train_loss.append(train_loss)
    final_train_acc.append(train_acc)
    final_valid_loss.append(valid_loss)
    final_valid_acc.append(valid_acc)
    writer.add_scalar('train_loss',train_loss,epoch) # 把loss值写入summary writer
    writer.add_scalar('valid_loss',valid_loss,epoch)
    writer1.add_scalar('train_acc',train_acc,epoch)
    writer1.add_scalar('valid_acc',valid_acc,epoch) 
    print('本epoch运行时长%s' % (time.time()-start1))
end = time.time()
print('运行时长%s' % (end-start))

----------------------epoch = 1------------------------
**epoch= 1 **train_loss= 2.6431334018707275 **acc= 0.21875 **batch / num of batch  =   10 / 119


KeyboardInterrupt: 

In [None]:
import matplotlib.pyplot as plt
plt.rc('font',family='Times New Roman')

epochs = range(num_epochs)
plt.plot(epochs, final_train_acc, 'r', label='Training Acc')
plt.plot(epochs, final_valid_acc, 'b',linewidth=1, label='Validation Acc') 
ax = plt.subplot(111)
# 设置刻度字体大小
plt.xticks(fontsize=15)
plt.yticks(fontsize=15)
# 设置坐标标签字体大小
ax.set_xlabel('Number of iteration', fontsize=18)
ax.set_ylabel('Acc', fontsize=18)
ax.legend(fontsize=15)