In [None]:
import cv2
import numpy
import os
import time
import copy
import matplotlib.pyplot
import torch
import torchvision
import gc
import barbar
import pandas
import datetime
import requests

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

today = str(datetime.datetime.today().strftime("%Y%m%d%H%M"))

n_classes = 2
study = 'AFB'

epoch_number = 300
IMAGE_FOLDER = '/root/AFB/IMAGE01/'

In [None]:
"""
Calcuate customized mean and std according to https://github.com/inhovation97/Image_classification_pipeline_Project/blob/main/pytorch/pytorch_project_Image_preprocessing.ipynb
"""
datasets_mean = [0.60195434, 0.63816965, 0.7695724]
datasets_std = [0.19413872, 0.18046075, 0.13762417]

In [None]:
"""
Adopted from https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html
"""

def train_model(model, criterion, optimizer, num_epochs):
    model_name = study + "-" + MODE + "-" + today
    since = time.time()

    best_acc = 0.0
    best_loss_epoch = 0
    vali_acc = []
    vali_los = []

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            #for inputs, labels in dataloaders[phase]:
            for idx, (inputs, labels) in enumerate(barbar.Bar(dataloaders[phase])):
                inputs = inputs.to(device)
                labels = labels.to(device)
                optimizer.zero_grad()

                if phase == 'train':
                    loss.backward()
                    optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]

            print('{} Loss: {:.6f} Acc: {:.6f}'.format(phase, epoch_loss, epoch_acc))

            if phase == 'val':
                if epoch == 0:
                    best_loss = epoch_loss
                valid_accu_value = torch.Tensor.numpy(torch.Tensor.cpu(epoch_acc))
                
            if phase == 'val' and epoch_loss < best_loss:
                best_loss = epoch_loss
                best_loss_epoch = epoch
                if is_debug:
                    pass
                else:
                    torch.save(model, "/root/AFB/MODEL/" + study + "_" + model_type + "_" + today + ".pt")
        if best_loss_epoch < epoch - 30:
            print("Early stopping")
            break
        print("Best val loss: {} at {}".format(best_loss, best_loss_epoch))

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best val Loss: {:6f}'.format(best_loss))

In [None]:
MODES = ['W', 'R', 'X', 'D', 'C', 'V', 'RG']

for MODE in MODES:
    print("MODE: {}".format(MODE))
    if MODE == 'W':
        batch_number = 16
        model_type = 'wideresnet'
    elif MODE == 'R':
        batch_number = 16
        model_type = 'resnet'
    elif MODE == "X":
        batch_number = 16
        model_type = 'resnext'
    elif MODE == "D":
        batch_number = 166
        model_type = 'densenet'
    elif MODE == "C":
        batch_number = 8
        model_type = 'convnext_large'
    elif MODE == "V":
        batch_number = 16
        model_type = 'ViT'
    elif MODE == "RG":
        batch_number = 14
        model_type = 'RegNet'

    print("Model : {}, Batch number: {}".format(model_type, batch_number))

    data_transforms = {
        'train': torchvision.transforms.Compose([
            torchvision.transforms.Resize((224, 224)),
            torchvision.transforms.RandomHorizontalFlip(),
            torchvision.transforms.RandomVerticalFlip(),
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(datasets_mean, datasets_std)
        ]),
        'val': torchvision.transforms.Compose([
            torchvision.transforms.Resize((224, 224)),
            torchvision.transforms.RandomHorizontalFlip(),
            torchvision.transforms.RandomVerticalFlip(),
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(datasets_mean, datasets_std)
        ])
    }
    image_datasets = {
        'train': torchvision.datasets.ImageFolder(root=(IMAGE_FOLDER + 'TRAIN'), transform=data_transforms['train']),
        'val':   torchvision.datasets.ImageFolder(root=(IMAGE_FOLDER + 'VALID'), transform=data_transforms['val'])
    }
    dataloaders={
        'train':torch.utils.data.DataLoader(image_datasets['train'], batch_size=batch_number, shuffle=True),
        'val':torch.utils.data.DataLoader(image_datasets['val'], batch_size=batch_number, shuffle=True)
    }
    dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
    class_names = image_datasets['train'].classes

    if MODE == "R":
        model = torchvision.models.resnet152(weights=None)
        model.fc = torch.nn.Linear(model.fc.in_features, n_classes)
    elif MODE == "W":
        model = torchvision.models.wide_resnet101_2(weights=None)
        model.fc = torch.nn.Linear(model.fc.in_features, n_classes)
    elif MODE == "D":
        model = torchvision.models.densenet161(weights=None)
        model.classifier = torch.nn.Linear(model.classifier.in_features, n_classes)
    elif MODE == "X":
        model = torchvision.models.resnext101_64x4d(weights=None)
        model.fc = torch.nn.Linear(model.fc.in_features, n_classes)
    elif MODE == "C":
        model = torchvision.models.convnext_large(weights=None)
        model.classifier[2] = torch.nn.Linear(model.classifier[2].in_features, n_classes)
    elif MODE == "RG":
        model = torchvision.models.regnet_y_32gf(weights=None)
        model.fc = torch.nn.Linear(model.fc.in_features, n_classes)
    elif MODE == "V":
        #model = torchvision.models.vit_l_16(weights=None)
        model = torchvision.models.vit_b_16(weights=None)
        model.heads.head = torch.nn.Linear(model.heads.head.in_features, n_classes)

    print(model)
    model = model.to(device)
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters())
    model = train_model(model, criterion, optimizer, num_epochs=epoch_number)