In [68]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models

import numpy as np
import matplotlib.pyplot as plt

In [69]:
import os

base_dir = '../../../assets/flower/102-flowers-dataset'
train_dir = os.path.join(base_dir, "train")
test_dir = os.path.join(base_dir, "test")

In [70]:
# dataset augment
data_transform = {
    'train': transforms.Compose([transforms.RandomRotation(45),
                                 transforms.CenterCrop(224),
                                 transforms.RandomVerticalFlip(p=0.5),
                                 transforms.RandomHorizontalFlip(p=0.5),
                                 transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1),
                                 transforms.RandomGrayscale(p=0.025),
                                 transforms.ToTensor(),
                                 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
    'valid': transforms.Compose([transforms.Resize(256),
                                 transforms.CenterCrop(224),
                                 transforms.ToTensor(),
                                 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
    'test': transforms.Compose([transforms.Resize(256),
                                transforms.CenterCrop(224),
                                transforms.ToTensor(),
                                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
}

In [71]:
# load data
batch_size = 48
train_datasets = datasets.ImageFolder(os.path.join(base_dir, "train"), data_transform["train"])
test_datasets = datasets.ImageFolder(os.path.join(base_dir, "test"), data_transform["test"])
valid_datasets = datasets.ImageFolder(os.path.join(base_dir, "valid"), data_transform["valid"])
train_dataloader = DataLoader(train_datasets, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_datasets, batch_size=batch_size, shuffle=True)
valid_dataloader = DataLoader(valid_datasets, batch_size=batch_size, shuffle=True)

train_datasets, test_datasets, valid_datasets

(Dataset ImageFolder
     Number of datapoints: 1020
     Root location: ../../../assets/flower/102-flowers-dataset\train
     StandardTransform
 Transform: Compose(
                RandomRotation(degrees=[-45.0, 45.0], interpolation=nearest, expand=False, fill=0)
                CenterCrop(size=(224, 224))
                RandomVerticalFlip(p=0.5)
                RandomHorizontalFlip(p=0.5)
                ColorJitter(brightness=[0.8, 1.2], contrast=[0.9, 1.1], saturation=[0.9, 1.1], hue=[-0.1, 0.1])
                RandomGrayscale(p=0.025)
                ToTensor()
                Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ),
 Dataset ImageFolder
     Number of datapoints: 6149
     Root location: ../../../assets/flower/102-flowers-dataset\test
     StandardTransform
 Transform: Compose(
                Resize(size=256, interpolation=bilinear, max_size=None, antialias=None)
                CenterCrop(size=(224, 224))
                ToTensor()
     

In [72]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

In [73]:
class FlowerModel(nn.Module):
    def __init__(self, model):
        super(FlowerModel, self).__init__()
        self.resnet = nn.Sequential(*list(model.children())[:-1])
        self.fc = nn.Linear(in_features=2048, out_features=102)

    def forward(self, x):
        x = self.resnet(x)
        x = x.view(x.shape[0], -1)
        x = self.fc(x)
        return x

resnet152_model = models.resnet152(pretrained=True)
model = FlowerModel(resnet152_model)



In [74]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.002)
epochs = 40

In [75]:
def eval_model(model, criterion, dataset, dataloader):
    best_acc = 0.0
    running_loss = 0.0
    running_corrects = 0

    model = model.to(device)
    for (idx, (inputs, labels)) in enumerate(dataloader):
        inputs = inputs.to(device)
        labels = labels.to(device)
        model.eval()

        with torch.no_grad():
            outputs = model(inputs)

        loss = criterion(outputs, labels)
        running_loss += loss.item()

        _, preds = torch.max(outputs, 1)  # val, idx
        running_corrects += torch.sum(preds == labels)

        print(f"{(idx + 1) * batch_size / len(dataset)}")

    epoch_loss = running_loss / len(dataset)
    epoch_acc = running_corrects / len(dataset)

    print(f"loss={epoch_loss}, acc={epoch_acc}")


eval_model(model, criterion, valid_datasets, valid_dataloader)

0.03137254901960784
0.06274509803921569


KeyboardInterrupt: 

In [76]:
def train_model(model,criterion,optimizer,epochs):
    model=model.to(device)
    model.train()

    epoch_loss=0.0
    for epoch in range(epochs):
        losses=0.0
        for (idx,(inputs,labels)) in enumerate(test_dataloader):
            inputs=inputs.to(device)
            labels=labels.to(device)

            optimizer.zero_grad() # 梯度清空
            outputs=model(inputs) # 前向传播
            loss=criterion(outputs,labels) # 计算损失函数
            loss.backward() # 反向传播
            optimizer.step() # 更新参数

            losses+=loss.item()

            if idx % 20==0:
                print(f"epoch={epoch}/{epochs}, {(idx+1)*batch_size}/{len(test_datasets)}, loss={losses/((idx+1)*batch_size)}")

        epoch_loss=losses/len(test_datasets)
        print(f"epoch={epoch}/{epochs}, losses={epoch_loss}")

train_model(model,criterion,optimizer,epochs)

KeyboardInterrupt: 

In [None]:
eval_model(model,criterion,test_datasets,test_dataloader)