In [None]:
import os
import sys
import json
from datetime import datetime
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from tqdm import tqdm
from model import resnet34
# 导入tensorboard库
from torch.utils.tensorboard import SummaryWriter

In [None]:

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Using {} device.".format(device))

data_transform = {
    "train": transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    "val": transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
}

data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))
image_path = os.path.join(data_root, "data_set", "flower_data")
assert os.path.exists(image_path), "{} path does not exist.".format(image_path)

train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
                                     transform=data_transform["train"])
train_num = len(train_dataset)

flower_list = train_dataset.class_to_idx
cla_dict = dict((val, key) for key, val in flower_list.items())
json_str = json.dumps(cla_dict, indent=4)
with open('class_indices.json', 'w') as json_file:
    json_file.write(json_str)

batch_size = 16
nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])
print('Using {} dataloader workers every process'.format(nw))

train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=batch_size, shuffle=True,
                                           num_workers=nw)

validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
                                        transform=data_transform["val"])
val_num = len(validate_dataset)
validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                              batch_size=batch_size, shuffle=False,
                                              num_workers=nw)

print("Using {} images for training, {} images for validation.".format(train_num, val_num))

net = resnet34()
model_weight_path = "./resnet34-pre.pth"
assert os.path.exists(model_weight_path), "File {} does not exist.".format(model_weight_path)
net.load_state_dict(torch.load(model_weight_path, map_location='cpu'))

in_channel = net.fc.in_features
net.fc = nn.Linear(in_channel, 5)
net.to(device)

loss_function = nn.CrossEntropyLoss()
params = [p for p in net.parameters() if p.requires_grad]
optimizer = optim.Adam(params, lr=0.0001)

epochs = 30
best_acc = 0.0
save_path = './resNet34.pth'

### 创建tensorboard的日志

In [2]:
# Create a SummaryWriter for logging
log_dir = os.path.join('logs', datetime.now().strftime('%Y%m%d-%H%M%S'))
writer = SummaryWriter(log_dir)


In [None]:
train_steps = len(train_loader)

for epoch in range(epochs):
    net.train()
    running_loss = 0.0
    train_bar = tqdm(train_loader, file=sys.stdout)

    for step, data in enumerate(train_bar):
        images, labels = data
        optimizer.zero_grad()
        logits = net(images.to(device))
        loss = loss_function(logits, labels.to(device))
        loss.backward()
        optimizer.step()
    
        writer.add_scalar('Train/Loss', loss, epoch) # 记录loss到tensorboard中

        running_loss += loss.item()
        train_bar.desc = "Train epoch[{}/{}] loss:{:.3f}".format(epoch + 1, epochs, loss)

    net.eval()
    acc = 0.0
    with torch.no_grad():
        val_bar = tqdm(validate_loader, file=sys.stdout)

        for val_data in val_bar:
            val_images, val_labels = val_data
            outputs = net(val_images.to(device))
            predict_y = torch.max(outputs, dim=1)[1]
            acc += torch.eq(predict_y, val_labels.to(device)).sum().item()
            val_bar.desc = "Validation epoch[{}/{}]".format(epoch + 1, epochs)

    val_accurate = acc / val_num
    print('[Epoch %d] Train_loss: %.3f  Val_accuracy: %.3f' %
          (epoch + 1, running_loss / train_steps, val_accurate))

    
    writer.add_scalar('Validation/Accuracy', val_accurate, epoch) # 记录val_accurate到tensorboard中

    if val_accurate > best_acc:
        best_acc = val_accurate
        torch.save(net.state_dict(), save_path)

# Close the SummaryWriter
writer.close()

print('Finished Training')


