In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [None]:
import h5py
from IPython.display import display
import numpy as np
from os.path import join as pj
import pandas as pd
import random
import sys
import torch
from tqdm import tqdm
import visdom

# Logger
from IO.logger import Logger
# Data Augument
from dataset.classification.loader import create_validation_split, load_validation_data, create_train_data, create_train_data_DCL
# RCM
from dataset.classification.region_confusion_mechanism import region_confusion_mechanism
# Model
from model.resnet.utils import define_weight
from model.resnet.predict import test_classification
# Evaluation
from evaluation.classification.evaluate import accuracy, confusion_matrix
# Statistics
from evaluation.classification.statistics import compute_each_size_df, compute_all_size_df
# Visualize
from evaluation.classification.visualize import create_confusion_matrix, plot_df_distrib_size

# Train Config

In [None]:
class args:
    # experiment_name
    experiment_name = "resnet50_b20_r45_lr1e-5_crossvalid_resize_other_without_grouping"
    # data split
    train_ratio = 0.8 # unused parameters
    test_ratio = 0.2
    # paths
    all_data_path = "/home/tanida/workspace/Insect_Phenology_Detector/data/all_classification_data/classify_insect_std_resize_aquatic_other_without_grouping"
    model_root = pj("/home/tanida/workspace/Insect_Phenology_Detector/output_model/classification/ResNet101", experiment_name)
    figure_root = pj("/home/tanida/workspace/Insect_Phenology_Detector/figure/classification/ResNet101", experiment_name)
    # class names to visualize
    #labels =  ['Aquatic_insects', 'Other_insects']
    #labels =  ['Diptera', 'Ephemeridae', 'Ephemeroptera', 
    #           'Lepidoptera', 'Plecoptera', 'Trichoptera']
    labels =  ['Diptera', 'Ephemeridae', 'Ephemeroptera', 
               'Lepidoptera', 'Plecoptera', 'Trichoptera', 
               'Coleoptera', 'Hemiptera', 'medium insect', 'small insect']
    # train config
    model_name = "ResNet50" # choice ["ResNet18", "ResNet34", "ResNet50", "ResNet101"]
    use_DCL = False
    division_number = 7
    neighborhood_range = 1
    cls_weight = 1
    dest_weight = 10
    coord_weight = 0.1
    bs = 20
    lr = 1e-5
    nepoch = 40
    rotate = 45
    pretrain = True
    # test config
    save_fig = True
    save_df = True
    # visdom
    visdom = True
    port = 8097

# Load Model

In [None]:
# Model
if args.model_name=="ResNet18":
    from model.resnet.resnet18 import ResNet18
elif args.model_name=="ResNet34":
    from model.resnet.resnet34 import ResNet34
elif args.model_name=="ResNet50":
    from model.resnet.resnet50 import ResNet50
elif args.model_name=="ResNet101":
    from model.resnet.resnet101 import ResNet101
else:
    print("error! write correct model name!")

# Train

In [None]:
def train(model, xtr, ytr, bs=20, lr=1e-5, nepoch=40, visdom=False):
    # calculate counts and count_sum
    _, counts = np.unique(ytr, return_counts=True)
    counts_sum = counts.sum()
    
    # define weight and create loss function
    ce = torch.nn.CrossEntropyLoss(define_weight(counts))
    
    # define optimizer
    opt = torch.optim.Adam(model.parameters(), lr=lr)
    
    # set model train mode
    model.train()
    
    t = tqdm(range(nepoch),leave=False)
    # training
    for epoch in t:
        total_loss = 0
        index = random.sample(range(counts_sum),counts_sum)
        t.set_description("epoch=%s" % (epoch))
        for idx in range(0,counts_sum-bs,bs):
            idx = index[idx:idx+bs]
            x = xtr[idx].cuda()
            y = ytr[idx].cuda()
            opt.zero_grad()
            out = model(x)
            loss = ce(out, y)
            total_loss += loss.item()
            loss.backward()
            opt.step()
        
        model.training = False
        te_acc = accuracy(model, xte, yte, bs)
        model.training = True
        if visdom:
            visualize(vis, epoch+1, total_loss, win_train_loss)
            visualize(vis, epoch+1, te_acc, win_test_acc)
        sys.stdout.write("\rtotal_loss=%f, te_acc=%f" % (total_loss,te_acc))
        sys.stdout.flush()
        

def train_DCL(model, xtr, ytr, target_dest_or_not, target_coordinate, bs=20, lr=1e-5, nepoch=40, visdom=False, cls_weight=1, dest_weight=10, coord_weight=0.1):
    # calculate counts and count_sum
    _, counts = np.unique(ytr, return_counts=True)
    counts_sum = counts.sum()
    
    # define weight and create loss function
    cls_ce = torch.nn.CrossEntropyLoss(define_weight(counts))
    dest_ce = torch.nn.CrossEntropyLoss()
    coord_l1 = torch.nn.L1Loss()
    
    # define optimizer
    opt = torch.optim.Adam(model.parameters(), lr=lr)
    
    # set model train mode
    model.train()
    
    t = tqdm(range(nepoch),leave=False)
    # training
    for epoch in t:
        sum_cls_loss = 0
        sum_dest_loss = 0
        sum_coord_loss = 0
        total_loss = 0
        index = random.sample(range(counts_sum),counts_sum)
        t.set_description("epoch=%s" % (epoch))
        for idx in range(0,counts_sum-bs,bs):
            idx = index[idx:idx+bs]
            x = xtr[idx].cuda()
            y = ytr[idx].cuda()
            sample_dest_or_not = target_dest_or_not[idx].cuda()
            sample_coordinate = target_coordinate[idx].cuda()
            opt.zero_grad()
            out, predict_loc, dest_or_not = model(x)
            cls_loss = cls_ce(out, y) * cls_weight
            dest_loss = dest_ce(dest_or_not, sample_dest_or_not) * dest_weight
            coord_loss = coord_l1(predict_loc, sample_coordinate) * coord_weight
            loss = cls_loss + dest_loss + coord_loss
            sum_cls_loss += cls_loss.item()
            sum_dest_loss += dest_loss.item()
            sum_coord_loss += coord_loss.item()
            total_loss += loss.item()
            loss.backward()
            opt.step()
        
        model.training = False
        te_acc = accuracy(model, xte, yte, bs)
        model.training = True
        if visdom:
            visualize(vis, epoch+1, sum_cls_loss, win_cls_loss)
            visualize(vis, epoch+1, sum_dest_loss, win_dest_loss)
            visualize(vis, epoch+1, sum_coord_loss, win_coord_loss)
            visualize(vis, epoch+1, total_loss, win_train_loss)
            visualize(vis, epoch+1, te_acc, win_test_acc)
        sys.stdout.write("\rcls_loss=%f, dest_loss=%f, coord_loss=%f, train_loss=%f, te_acc=%f" % (sum_cls_loss, sum_dest_loss, sum_coord_loss, total_loss, te_acc))
        sys.stdout.flush()

# Set Visdom

In [None]:
if args.visdom:
    # Create visdom
    vis = visdom.Visdom(port=args.port)
    
    """train_loss"""
    win_train_loss = vis.line(
        X=np.array([0]),
        Y=np.array([0]),
        opts=dict(
            title='train_loss',
            xlabel='epoch',
            ylabel='loss',
            width=800,
            height=400
        )
    )
    """test_accuracy"""
    win_test_acc = vis.line(
        X=np.array([0]),
        Y=np.array([0]),
        opts=dict(
            title='test_accuracy',
            xlabel='epoch',
            ylabel='loss',
            width=800,
            height=400
        )
    )
    
    if args.use_DCL is True:
        """train_loss"""
        win_cls_loss = vis.line(
            X=np.array([0]),
            Y=np.array([0]),
            opts=dict(
                title='classification_loss',
                xlabel='epoch',
                ylabel='loss',
                width=800,
                height=400
            )
        )
        win_dest_loss = vis.line(
            X=np.array([0]),
            Y=np.array([0]),
            opts=dict(
                title='adversarial_loss',
                xlabel='epoch',
                ylabel='loss',
                width=800,
                height=400
            )
        )
        win_coord_loss = vis.line(
            X=np.array([0]),
            Y=np.array([0]),
            opts=dict(
                title='coordinate_loss',
                xlabel='epoch',
                ylabel='loss',
                width=800,
                height=400
            )
        )

In [None]:
def visualize(vis, phase, visualized_data, window):
    vis.line(
        X=np.array([phase]),
        Y=np.array([visualized_data]),
        update='append',
        win=window
    )

### Save args

In [None]:
args_logger = Logger(args)
args_logger.save()

# Cross Validation

In [None]:
model_save_path = pj(args.model_root, "final.pth")
if os.path.exists(args.model_root) is False:
    os.makedirs(args.model_root)
if os.path.exists(args.figure_root) is False:
    os.makedirs(args.figure_root)

In [None]:
valid_num = int(1.0/args.test_ratio)
with h5py.File(args.all_data_path) as f:
    X = f["X"][:]
    Y = f["Y"][:]
_, ntests = np.unique(Y, return_counts=True)
train_idxs, test_idxs = create_validation_split(Y, args.test_ratio)
result = []
for valid_count in range(valid_num):
    xtr, ytr, xte, yte = load_validation_data(X, Y, train_idxs[valid_count], test_idxs[valid_count])
    
    if args.use_DCL is True:
        new_xtr, new_coordinate = region_confusion_mechanism(xtr, division_number=args.division_number, neighborhood_range=args.neighborhood_range)
        xtr = np.concatenate([xtr, new_xtr])
        ytr = np.concatenate([ytr, ytr])
        target_dest_or_not = np.concatenate([np.zeros(ytr.shape), np.ones(ytr.shape)])
        target_coordinate = np.concatenate([new_coordinate, np.asarray([np.arange(args.division_number**2)] * new_coordinate.shape[0])])
        xtr, ytr, target_dest_or_not, target_coordinate = create_train_data_DCL(xtr, ytr, target_dest_or_not, target_coordinate, args.rotate)
    else:
        xtr, ytr = create_train_data(xtr, ytr, args.rotate)
    
    if args.model_name=="ResNet18":
        model = ResNet18(len(args.labels), use_DCL=args.use_DCL, division_number=args.division_number, pretrain=args.pretrain, training=True).cuda()
    elif args.model_name=="ResNet34":
        model = ResNet34(len(args.labels), use_DCL=args.use_DCL, division_number=args.division_number, pretrain=args.pretrain, training=True).cuda()
    elif args.model_name=="ResNet50":
        model = ResNet50(len(args.labels), use_DCL=args.use_DCL, division_number=args.division_number, pretrain=args.pretrain, training=True).cuda()
    elif args.model_name=="ResNet101":
        model = ResNet101(len(args.labels), use_DCL=args.use_DCL, division_number=args.division_number, pretrain=args.pretrain, training=True).cuda()
    else:
        print("error! write correct model name!")
    
    if args.use_DCL is True:
        train_DCL(model, xtr, ytr, target_dest_or_not, target_coordinate, bs=args.bs, lr=args.lr, nepoch=args.nepoch, visdom=args.visdom, cls_weight=args.cls_weight, dest_weight=args.dest_weight, coord_weight=args.coord_weight)
    else:
        train(model, xtr, ytr, bs=args.bs, lr=args.lr, nepoch=args.nepoch, visdom=args.visdom)
            
    torch.save(model.state_dict(), model_save_path)
    
    model.training = False
    matrix = confusion_matrix(model, xte, yte, args.labels, bs=args.bs)
    model.training = True
    df = pd.DataFrame(matrix)
    display(df)
    if valid_count == 0:
        validation_matrix = matrix
        x_all = xte.cpu().numpy()
        y_all = yte.cpu().numpy()
    else:
        validation_matrix += matrix
        x_all = np.concatenate([x_all, xte.cpu().numpy()])
        y_all = np.concatenate([y_all, yte.cpu().numpy()])
    
    result.extend(test_classification(model, xte))

In [None]:
df = pd.DataFrame(validation_matrix)
if args.save_df is True:
    df.to_csv(pj(args.figure_root, "validation_matrix.csv"))
df

In [None]:
create_confusion_matrix(validation_matrix, ntests, args.labels, args.figure_root, save=args.save_fig)

In [None]:
each_df = compute_each_size_df(result, x_all, y_all)
if args.save_df is True:
    each_df.to_csv(pj(args.figure_root, "each_size_df.csv"))
each_df

In [None]:
all_df = compute_all_size_df(each_df)
if args.save_df is True:
    all_df.to_csv(pj(args.figure_root, "all_size_df.csv"))
all_df

In [None]:
plot_df_distrib_size(all_df, args.figure_root, save=args.save_fig)

### Load and Test model