In [1]:
import torch
import os, sys, json, cv2, random, torchvision
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from torchvision import transforms
import seaborn as sns
from PIL import Image
from numpy import interp
import warnings
import torch.nn as nn
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
from tqdm import tqdm
from torchinfo import summary
from sklearn.metrics import auc, f1_score, roc_curve, classification_report, confusion_matrix
from torch.utils.tensorboard import SummaryWriter
from torch.optim.lr_scheduler import CosineAnnealingLR
from itertools import cycle


from config import device, epochs, root, batch_size, lr, weight_decay, save_path, data_transform, resume, best_val_accuracy
from my_dataset import MyDataset
from utils import Plot_ROC, train_step, val_step, read_split_data
from VGGPredict import predict_single_image, predictor

In [None]:
train_image_path, train_image_label, val_image_path, val_image_label, class_indices = read_split_data(root)

train_dataset = MyDataset(train_image_path, train_image_label, data_transform['train'])
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True,
                          collate_fn=train_dataset.collate_fn)

valid_dataset = MyDataset(val_image_path, val_image_label, data_transform['valid'])
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True,
                          collate_fn=valid_dataset.collate_fn)

In [3]:
vgg16 = torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights)

In [7]:
vgg16.classifier[6].out_features = 5

In [8]:
vgg16

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1

In [9]:
net = vgg16
optimizer = torch.optim.Adam(net.parameters(), lr=lr, weight_decay=weight_decay)

In [None]:
scalar = torch.cuda.amp.GradScaler() if torch.cuda.is_bf16_supported() else None

best_val_accuracy = 0

for epoch in range(epochs):
    # train
    train_loss, train_accuracy = train_step(net, optimizer, train_loader, device, epoch, scalar)
    # valid
    val_loss, val_accuracy = val_step(net, valid_loader, device, epoch)

    lr_scheduler.step()

    if val_accuracy > best_val_accuracy:
        best_val_accuracy = val_accuracy
        save_parameters = {
            'model': net.state_dict(),
            'best_accuracy': val_accuracy,
            'optimizer': optimizer.state_dict(),
            'lr_scheduler': lr_scheduler.state_dict()
        }

        torch.save(save_parameters, save_path)

print('Now we predict an image!!!')
predict_single_image()
print('\n')
f1score = predictor(valid_loader)
Plot_ROC(net, valid_loader, save_path, device)