In [2]:
import os
import warnings
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import transforms
# from demo import BATCH_SIZE, DEVICE, EPOCHS, LEARNING_RATE, NUM_WORKERS, RUN_FOLDER, WEIGHT_DECAY
from net import resnet50
from utils import ISIC2018Dataset, save_model, Logger, Evaluation, plot_confusion_matrix, plot_roc_curves, plot_losses
warnings.filterwarnings('ignore')

In [None]:
RUN_FOLDER = './demo'
BATCH_SIZE = 32
NUM_WORKERS = 4
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

EPOCHS = 100
LEARNING_RATE = 0.0005
WEIGHT_DECAY  = 0.2

if not os.path.exists(RUN_FOLDER):
    os.makedirs(RUN_FOLDER)
    os.makedirs(os.path.join(RUN_FOLDER,'images'))
    os.makedirs(os.path.join(RUN_FOLDER,'models'))

LOGGER = Logger(RUN_FOLDER,'demo')
LOGGER.info('isic2018 demo run by cc')

In [None]:
train_trans1 = transforms.Compose([
    transforms.CenterCrop((450, 450)),
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406],
                         std=[0.229,0.224,0.225])
])

train_trans2 = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.CenterCrop((450, 450)),
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406],
                         std=[0.229,0.224,0.225])
])


train_trans3 = transforms.Compose([
    transforms.ColorJitter(contrast=0.5),
    transforms.CenterCrop((450, 450)),
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406],
                         std=[0.229,0.224,0.225])
])

test_trans = transforms.Compose([
    transforms.CenterCrop((450, 450)),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

train_trans = [train_trans1,train_trans2,train_trans3]
train_dataset = []
for train in train_trans:
    train_dataset.append(ISIC2018Dataset(
    csv_file_path='./data/ISIC2018/Train_GroundTruth.csv',
    img_dir='./data/ISIC2018/ISIC2018_Task3_Training_Input',
    transform=train))

test_dataset = ISIC2018Dataset(
    csv_file_path='./data/ISIC2018/Test_GroundTruth.csv',
    img_dir='./data/ISIC2018/ISIC2018_Task3_Training_Input',
    transform=test_trans
)

In [None]:
train_iter = []
for data in train_dataset:
    train_iter.append(
            DataLoader( data,
                        batch_size=BATCH_SIZE,
                        shuffle=True,
                        drop_last=True,
                        num_workers=NUM_WORKERS))
                        
test_iter = DataLoader(test_dataset,
                       batch_size=BATCH_SIZE,
                       num_workers=NUM_WORKERS)

In [None]:
net = resnet50(num_classes = 7)
net = net.to(DEVICE)
loss_fn = nn.CrossEntropyLoss(reduction='sum')
optim = torch.optim.Adam(net.parameters(),
                        lr=LEARNING_RATE,
                        weight_decay=WEIGHT_DECAY)

In [None]:
def train(train_iter,test_iter=test_iter,net=net,loss_fn=loss_fn,version = ''):
    train_l = []
    train_acc = []
    test_l = []
    test_acc = []
    for epoch in range(EPOCHS):
        correct = 0
        num_data = 0
        losses = 0
        net.train()
        for X,y in train_iter:
            num_data += X.shape[0]
            X,y = X.to(DEVICE), y.to(DEVICE)

            out = net(X)
            l = loss_fn(out,y)

            l.backward()
            optim.step()

            losses += l.cpu().detach().item()
            yhat = out.argmax(dim=1)
            correct += (yhat == y).sum().cpu().detach().item()

        loss = losses / num_data
        acc = correct / num_data
        train_l.append(loss)
        train_acc.append(acc)

        correct,num_data,losses = 0,0,0
        net.eval()
        for X, y in test_iter:
            num_data += X.shape[0]
            X, y = X.to(DEVICE), y.to(DEVICE)
            with torch.no_grad():
                out = net(X)
                l = loss_fn(out, y)

                losses += l.cpu().detach().item()
                yhat = out.argmax(dim=1)
                correct += (yhat == y).sum().cpu().detach().item()

        loss = losses / num_data
        acc = correct / num_data
        test_l.append(loss)
        test_acc.append(acc)

        LOGGER.info("Epoch {:03d} --- train loss: {:.4f} train acc: {:.4f}\ttest loss: {:.4f} test acc: {:.4f}".format(
        epoch+1, train_l[-1], train_acc[-1], test_l[-1], test_acc[-1]))

    plot_losses([train_l, train_acc, test_l, test_acc],
            title="loss and acc",
            legend=["train loss", "train acc", "test loss", "test acc"],
            filename=os.path.join(RUN_FOLDER, "images", "loss.png"+version))

    save_model(model=net, path=os.path.join(RUN_FOLDER, "models"+version))

    # 模型评估
    evaluation = Evaluation(net, test_iter, DEVICE,
                        categories=test_dataset.categories)
    report = evaluation.get_report()
    LOGGER.info(report)
    result = evaluation.evaluate(["c_matrix", "roc_curves"])
    plot_confusion_matrix(result["c_matrix"], test_dataset.categories,
                      title="confusion matrix",
                      filename=os.path.join(RUN_FOLDER, "images", "cm{}.png".format(version)))
    plot_roc_curves(result["roc_curves"][0],
                result["roc_curves"][1],
                result["roc_curves"][2],
                categories=test_dataset.categories,
                filename=os.path.join(RUN_FOLDER, "images", "roc_curve{}.png".format(version)))
    return train_acc[-1],test_acc[-1]

In [None]:
from matplotlib import pyplot as plt
train_acc,test_acc = [],[]

for iter,idx in enumerate(train_iter):
    temp = train(iter,version=idx)
    train_acc.append(temp[0])
    test_acc.append(temp[1])

x = [1,2,3]
plt.bar(x,train_acc,lw=0.5,fc="r",width=0.3,label="train")
plt.bar(x,test_acc,lw=0.5,fc="b",width=0.3,label="test")

plt.title("实验对照")
plt.xlabel("组别")
plt.ylabel("准确度")

plt.savefig(os.path.join(RUN_FOLDER, "images", "acc.png"), dpi=300)
plt.show()
