In [None]:
import torch
import torchvision
from torch import nn
import pandas as pd
import numpy as np
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os
from matplotlib import pyplot as plt
from tqdm import tqdm
import seaborn as sns

In [None]:
labels_frame = pd.read_csv('./train.csv')
labels_frame.describe()

In [None]:
leaves_labels = sorted(list(set(labels_frame['label'])))
num_classes = len(leaves_labels)
leaves_labels[:10]

In [None]:
class2num = dict(zip(leaves_labels, range(num_classes)))
class2num

In [None]:
num2class = { v : k for k, v in class2num.items() }

In [None]:
class LeavesDataset(Dataset):
    def __init__(self, csv_path, file_path, mode='train', valid_ratio=0.2, resize_height=224, resize_width=224):
        self.resize_height = resize_height
        self.resize_width = resize_width
        self.file_path = file_path
        self.mode = mode
        
        self.data_info = pd.read_csv(csv_path, header=None)
        self.data_len = len(self.data_info.index) - 1
        self.train_len = int(self.data_len * (1 - valid_ratio))
        
        if mode == 'train':
            self.train_image = np.asarray(self.data_info.iloc[1:self.train_len, 0])
            self.train_label = np.asarray(self.data_info.iloc[1:self.train_len, 1])
            self.image_arr = self.train_image
            self.label_arr = self.train_label
        elif mode == 'valid':
            self.valid_image = np.asarray(self.data_info.iloc[self.train_len:, 0])
            self.valid_label = np.asarray(self.data_info.iloc[self.train_len:, 1])
            self.image_arr = self.valid_image
            self.label_arr = self.valid_label
        elif mode == 'test':
            self.test_image = np.asarray(self.data_info.iloc[1:, 0])
            self.image_arr = self.test_image
        
        self.real_len = len(self.image_arr)
        
        print('Finished reading the {} set of Leaves Dataset ({} samples found)'.format(mode, self.real_len))
        
    def __getitem__(self, index):
        single_image_name = self.image_arr[index]
        img_as_img = Image.open(self.file_path + single_image_name)
        
        if self.mode == 'train':
            train_augs = torchvision.transforms.Compose([
                torchvision.transforms.Resize((self.resize_height, self.resize_width)),
                torchvision.transforms.RandomHorizontalFlip(p=0.5),
                torchvision.transforms.ToTensor()
            ])
            
        else:
            valid_test_augs = torchvision.transforms.Compose([
                torchvision.transforms.Resize((self.resize_height, self.resize_width)),
                torchvision.transforms.ToTensor()
            ])
            
        if self.mode == 'train':
            img_as_img = train_augs(img_as_img)
        else:
            img_as_img = valid_test_augs(img_as_img)
        
        if self.mode == 'test':
            return img_as_img
        else:
            label = self.image_arr[index]
            number_label = class2num[label]
            return img_as_img, number_label
    
    def __len__(self):
        return self.real_len

In [None]:
train_path = './train.csv'
test_path = './test.csv'
image_path = './' # csv文件中已经images的路径了，因此这里只到上一级目录

train_dataset = LeavesDataset(train_path, image_path, 'train')
valid_dataset = LeavesDataset(train_path, image_path, 'valid')
test_dataset = LeavesDataset(test_path, image_path, 'test')

print(train_dataset)
print(valid_dataset)
print(test_dataset)

In [None]:
train_loader = torch.utils.data.DataLoader(
    dataset=train_dataset,
    batch_size=8,
    shuffle=False,
    num_workers=5
)

valid_loader = torch.utils.data.DataLoader(
    dataset=valid_dataset,
    batch_size=8,
    shuffle=False,
    num_workers=5
)

test_loader = torch.utils.data.DataLoader(
    dataset=test_dataset,
    batch_size=8,
    shuffle=False,
    num_workers=5
)

In [None]:
print(train_loader)

In [None]:
# def im_convert(tensor):
#     """ 展示数据"""
    
# #     image = tensor.to("cpu").clone().detach()
#     image = image.numpy().squeeze() # squeeze() 把维度为1的数组剪切，使得正常画图
#     image = image.transpose(1,2,0) # imshow() 是 h w c 而image是 c h w
#     image = image.clip(0, 1)

#     return image

# fig=plt.figure(figsize=(20, 12))
# columns = 4  # 2*4=8 正好是一个batch_size的大小
# rows = 2

# dataiter = iter(valid_loader)
# inputs, classes = dataiter.next()

# for idx in range (columns*rows):
#     ax = fig.add_subplot(rows, columns, idx+1, xticks=[], yticks=[])
#     ax.set_title(num2class[int(classes[idx])])
#     plt.imshow(im_convert(inputs[idx]))
# plt.show()

In [None]:
def get_device():
    return 'cuda' if torch.cuda.is_available() else 'cpu'

device = get_device()
print(device)

In [None]:
def set_parameter_requires_grad(model, feature_extracting):
    """ 模型冻结 """
    if feature_extracting:
        model = model
        for param in model.parameters():
            param.requires_grad = False

            
def res_model(num_classes, feature_extracting=False, use_pretrained=True):
    model_ft = torchvision.models.resnet34(pretrained=use_pretrained)
    set_parameter_requires_grad(model_ft, feature_extracting)
    num_ftrs = model_ft.fc.in_features
    model_ft.fc = nn.Sequential(nn.Linear(num_ftrs, num_classes))
    return model_ft

In [None]:
# 超参数
learning_rate = 3e-4
weight_decay = 1e-3
num_epoch = 30
model_path = './models/classify_leaves_v1.pth'

In [None]:
# 初始化模型
model = res_model(num_classes)
model = model.to(device)
model.device = device

# 损失函数 - 交叉熵
loss = nn.CrossEntropyLoss(reduction='none')

# 优化器
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

# 迭代次数
n_epochs = num_epoch

best_acc = 0.0
for epoch in range(n_epochs):
    # ----------------- Train -----------------
    model.train()
    train_loss = []
    train_accs = []
    for batch in tqdm(train_loader):
        imgs, labels = batch
        imgs = imgs.to(device)
        labels = labels.to(device)
        logits = model(imgs)
        l = loss(logits, labels)

        optimizer.zero_grad()
        l.backward()
        optimizer.step()
        acc = (logits.argmax(dim=-1) == labels).float().mean()

        train_loss.append(l.item())
        train_accs.append(acc)

    train_loss = sum(train_loss) / len(train_loss)
    train_acc = sum(train_accs) / len(train_accs)
    print(f"[ Train | {epoch + 1:03d}/{n_epochs:03d} ] loss = {train_loss:.5f}, acc = {train_acc:.5f}")

    # ----------------- Validation -----------------
    model.eval()
    valid_loss = []
    valid_accs = []

    for batch in tqdm(valid_loader):
        imgs, labels = batch
        with torch.no_grad():
            logits = model(imgs.to(device))

        l_v = loss(logits, labels.to(device))
        acc = (logits.argmax(dim=-1) == labels.to(device)).float().mean()

        valid_loss.append(l_v.item())
        valid_accs.append(acc)

    valid_loss = sum(valid_loss) / len(valid_loss)
    valid_acc = sum(valid_accs) / len(valid_accs)

    print(f"[ Valid | {epoch + 1:03d}/{n_epochs:03d} ] loss = {valid_loss:.5f}, acc = {valid_acc:.5f}")

    if valid_acc > best_acc:
        best_acc = valid_acc
        torch.save(model.state_dict(), model_path)
        print('saving model with acc {:.3f}'.format(best_acc))
