In [1]:
from openc2seg import *
from base import *
import sys
import gc
import os
from os import listdir
from os.path import isfile, join
import joblib 
import numpy as np
from datetime import datetime
import time
from torch import optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
import copy
import warnings
warnings.filterwarnings("ignore")

In [2]:
# values can be 'vaihingen', 'potsdam'
dataset = 'vaihingen'

# for vaihingen and potsdam hidden can be: 0,1,2,3 or 4
hidden = 0

# task can be: train, prepare_evaluation, define_thresholds, eval or all
task = 'all'

epochs=10

# all other configuration settings can be modified inside parse_arguments function.

In [3]:
try:
    args = parse_arguments()
except:
    sys.argv=['']
    args = parse_arguments()

args, exp_name, pretrained_path, datestr, final_outp_path, images_path_roc, images_path_trainglog, metrics_path, pretreined_path_closedset, images_path, charts_path  = config_execution(args,dataset,hidden,task, epochs=epochs)    
train_loader,valtrain_loader,test_loader,val_threshold_loader = get_loaders(args)

train_loader,valtrain_loader,test_loader,val_threshold_loader = get_loaders(args)

gc.collect()
alpha = [float(x) for x in args.alphas.split(',')]
alpha = alpha[0]
args.alpha=alpha
finalblock = args.final_extra_convs_block.split(',')[0]
norm = args.normalizations.split(',')[0]
activation = args.activations.split(',')[0]

ARQUIVO PRETREINO EXISTE?  True
cuda:0
loader in memory:  True


In [4]:

net = OpenC2Seg(4, num_classes=args.n_classes, hidden_classes=args.hidden,
               norm=norm, activation=activation, finalextraconvsblock=int(finalblock), alpha=args.alpha, args=args).to(args.device)

net_unique_name = net.unique_name()
save_name = net.unique_name("_")+ "_"+args.model+ "_"+ str(args.select_non_match)+"_" + str(args.ignore_others_non_match) +"_"+args.dataset+ "_" + str(args.epochs)
save_name = save_name.strip(' ').replace(",","_").replace("[","").replace("]","")

final_model_path = os.path.join(args.ckpt_path, exp_name, 'model_os_' + save_name + '_final.pth')
best_model_path = os.path.join(args.ckpt_path, exp_name, 'model_os_' + save_name + '_best.pth')

save_name = save_name+"_"+args.eval_type

train_log_name='trainlog_'+save_name+'_'+datestr
train_log_path = os.path.join(final_outp_path, train_log_name +'.csv')
train_log_image = os.path.join(images_path_trainglog, train_log_name +'.jpg')
images_path_roc_filename=os.path.join(images_path_roc,save_name+'.jpg')

thresholds_values = []

if args.train_model:
    #p.nice(5)
    net.initialize()
    net.to(args.device)
    print('total parameters:',count_parameters(net))
    if args.load_model:
        print('Loading pretrained weights from file "' + best_model_path + '"')
        if os.path.isfile(best_model_path):
            if torch.cuda.is_available():
                net.load_state_dict(torch.load(best_model_path))
            else:
                net.load_state_dict(torch.load(best_model_path,map_location=torch.device('cpu') ))
            print('full model loaded')
        else:
            print("ERROR, NOT LOADED!")
    else:
        load_my_state_dict(net, torch.load(pretreined_path_closedset))  
        print('pretreined closed set loaded')  

    optimizer = optim.Adam([
                            {'params': [param for name, param in net.named_parameters() if name[-4:] == 'bias'],
                             'lr': 2 * args.lr},
                            {'params': [param for name, param in net.named_parameters() if name[-4:] != 'bias'],
                             'lr': args.lr, 'weight_decay': args.weight_decay}
                        ], betas=(args.momentum, 0.99))

    scheduler = optim.lr_scheduler.StepLR(optimizer, args.epochs // 3, 0.2)
    curr_epoch=1 
    if args.only_new_model and os.path.isfile(best_model_path):
        print('modelo ja treinado!')
        exit()
    # Lists for losses and metrics.
    maxValAUC = 0
    minTrainLoss = 999999999999
    best_epoch=1

    validate_train_auc = 0

    print("iniciando treino")
    for epoch in range(curr_epoch, args.epochs + 1):
        tic_train=time.time()
        gc.collect()
        train_loss = train_os(train_loader, net, loss_custom, optimizer, epoch, 
                              args.num_known_classes, args.num_unknown_classes, 
                              args.hidden, args, train_log_path)

        train_loss_mean = train_loss[:].mean()
        train_loss_std = train_loss[:].mean()                            

        if train_loss_mean<minTrainLoss:
            print("melhor modelo de treino!")                            
            minTrainLoss = train_loss_mean

        train_auc=0
        test_auc=0

        if (epoch%int(args.validaterate)==0) and epoch>1:

            roc_auc_metrics, flatten_full_msks, flatten_full_trues, flatten_full_prds, flatten_full_minlosses = validate_train(valtrain_loader, net, args, args.thresholds, False)

            validate_train_auc = roc_auc_metrics[4]

            del roc_auc_metrics, flatten_full_msks, flatten_full_trues, flatten_full_prds, flatten_full_minlosses

            print("Best (",str(best_epoch),") AUC: ", str(maxValAUC))

        f = open(train_log_path, "a")
        f.write('train;'+net.unique_name()+ ';' +str(epoch)+ ';' + str(train_loss_mean)+ ';' + str(train_loss_std) + ';' + 'validation_train;' + str(validate_train_auc) + ';'+ str(train_auc) + ';'+ str(test_auc) + ';' + datestr+';\n')
        f.close()

        if validate_train_auc>maxValAUC:
            print("melhor modelo de validacao!")

            best_model = copy.deepcopy(net.state_dict())
            maxValAUC = validate_train_auc
            best_epoch=epoch

        if epoch-best_epoch>args.early_stop:
            break

        scheduler.step()
        print("tempo total da epoca:", time.time()-tic_train)

    print("Salvando modelos!")

    torch.save(best_model, best_model_path)

    print('salvando trainlog chart.')
    l = logloss(train_log_path, prefix=exp_name.replace('_',' - '), savefile=True, savefilepath=train_log_image)
    l.plot_chart()

    del net
    gc.collect()

if args.prep_eval:
    print(args.thresholds)

    print('Loading pretrained weights from file "' + best_model_path + '"')
    net = OpenC2Seg(4, num_classes=args.n_classes, hidden_classes=args.hidden,
               norm=norm, activation=activation, finalextraconvsblock=int(finalblock), alpha=args.alpha, args=args).to(args.device)
    net.initialize()

    if os.path.isfile(best_model_path):
        if torch.cuda.is_available():
            net.load_state_dict(torch.load(best_model_path))
            net.cuda()
        else:
            net.load_state_dict(torch.load(best_model_path,map_location=torch.device('cpu') ))
        print('loaded')
    else:
        print("ERROR, NOT LOADED!")
        exit()

    print("gerando full images do melhor modelo")
    full_imgs, full_msks, full_trues, full_prds, full_outs, full_minlosses = get_predictions_reconstructions(test_loader, net, args)
    gc.collect()

    best_model_suffix = '_best'

    if args.save_images:
        print("salvando imagens! "+ best_model_path)
        save_images(full_msks, images_path, save_name+"_full_msk"+best_model_suffix)
        save_images(full_trues, images_path, save_name+"_full_true"+best_model_suffix)
        save_images(full_prds, images_path, save_name+"_full_prd"+best_model_suffix)
        save_images(full_minlosses, images_path, save_name+"_full_minloss"+best_model_suffix, norm=True)

    del full_imgs, full_msks, full_trues, full_prds, full_outs, full_minlosses

if args.prep_thresholds:
    net = OpenC2Seg(4, num_classes=args.n_classes, hidden_classes=args.hidden,
               norm=norm, activation=activation, finalextraconvsblock=int(finalblock), alpha=args.alpha, args=args).to(args.device)
    net.initialize()
    print('generating thresholds: ',alpha)
    print('loading best model: ',best_model_path)
    if os.path.isfile(best_model_path):
        if torch.cuda.is_available():
            net.load_state_dict(torch.load(best_model_path))
            net.cuda()
        else:
            net.load_state_dict(torch.load(best_model_path,map_location=torch.device('cpu') ))
        print('loaded')
    else:
        print("ERROR, NOT LOADED!")
        exit()

    roc_auc_metrics, flatten_full_msks, flatten_full_trues, flatten_full_prds, flatten_full_minlosses = validate_train(val_threshold_loader, net, args, args.thresholds, True)
    thresholds_values = get_quantiles(flatten_full_minlosses, flatten_full_msks, flatten_full_prds, args.thresholds)
    np.savez(os.path.join(images_path,save_name+"_thresholds_values.npz"), thresholds_values=thresholds_values)         

    del net,roc_auc_metrics, flatten_full_msks, flatten_full_trues, flatten_full_prds, flatten_full_minlosses
if args.eval_model: 
    gc.collect()                        

    best_model_suffix = '_best'

    print(metrics_path.replace(".csv",best_model_suffix+".csv"))

    full_msks,full_trues,full_prds,full_minlosses = load_images_array(images_path, save_name, best_model_suffix, test_loader)

    if len(thresholds_values)!=len(args.thresholds):
        threshold_path=os.path.join(images_path,save_name+"_thresholds_values.npz")

        while os.path.isfile(threshold_path)==False:
            time.sleep(30) 
            print("waiting: ",threshold_path)

        thresholds_values = np.load(threshold_path)['thresholds_values']
        print(thresholds_values)

    flatten_full_msks  = get_flatten_image(full_msks)
    flatten_full_trues = get_flatten_image(full_trues)
    flatten_full_prds  = get_flatten_image(full_prds)
    flatten_full_minlosses = get_flatten_image(full_minlosses)
    get_os_metrics2(net_unique_name, args, flatten_full_minlosses, flatten_full_trues, flatten_full_prds, [], args.num_known_classes, args.thresholds, thresholds_values, metrics_path.replace(".csv",best_model_suffix+".csv"), images_path_roc_filename.replace(".jpg",best_model_suffix+".jpg"))
 
print("FIM!")

/mnt/DADOS_GRENOBLE_1/ian/openseg/outputs/OpenC2Seg_unet_Vaihingen_0/metrics_val_train_20220209225638_best.csv
[0.43645944 0.45236336 0.46951624 0.4890776  0.51550144 0.52219268
 0.52944157 0.53739147 0.54598134 0.55568879 0.56659153 0.57953048
 0.59635747 0.62330713 0.84267704]
    Computing AUC ROC...
    AUC:  0.8336444561498698
/mnt/DADOS_GRENOBLE_1/ian/openseg/outputs/charts/OpenC2Seg_unet_Vaihingen_0/roc/OpenC2Seg_4_5_0_64_True_batch_relu_all_0_0_95_0_001_OpenC2Seg_random_False_Vaihingen_5_val_train_best.jpg
(array([1.88856664, 0.88856664, 0.87147778, ..., 0.07083319, 0.07078827,
       0.02850567]), array([0.00000000e+00, 0.00000000e+00, 6.79981613e-07, ...,
       9.99999830e-01, 1.00000000e+00, 1.00000000e+00]), array([0.00000000e+00, 5.77844528e-08, 5.77844528e-08, ...,
       9.99996475e-01, 9.99996475e-01, 1.00000000e+00]), [3662237, 4132383, 4687162, 5378819, 6241698, 6450748, 6681453, 6939881, 7224951, 7537398, 7849450, 8186665, 8589716, 9137660, -1], 0.8336444561498698)
