In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [None]:
import h5py
from IPython.display import display
import numpy as np
from os.path import join as pj
from os import getcwd as cwd
import pandas as pd
import random
import sys
import torch
import torch.nn as nn
from tqdm import tqdm
import visdom
from PIL import Image

# Logger
from IO.logger import Logger
# Data Augument
from dataset.classification.loader import create_validation_split, load_validation_data, create_train_data
# Data Sampling
from dataset.classification.sampler import get_randomsampled_idx, get_randomoversampled_idx
# Model
from model.resnet.resnet import ResNet
from model.resnet.utils import define_weight
from model.resnet.predict import test_classification
from model.resnet.data_augmentation import AutoAugment
from model.optimizer import AdamW
# Evaluation
from evaluation.classification.evaluate import accuracy, confusion_matrix
# Statistics
from evaluation.classification.statistics import compute_each_size_df, compute_all_size_df
# Visualize
from evaluation.classification.visualize import create_confusion_matrix, plot_df_distrib_size

# Train Config

In [None]:
class args:
    # experiment_name
    experiment_name = "resnet50_b20_r45_lr1e-5_crossvalid_fastautoaugment"
    # data split
    train_ratio = 0.8 # unused parameters
    test_ratio = 0.2
    # paths
    all_data_path = pj(cwd(), "data/all_classification_data/classify_insect_std")
    model_root = pj(cwd(), "output_model/classification/ResNet101", experiment_name)
    figure_root = pj(cwd(), "figure/classification/ResNet101", experiment_name)
    # train config
    model_name = "resnet50" # choice ["resnet18", "resnet34", "resnet50", "resnet101", "resnet152"]
    bs = 20
    lr = 1e-5
    lamda = 1e-2
    nepoch = 40
    rotate = 45
    pretrain = True
    param_freeze = False
    correction_term = False
    sampling = None # choice [None, "RandomSample", "RandomOverSample"]
    augment = "fastautoaugment" # choice [None, "RandomSizeCrop", "RegionConfusionMechanism", "autoaugment", "fastautoaugment"]
    optimizer = "AdamW" # choice ["Adam, AdamW"]
    activation_function = "ReLU" # choice ["ReLU", "LeakyReLU", "RReLU"]
    decoder = None # choice [None, "Concatenate", "FPN"]
    # test config
    save_fig = True
    save_df = True
    # visdom
    visdom = False
    port = 8097

In [None]:
dataset_name = args.all_data_path.split('/')[-1]
if dataset_name == 'classify_insect_std':
    args.labels =  ['Diptera', 'Ephemeridae', 'Ephemeroptera', 
                    'Lepidoptera', 'Plecoptera', 'Trichoptera']
elif dataset_name == 'classify_insect_std_resizeFAR':
    args.labels =  ['Diptera', 'Ephemeridae', 'Ephemeroptera', 
                    'Lepidoptera', 'Plecoptera', 'Trichoptera']
elif dataset_name == 'classify_insect_std_resize':
    args.labels =  ['Diptera', 'Ephemeridae', 'Ephemeroptera', 
                    'Lepidoptera', 'Plecoptera', 'Trichoptera']
elif dataset_name == 'classify_insect_std_plus_other':
    args.labels =  ['Diptera', 'Ephemeridae', 'Ephemeroptera', 
                    'Lepidoptera', 'Plecoptera', 'Trichoptera', 'Other']

# Train

In [None]:
def train(model, xtr, ytr, bs=20, lr=1e-5, nepoch=40, visdom=False):
    # calculate counts and count_sum
    _, counts = np.unique(ytr, return_counts=True)
    counts_sum = counts.sum()
    
    # define weight and create loss function
    ce = torch.nn.CrossEntropyLoss(define_weight(counts))
    l2_loss = nn.MSELoss(reduction='elementwise_mean').cuda()
    
    # define optimizer
    if args.optimizer == "Adam":
        opt = torch.optim.Adam(model.parameters(), lr=lr)
    elif args.optimizer == "AdamW":
        opt = AdamW(model.parameters(), lr=lr)
        
    # define autoaugment
    if args.augment == "autoaugment":
        autoaugment = AutoAugment()
    elif args.augment == "fastautoaugment":
        autoaugment = AutoAugment(policy_dir=pj(cwd(), "model/resnet"))
    
    # set model train mode
    model.train()
    
    t = tqdm(range(nepoch),leave=False)
    # training
    for epoch in t:
        sum_cls_loss = 0
        sum_norm_loss = 0
        total_loss = 0
        index = random.sample(range(counts_sum),counts_sum)
        t.set_description("epoch=%s" % (epoch))
        for idx in range(0,counts_sum-bs,bs):
            idx = index[idx:idx+bs]
            if args.augment == "autoaugment" or args.augment == "fastautoaugment":
                x = np.asarray([np.asarray(autoaugment(Image.fromarray(xtr[i].astype("uint8")))) for i in idx])
                y = ytr[idx]
                x = torch.from_numpy(x).transpose(1, -1).float().cuda()
                y = torch.from_numpy(y).cuda()
            else:
                x = xtr[idx].cuda()
                y = ytr[idx].cuda()
            opt.zero_grad()
            out = model(x)
            cls_loss = ce(out, y)
            if args.lamda != 0:
                norm_loss = 0
                for param in model.parameters():
                    param_target = torch.zeros(param.size()).cuda()
                    norm_loss += l2_loss(param, param_target)

                norm_loss = norm_loss * args.lamda
            else:
                norm_loss = 0
            sum_cls_loss += cls_loss.item()
            sum_norm_loss += norm_loss.item()
            loss = cls_loss + norm_loss
            total_loss += loss.item()
            loss.backward()
            opt.step()
            sys.stdout.write("\rcls_loss=%f, norm_loss=%f" % (cls_loss.item(), norm_loss.item()))
            sys.stdout.flush()
        
        model.training = False
        if args.correction_term == True:
            te_acc, correction_term = accuracy(model, xte, yte, bs, return_correction_term=True, low_trainable_correction=True)
            ce = torch.nn.CrossEntropyLoss(define_weight(counts) * correction_term)
            print(correction_term)
        else:
            te_acc = accuracy(model, xte, yte, bs)
        model.training = True
        if visdom:
            visualize(vis, epoch+1, sum_cls_loss, win_cls_loss)
            visualize(vis, epoch+1, sum_norm_loss, win_norm_loss)
            visualize(vis, epoch+1, total_loss, win_train_loss)
            visualize(vis, epoch+1, te_acc, win_test_acc)
        print("sum_cls_loss=%f, sum_norm_loss=%f, total_loss=%f, te_acc=%f" % (sum_cls_loss, sum_norm_loss, total_loss, te_acc))

# Set Visdom

In [None]:
if args.visdom:
    # Create visdom
    vis = visdom.Visdom(port=args.port)
    
    win_cls_loss = vis.line(
        X=np.array([0]),
        Y=np.array([0]),
        opts=dict(
            title='cls_loss',
            xlabel='epoch',
            ylabel='loss',
            width=800,
            height=400
        )
    )
    win_norm_loss = vis.line(
        X=np.array([0]),
        Y=np.array([0]),
        opts=dict(
            title='norm_loss',
            xlabel='epoch',
            ylabel='loss',
            width=800,
            height=400
        )
    )
    win_train_loss = vis.line(
        X=np.array([0]),
        Y=np.array([0]),
        opts=dict(
            title='train_loss',
            xlabel='epoch',
            ylabel='loss',
            width=800,
            height=400
        )
    )
    win_test_acc = vis.line(
        X=np.array([0]),
        Y=np.array([0]),
        opts=dict(
            title='test_accuracy',
            xlabel='epoch',
            ylabel='loss',
            width=800,
            height=400
        )
    )

In [None]:
def visualize(vis, phase, visualized_data, window):
    vis.line(
        X=np.array([phase]),
        Y=np.array([visualized_data]),
        update='append',
        win=window
    )

### Save args

In [None]:
args_logger = Logger(args)
args_logger.save()

# Cross Validation

In [None]:
model_save_path = pj(args.model_root, "final.pth")
if os.path.exists(args.model_root) is False:
    os.makedirs(args.model_root)
if os.path.exists(args.figure_root) is False:
    os.makedirs(args.figure_root)

In [None]:
def adopt_sampling(sampling, Y, idx):
    if sampling == "RandomSample":
        print("sampling = RandomSample")
        new_train_idx = get_randomsampled_idx(Y, idx)
    elif sampling == "RandomOverSample":
        print("sampling == RandomOverSample")
        new_train_idx = get_randomoversampled_idx(Y, idx)
    else:
        print("sampling = None")
        new_train_idx = idx
    return new_train_idx

In [None]:
valid_num = int(1.0/args.test_ratio)
with h5py.File(args.all_data_path) as f:
    X = f["X"][:]
    Y = f["Y"][:]
_, ntests = np.unique(Y, return_counts=True)
train_idxs, test_idxs = create_validation_split(Y, args.test_ratio)
result = []
for valid_count in range(valid_num):
    new_train_idx = adopt_sampling(args.sampling, Y, train_idxs[valid_count])
    xtr, ytr, xte, yte = load_validation_data(X, Y, new_train_idx, test_idxs[valid_count])
    xtr, ytr = create_train_data(xtr, ytr, args.rotate, args.augment)
    
    model = ResNet(args.model_name, len(args.labels), pretrain=args.pretrain, training=True, param_freeze=args.param_freeze, activation_function=args.activation_function, decoder=args.decoder).cuda()
    train(model, xtr, ytr, bs=args.bs, lr=args.lr, nepoch=args.nepoch, visdom=args.visdom)
    torch.save(model.state_dict(), model_save_path)
    
    model.training = False
    matrix = confusion_matrix(model, xte, yte, args.labels, bs=args.bs)
    model.training = True
    df = pd.DataFrame(matrix)
    display(df)
    if valid_count == 0:
        validation_matrix = matrix
        x_all = xte.cpu().numpy()
        y_all = yte.cpu().numpy()
    else:
        validation_matrix += matrix
        x_all = np.concatenate([x_all, xte.cpu().numpy()])
        y_all = np.concatenate([y_all, yte.cpu().numpy()])
    
    result.extend(test_classification(model, xte))

In [None]:
df = pd.DataFrame(validation_matrix)
if args.save_df is True:
    df.to_csv(pj(args.figure_root, "validation_matrix.csv"))
df

In [None]:
create_confusion_matrix(validation_matrix, ntests, args.labels, args.figure_root, save=args.save_fig)

In [None]:
each_df = compute_each_size_df(result, x_all, y_all)
if args.save_df is True:
    each_df.to_csv(pj(args.figure_root, "each_size_df.csv"))
each_df

In [None]:
all_df = compute_all_size_df(each_df)
if args.save_df is True:
    all_df.to_csv(pj(args.figure_root, "all_size_df.csv"))
all_df

In [None]:
plot_df_distrib_size(all_df, args.figure_root, save=args.save_fig)