In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [None]:
import h5py
from IPython.display import display
import numpy as np
from os.path import join as pj
from os import getcwd as cwd
import pandas as pd
import random
import sys
import torch
import torch.nn as nn
import torch.utils.data as data
from tqdm import tqdm
import visdom
from PIL import Image

# Logger
from IO.logger import Logger
# Loader
from dataset.classification.loader import create_validation_split, load_validation_data
# Data Sampling
from dataset.classification.sampler import adopt_sampling
# Dataset
from dataset.classification.dataset import insects_dataset
# Model
from model.resnet.resnet import ResNet
from model.resnet.utils import define_weight
from model.resnet.predict import test_classification
from model.optimizer import AdamW
# Evaluation
from evaluation.classification.evaluate import accuracy, confusion_matrix
# Statistics
from evaluation.classification.statistics import compute_each_size_df, compute_all_size_df
# Visualize
from evaluation.classification.visualize import create_confusion_matrix, plot_df_distrib_size

# Train Config

In [None]:
class args:
    # experiment_name
    experiment_name = "resnet50_b20_r45_lr1e-5_crossvalid_20200806_Rotate"
    # data split
    train_ratio = 0.8 # unused parameter
    test_ratio = 0.2
    # paths
    all_data_path = pj(cwd(), "data/all_classification_data/classify_insect_std_20200806")
    model_root = pj(cwd(), "output_model/classification/ResNet101", experiment_name)
    figure_root = pj(cwd(), "figure/classification/ResNet101", experiment_name)
    # train config
    model_name = "resnet50" # choice ["resnet18", "resnet34", "resnet50", "resnet101", "resnet152"]
    bs = 20
    lr = 1e-5
    lamda = 0
    nepoch = 100
    pretrain = True
    param_freeze = False
    sampling = None # choice [None, "RandomSample", "OverSample"]
    method_aug = ["Rotate"]
    optimizer = "AdamW" # choice ["Adam, AdamW"]
    activation_function = "ReLU" # choice ["ReLU", "LeakyReLU", "RReLU"]
    decoder = None # choice [None, "Concatenate", "FPN"]
    # test config
    save_fig = True
    save_df = True
    # visdom
    visdom = True
    port = 8099

In [None]:
dataset_name = args.all_data_path.split('/')[-1]
if dataset_name == 'classify_insect_std':
    args.labels = ['Diptera', 'Ephemeridae', 'Ephemeroptera', 
                   'Lepidoptera', 'Plecoptera', 'Trichoptera']
elif dataset_name == 'classify_insect_std_resizeFAR':
    args.labels = ['Diptera', 'Ephemeridae', 'Ephemeroptera', 
                   'Lepidoptera', 'Plecoptera', 'Trichoptera']
elif dataset_name == 'classify_insect_std_resize':
    args.labels = ['Diptera', 'Ephemeridae', 'Ephemeroptera', 
                   'Lepidoptera', 'Plecoptera', 'Trichoptera']
elif dataset_name == 'classify_insect_std_plus_other':
    args.labels = ['Diptera', 'Ephemeridae', 'Ephemeroptera', 
                   'Lepidoptera', 'Plecoptera', 'Trichoptera', 'Other']
elif dataset_name == 'classify_insect_std_20200806':
    args.labels = ['Diptera', 'Ephemeridae', 'Ephemeroptera', 
                   'Lepidoptera', 'Plecoptera', 'Trichoptera']

# Train

In [None]:
def train(model, counts, train_dataloader, valid_dataloader, test_dataloader, lr=1e-5, nepoch=40, visdom=False):
    # define weight and create loss function
    ce = torch.nn.CrossEntropyLoss(define_weight(counts)).cuda()
    l2_loss = nn.MSELoss(reduction='elementwise_mean').cuda()
    
    # define optimizer
    if args.optimizer == "Adam":
        opt = torch.optim.Adam(model.parameters(), lr=lr)
    elif args.optimizer == "AdamW":
        opt = AdamW(model.parameters(), lr=lr)
    
    # set model train mode
    model.train()
    
    epoch_tqdm = tqdm(range(nepoch),leave=False)
    # training
    for epoch in epoch_tqdm:
        sum_cls_loss = 0
        sum_norm_loss = 0
        total_loss = 0
        epoch_tqdm.set_description("epoch=%s" % (epoch))
        for x, y in train_dataloader:
            x = x.cuda()
            y = y.cuda()
            opt.zero_grad()
            out = model(x)
            cls_loss = ce(out, y)
            if args.lamda != 0:
                norm_loss = 0
                for param in model.parameters():
                    param_target = torch.zeros(param.size()).cuda()
                    norm_loss += l2_loss(param, param_target)

                norm_loss = norm_loss * args.lamda
                sum_cls_loss += cls_loss.item()
                sum_norm_loss += norm_loss.item()
                loss = cls_loss + norm_loss
            else:
                norm_loss = 0
                sum_cls_loss += cls_loss.item()
                sum_norm_loss = 0
                loss = cls_loss
            
            total_loss += loss.item()
            loss.backward()
            opt.step()
            if args.lamda != 0:
                sys.stdout.write("\rcls_loss=%f, norm_loss=%f" % (cls_loss.item(), norm_loss.item()))
            else:
                sys.stdout.write("\rcls_loss=%f" % (cls_loss.item()))
            sys.stdout.flush()
        
        model.training = False
        valid_acc = accuracy(model, valid_dataloader)
        te_acc = accuracy(model, test_dataloader)
        model.training = True
        if visdom:
            visualize(vis, epoch+1, sum_cls_loss, win_cls_loss)
            visualize(vis, epoch+1, sum_norm_loss, win_norm_loss)
            visualize(vis, epoch+1, total_loss, win_train_loss)
            visualize(vis, epoch+1, te_acc, win_test_acc)
            visualize(vis, epoch+1, valid_acc, win_train_acc)
        print("sum_cls_loss=%f, sum_norm_loss=%f, total_loss=%f, train_acc=%f, te_acc=%f" % (sum_cls_loss, sum_norm_loss, total_loss, valid_acc, te_acc))

# Set Visdom

In [None]:
if args.visdom:
    # Create visdom
    vis = visdom.Visdom(port=args.port)
    
    win_cls_loss = vis.line(
        X=np.array([0]),
        Y=np.array([0]),
        opts=dict(
            title='cls_loss',
            xlabel='epoch',
            ylabel='loss',
            width=800,
            height=400
        )
    )
    win_norm_loss = vis.line(
        X=np.array([0]),
        Y=np.array([0]),
        opts=dict(
            title='norm_loss',
            xlabel='epoch',
            ylabel='loss',
            width=800,
            height=400
        )
    )
    win_train_loss = vis.line(
        X=np.array([0]),
        Y=np.array([0]),
        opts=dict(
            title='train_loss',
            xlabel='epoch',
            ylabel='loss',
            width=800,
            height=400
        )
    )
    win_train_acc = vis.line(
        X=np.array([0]),
        Y=np.array([0]),
        opts=dict(
            title='train_accuracy',
            xlabel='epoch',
            ylabel='loss',
            width=800,
            height=400
        )
    )
    win_test_acc = vis.line(
        X=np.array([0]),
        Y=np.array([0]),
        opts=dict(
            title='test_accuracy',
            xlabel='epoch',
            ylabel='loss',
            width=800,
            height=400
        )
    )

In [None]:
def visualize(vis, phase, visualized_data, window):
    vis.line(
        X=np.array([phase]),
        Y=np.array([visualized_data]),
        update='append',
        win=window
    )

### Save args

In [None]:
args_logger = Logger(args)
args_logger.save()

# Cross Validation

In [None]:
model_save_path = pj(args.model_root, "final.pth")
if os.path.exists(args.model_root) is False:
    os.makedirs(args.model_root)
if os.path.exists(args.figure_root) is False:
    os.makedirs(args.figure_root)

In [None]:
valid_num = int(1.0/args.test_ratio)
with h5py.File(args.all_data_path) as f:
    X = f["X"][:]
    Y = f["Y"][:]
_, ntests = np.unique(Y, return_counts=True)
train_idxs, test_idxs = create_validation_split(Y, args.test_ratio)
result = []
fail_count = np.zeros(Y.shape[0], dtype="int")
for valid_count in range(valid_num):
    # create validation data
    valid_train_idx = adopt_sampling(Y, train_idxs[valid_count], args.sampling)
    valid_test_idx = test_idxs[valid_count]
    xtr, ytr, xte, yte = load_validation_data(X, Y, valid_train_idx, valid_test_idx)
    _, counts = np.unique(ytr, return_counts=True)
    # create dataloader
    train_dataset = insects_dataset(xtr, ytr, training=True, method_aug=args.method_aug)
    train_dataloader = data.DataLoader(train_dataset, args.bs, num_workers=args.bs, shuffle=True)
    valid_dataset = insects_dataset(xtr, ytr, training=False)
    valid_dataloader = data.DataLoader(valid_dataset, args.bs, num_workers=args.bs, shuffle=False)
    test_dataset = insects_dataset(xte, yte, training=False)
    test_dataloader = data.DataLoader(test_dataset, args.bs, num_workers=args.bs, shuffle=False)
    all_dataset = insects_dataset(X, Y, training=False)
    all_dataloader = data.DataLoader(all_dataset, args.bs, num_workers=args.bs, shuffle=False)
    
    # create model
    model = ResNet(args.model_name, len(args.labels), pretrain=args.pretrain, training=True, param_freeze=args.param_freeze, activation_function=args.activation_function, decoder=args.decoder).cuda()
    
    # training
    train(model, counts, train_dataloader, valid_dataloader, test_dataloader, lr=args.lr, nepoch=args.nepoch, visdom=args.visdom)
    torch.save(model.state_dict(), model_save_path)
    
    # make result
    model.training = False
    matrix = confusion_matrix(model, test_dataloader, args.labels)
    result.extend(test_classification(model, test_dataloader))
    _, correct = accuracy(model, all_dataloader, return_correct=True)
    model.training = True
    fail_count += ~correct
    
    # other
    df = pd.DataFrame(matrix)
    display(df)
    if valid_count == 0:
        validation_matrix = matrix
        x_all = xte
        y_all = yte
    else:
        validation_matrix += matrix
        x_all = np.concatenate([x_all, xte])
        y_all = np.concatenate([y_all, yte])

In [None]:
df = pd.DataFrame(validation_matrix)
if args.save_df is True:
    df.to_csv(pj(args.figure_root, "validation_matrix.csv"))
df

In [None]:
create_confusion_matrix(validation_matrix, ntests, args.labels, args.figure_root, save=args.save_fig)

In [None]:
each_df = compute_each_size_df(result, x_all, y_all)
if args.save_df is True:
    each_df.to_csv(pj(args.figure_root, "each_size_df.csv"))
each_df

In [None]:
all_df = compute_all_size_df(each_df)
if args.save_df is True:
    all_df.to_csv(pj(args.figure_root, "all_size_df.csv"))
all_df

In [None]:
fail_df = pd.DataFrame({"fail_count": fail_count})
if args.save_df is True:
    fail_df.to_csv(pj(args.figure_root, "fail_count.csv"))
fail_df

In [None]:
plot_df_distrib_size(all_df, args.figure_root, save=args.save_fig)