In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from PIL import Image
from torchvision import transforms, datasets, models
import os
import shutil
from tqdm import tqdm
import json
import sys
import platform
from torchvision.models import mobilenet_v2
import model_v3

In [None]:
dataset_path = os.path.join(os.getcwd(), "dataset_class4")
json_path = 'garbage_classify_rule.json'
assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)
with open(json_path, 'r', encoding = "utf8") as f:
    class_indict = json.load(f)
nums = ['其他垃圾', '厨余垃圾', '可回收物', '有害垃圾']

In [None]:
# 将华为云数据按照标签进行分类整理
if not os.path.exists(dataset_path):
    os.mkdir(dataset_path)
assert os.path.exists(dataset_path), "dataset path not exists"
for root, dirs, files in os.walk(os.path.join(os.getcwd(), "train_data")):
    for file in tqdm(files):
        if(os.path.splitext(file)[1]=='.txt'):
            with open(os.path.join(root, file), "r") as f:
                file_name, classify = f.readlines()[0].split(",")
                file_name = file_name.strip()
                classify = classify.strip()
                classify = str(nums.index(class_indict[classify].split("/")[0]))
                if not os.path.exists(os.path.join(dataset_path, classify)):
                    os.mkdir(os.path.join(dataset_path, classify))
                shutil.copyfile(os.path.join(root, file_name), os.path.join(dataset_path, classify, file_name))

In [None]:
# 统计类别
classify_num = len(nums)
print("there exist {} classes in this datasets".format(classify_num))
for root, dirs, files in os.walk(dataset_path):
    if(len(files) == 0):
        continue
    # print(class_indict[os.path.split(root)[-1]], len(files))

In [None]:
# 数据集处理
data_transform = {
    "train": transforms.Compose([transforms.Resize(224),
                                transforms.CenterCrop(224),
                                transforms.RandomHorizontalFlip(p=0.5),
                                transforms.RandomVerticalFlip(p=0.5),
                                transforms.RandomRotation(45),
                                transforms.RandomGrayscale(p=0.5),
                                transforms.ToTensor(),
                                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
    "val": transforms.Compose([transforms.Resize(256),
                                transforms.CenterCrop(224),
                                transforms.ToTensor(),
                                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
    }

In [None]:
# 训练集测试集划分

# 选择设备
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# 载入图片并进行训练集和测试集划分
split_rate = 0.8
dataset = datasets.ImageFolder(root = dataset_path, transform = data_transform['train']) #对图像进行处理
train_size = int(len(dataset) * split_rate)
valid_size = len(dataset) - train_size
train_dataset, valid_dataset = torch.utils.data.random_split(dataset, (train_size, valid_size))
print("using {} images for training, {} images for validation.".format(train_size, valid_size))

# 构建data_loader
batch_size = 16
num_work = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
# windows num_work = 0
num_work = 0 if platform.system().lower() == "windows" else num_work
print("using {} num_work for training and validation.".format(num_work))
train_loader = torch.utils.data.DataLoader(train_dataset,
                                            batch_size=batch_size, shuffle=True,
                                            num_workers=num_work)

validate_loader = torch.utils.data.DataLoader(valid_dataset,
                                            batch_size=batch_size, shuffle=False,
                                            num_workers=num_work)

In [None]:
# 网络构建
nets = ["resnet18", "resnet50", "resnet101", "mobilenet_v2", "mobilenet_v3"]
choose_net = nets[3]
print("using {} net to train!".format(choose_net))
if(choose_net == "resnet18"):
    net = models.resnet18(pretrained = False) # 这里可以使用True直接下载 
elif choose_net == "resnet50":
    net = models.resnet50(pretrained = False)
elif choose_net == "resnet101":
    net = models.resnet101(pretrained = False)
elif choose_net == "mobilenet_v2":
    net = models.mobilenet_v2(pretrained = False)
elif choose_net == "mobilenet_v3":
    net = model_v3.mobilenet_v3_large(num_classes = classify_num)

# 载入权重
model_weight_path = "{name}-pre.pth".format(name = choose_net)
assert os.path.exists(model_weight_path), "file {} does not exist.".format(model_weight_path)
net.load_state_dict(torch.load(model_weight_path, map_location=device))

# 对于模型的每个权重，使其不进行反向传播，即固定参数
for param in net.parameters():
    param.requires_grad = False

# 修改最后一层，更换输出类别数
if("resnet" in choose_net):
    # 不固定最后一层，即全连接层fc
    for param in net.fc.parameters():
        param.requires_grad = True
    channel_in = net.fc.in_features #获取fc层的输入通道数
    net.fc = nn.Linear(channel_in, classify_num)
elif("mobilenet_v2" in choose_net):
    # 不固定最后一层，即classify
    for param in net.classifier.parameters():
        param.requires_grad = True
    channel_in = net.last_channel #获取classify层的输入通道数
    net.classifier = nn.Sequential(nn.Dropout(0.2), nn.Linear(channel_in, classify_num))
else:
    pass

net.to(device)

# 设定损失函数
loss_function = nn.CrossEntropyLoss()

# 优化器过滤
optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, net.parameters()), lr=0.001, momentum=0.9)
# params = [p for p in net.parameters() if p.requires_grad]
# optimizer = torch.optim.Adam(params, lr=0.0003)
save_path = 'garbage_{}_class4.pkl'.format(choose_net)


In [None]:
# 训练
losslist=[]
epochs = 20
best_acc = 0.0
for epoch in range(epochs):
    # train
    net.train()
    running_loss = 0.0
    # use tqdm
    train_bar = tqdm(train_loader, file=sys.stdout)
    for step, data in enumerate(train_bar):
        images, labels = data
        # make commulate grad to zero
        optimizer.zero_grad()
        output = net(images.to(device))
        loss = loss_function(output, labels.to(device))
        
        # loss backward
        loss.backward()
        
        # update parameter grad
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1, epochs, loss)
        
        # print statistics for every 200 batch_size
        # if step % 200 == 199:
        #     print('[%d %5d] loss: %.3f' % (epoch + 1, step + 1, running_loss / 200))
        #     losslist.append(running_loss / 200)
        #     running_loss = 0.0
    
    # eval
    net.eval()
    acc = 0.0
    with torch.no_grad():
        valid_bar = tqdm(validate_loader, file=sys.stdout)
        for val_data in valid_bar:
            val_images, val_labels = val_data
            outputs = net(val_images.to(device))
            loss = loss_function(outputs, val_labels.to(device))
            predict = torch.max(outputs, dim=1)[1]
            acc += torch.eq(predict, val_labels.to(device)).sum().item()
            valid_bar.desc = "valid epoch[{}/{}] loss:{:.3f}".format(epoch + 1, epochs, loss)
    val_accurate = acc / valid_size
    print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f\n' % (epoch + 1, running_loss / train_size, val_accurate))

    if val_accurate > best_acc:
        best_acc = val_accurate
        torch.save(net, save_path)

print("Finished Training")



In [None]:
# test and plot
model_path = "garbage_resnet18_class4.pkl"
# load image
img_path = "test_pic/3/yaqian.jpg"
assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
img = Image.open(img_path)
plt.imshow(img)

# [N, C, H, W]
img = data_transform['val'](img)
img = torch.unsqueeze(img, dim=0)

# read class_indict
assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)
with open(json_path, 'r', encoding = "utf8") as f:
    class_indict = json.load(f)

# create model
assert os.path.exists(model_path), "file: '{}' dose not exist.".format(model_path)
model = torch.load(model_path)

cla_dict = dict((val, key) for key, val in dataset.class_to_idx.items())

# prediction
model.eval()
with torch.no_grad():
    # predict class
    output = torch.squeeze(model(img.to(device))).cpu()
    predict = torch.softmax(output, dim=0)
    predict_cla = torch.argmax(predict).numpy()

# print_res = "class: {}   prob: {:.3}".format(class_indict[str(predict_cla)], predict[predict_cla].numpy())
print(nums[int(cla_dict[predict_cla.item()])])
# for i in range(len(predict)):
#     print("class: {:10}   prob: {:.3}".format(class_indict[str(i)], predict[i].numpy()))
plt.show()

In [None]:
# 测试集
test_dataset = datasets.ImageFolder(root = 'test_pic',transform = data_transform['val'])
test_loader = torch.utils.data.DataLoader(test_dataset,batch_size =5,shuffle =  False,num_workers = 0)
net = torch.load('garbage_resnet50_class4.pkl')

# 测试集可能不会包含所有类别，需要重新进行class和idx映射
test_indict = dict((val, key) for key, val in test_dataset.class_to_idx.items())

# unnormalize
mean_tensor = torch.from_numpy(np.array([0.485, 0.456, 0.406])).reshape(3, 1, 1).to(torch.float32)
var_tensor = torch.from_numpy(np.array([0.229, 0.224, 0.225])).reshape(3, 1, 1).to(torch.float32)

for data in test_loader:  
    images, labels = data
    images, labels= images.to(device),labels.to(device)
    outputs = net(images)
    # return data and index
    _, predicted = torch.max(outputs, 1)
    # total += labels.size(0)
    # correct += (predicted == labels).sum()
    labels = labels.cpu().numpy()
    predicted = predicted.cpu().numpy()
    for i in range(len(labels)):
        print("lable:{}\npredict:{}\n".format(class_indict[test_indict[labels[i]]] , nums[int(cla_dict[predicted[i]])]))
        image = images[i].cpu() * var_tensor + mean_tensor
        img = transforms.ToPILImage()(image)
        display(img)  
    # for i in range(len(labels)):
    #     if(nums[labels[i]] != nums[predicted[i]]):
            
# print('Accuracy of the network on the test images: %d %%' % (100 * correct / total))