In [27]:
import torch
import torch.nn as nn
from torchvision import transforms, datasets, utils
import matplotlib.pyplot as plt
import numpy as np
import torch.optim as optim
import os
import json
import time

device = torch.device("cpu")
print(device)

cpu


In [28]:
class AlexNet(nn.Module):
    def __init__(self, num_classes = 1000, init_weights = False):
        super(AlexNet, self).__init__()
        
        self.features = nn.Sequential(
            nn.Conv2d(3, 48, kernel_size = 11, stride = 4, padding = 2),
            nn.ReLU(inplace = True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(48, 128, kernel_size=5, padding=2),
            nn.ReLU(inplace = True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(128, 192, kernel_size=3, padding=1),
            nn.ReLU(inplace = True),
            nn.Conv2d(192, 192, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(192, 128, kernel_size=3, padding=1), 
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
        
        )
        
        self.classifier = nn.Sequential(
            nn.Dropout(p=0.5),
            nn.Linear(128*6*6, 2048),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),
            nn.Linear(2048, 2048),
            nn.ReLU(inplace=True),
            nn.Linear(2048, num_classes),
        
        )
        
        if init_weights:
            self._initialize_weights()
            
            
    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, start_dim = 1)
        x = self.classifier(x)
        return x
        
        
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                    
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
                
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

In [29]:
data_transform = {
    "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                 transforms.RandomHorizontalFlip(p=0.5),
                                 transforms.ToTensor(),
                                 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                                ]),
    
    "val": transforms.Compose([transforms.Resize((224, 224)), 
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        
    ])
}

In [30]:
data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))
image_path = data_root + "Users/Administrator/flower_data/"

train_dataset = datasets.ImageFolder(root=image_path + "/train",
                                     transform=data_transform["train"])

train_num = len(train_dataset)

train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=32,
                                           shuffle=True,
                                           num_workers=0
                                          )

In [31]:
validate_dataset = datasets.ImageFolder(root=image_path + "/val",
                                        transform=data_transform["val"])

val_num = len(validate_dataset)

validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                              batch_size=32,
                                              shuffle=True,
                                              num_workers=0
                                             )

In [32]:
# 字典，类别：索引 {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
flower_list = train_dataset.class_to_idx
# 将 flower_list 中的 key 和 val 调换位置
cla_dict = dict((val, key) for key, val in flower_list.items())

# 将 cla_dict 写入 json 文件中
json_str = json.dumps(cla_dict, indent=4)
with open('class_indices.json', 'w') as json_file:
    json_file.write(json_str)


In [33]:
net = AlexNet(num_classes = 5, init_weights= True)
net.to(device)
loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr = 0.0002)

save_path = './AlexNet.pth'
best_acc = 0.0

for epoch in range(10):
    net.train()
    running_loss = 0.0
    time_start = time.perf_counter()
    
    for step, data in enumerate(train_loader, start= 0):
        images, labels = data
        optimizer.zero_grad()
        
        outputs = net(images.to(device))
        loss = loss_function(outputs, labels.to(device))
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        
        rate = (step + 1) / len(train_loader)
        a = "*" * int(rate * 50)
        b = "." * int((1 - rate) * 50)
        print("\rtrain loss: {:^3.0f}%[{}->{}]{:.3f}".format(int(rate * 100), a, b, loss), end="")
    print()
    print('%f s' % (time.perf_counter() - time_start))
    
    
    net.eval()
    acc = 0.0
    with torch.no_grad():
        for val_data in validate_loader:
            val_images, val_labels = val_data
            outputs = net(val_images.to(device))
            predict_y = torch.max(outputs, dim=1)[1]  # 以output中值最大位置对应的索引（标签）作为预测输出
            acc += (predict_y == val_labels.to(device)).sum().item()    
        val_accurate = acc / val_num
        
        # 保存准确率最高的那次网络参数
        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(net.state_dict(), save_path)
            
        print('[epoch %d] train_loss: %.3f  test_accuracy: %.3f \n' %
              (epoch + 1, running_loss / step, val_accurate))

print('Finished Training')

    

train loss: 100%[**************************************************->]1.560
43.898343 s
[epoch 1] train_loss: 1.413  test_accuracy: 0.424 

train loss: 100%[**************************************************->]1.352
44.489712 s
[epoch 2] train_loss: 1.249  test_accuracy: 0.504 

train loss: 100%[**************************************************->]1.187
43.125236 s
[epoch 3] train_loss: 1.188  test_accuracy: 0.578 

train loss: 100%[**************************************************->]0.857
43.734212 s
[epoch 4] train_loss: 1.109  test_accuracy: 0.566 

train loss: 100%[**************************************************->]1.364
43.973357 s
[epoch 5] train_loss: 1.049  test_accuracy: 0.650 

train loss: 100%[**************************************************->]1.108
44.918972 s
[epoch 6] train_loss: 0.989  test_accuracy: 0.664 

train loss: 100%[**************************************************->]0.620
45.390587 s
[epoch 7] train_loss: 0.940  test_accuracy: 0.658 

train loss: 100%[***