In [None]:
import os
import sys
import time
import random
import string

import torch
import torch.backends.cudnn as cudnn
import torch.nn.init as init
import torch.optim as optim
import torch.utils.data
import numpy as np
import torch.nn.functional as F
from nltk.metrics.distance import edit_distance

from utils import CTCLabelConverter, CTCLabelConverterForBaiduWarpctc, AttnLabelConverter, Averager
from c_dataset import custom_Batch_Balanced_Dataset,custom_dataset,AlignCollate
from model import Model
from test import validation
import easydict
global opt

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

global opt
opt = easydict.EasyDict({
    "exp_name": "test_01",
    "train_data": "/data/data/STARN/data_lmdb_release/training",
    "valid_data":"/data/data/STARN/data_lmdb_release/validation",
    "manualSeed": 1111,
    "workers": 8,
    "batch_size":1024,
    "num_iter":300000,
    "valInterval":1,
    "saved_model":'',
    "FT":False,
    "adam":True,
    "lr":0.0001,
    "beta1":0.9,
    "rho":0.95,
    "eps":1e-8,
    "grad_clip":5,
    "baiduCTC":False,
    "select_data":'ST',
    "batch_ratio":'1',
    "total_data_usage_ratio":'1.0',
    "batch_max_length":25,
    "imgW":100,
    "imgH":32,
    "rgb":False,
    "character":"0123456789abcdefghijklmnopqrstuvwxyz",
    "sensitive":False,
    "PAD":False,
    "data_filtering_off":False,
    "Transformation":"TPS",
    "FeatureExtraction":"ResNet",
    "SequenceModeling":"BiLSTM",
    "Prediction":'Attn',
    "num_fiducial":20,
    "input_channel":1,
    "output_channel":512,
    "hidden_size":256    
})





def validation(model, criterion, evaluation_loader, converter, opt):
    """ validation or evaluation """
    n_correct = 0
    norm_ED = 0
    length_of_data = 0
    infer_time = 0
    valid_loss_avg = Averager()

    for i, (image_tensors, labels) in enumerate(evaluation_loader):
        batch_size = image_tensors.size(0)
        length_of_data = length_of_data + batch_size
        image = image_tensors.to(device)
        # For max length prediction
        length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device)
        text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0).to(device)

        text_for_loss, length_for_loss = converter.encode(labels, batch_max_length=opt.batch_max_length)

        start_time = time.time()
        if 'CTC' in opt.Prediction:
            preds = model(image, text_for_pred)
            forward_time = time.time() - start_time

            # Calculate evaluation loss for CTC deocder.
            preds_size = torch.IntTensor([preds.size(1)] * batch_size)
            # permute 'preds' to use CTCloss format
            if opt.baiduCTC:
                cost = criterion(preds.permute(1, 0, 2), text_for_loss, preds_size, length_for_loss) / batch_size
            else:
                cost = criterion(preds.log_softmax(2).permute(1, 0, 2), text_for_loss, preds_size, length_for_loss)

            # Select max probabilty (greedy decoding) then decode index to character
            if opt.baiduCTC:
                _, preds_index = preds.max(2)
                preds_index = preds_index.view(-1)
            else:
                _, preds_index = preds.max(2)
            preds_str = converter.decode(preds_index.data, preds_size.data)
        
        else:
            preds = model(image, text_for_pred, is_train=False)
            forward_time = time.time() - start_time

            preds = preds[:, :text_for_loss.shape[1] - 1, :]
            target = text_for_loss[:, 1:]  # without [GO] Symbol
            cost = criterion(preds.contiguous().view(-1, preds.shape[-1]), target.contiguous().view(-1))

            # select max probabilty (greedy decoding) then decode index to character
            _, preds_index = preds.max(2)
            preds_str = converter.decode(preds_index, length_for_pred)
            labels = converter.decode(text_for_loss[:, 1:], length_for_loss)

        infer_time += forward_time
        valid_loss_avg.add(cost)

        # calculate accuracy & confidence score
        preds_prob = F.softmax(preds, dim=2)
        preds_max_prob, _ = preds_prob.max(dim=2)
        confidence_score_list = []
        for gt, pred, pred_max_prob in zip(labels, preds_str, preds_max_prob):
            if 'Attn' in opt.Prediction:
                gt = gt[:gt.find('[s]')]
                pred_EOS = pred.find('[s]')
                pred = pred[:pred_EOS]  # prune after "end of sentence" token ([s])
                pred_max_prob = pred_max_prob[:pred_EOS]


            if pred == gt:
                n_correct += 1

            '''
            (old version) ICDAR2017 DOST Normalized Edit Distance https://rrc.cvc.uab.es/?ch=7&com=tasks
            "For each word we calculate the normalized edit distance to the length of the ground truth transcription."
            if len(gt) == 0:
                norm_ED += 1
            else:
                norm_ED += edit_distance(pred, gt) / len(gt)
            '''

            # ICDAR2019 Normalized Edit Distance
            if len(gt) == 0 or len(pred) == 0:
                norm_ED += 0
            elif len(gt) > len(pred):
                norm_ED += 1 - edit_distance(pred, gt) / len(gt)
            else:
                norm_ED += 1 - edit_distance(pred, gt) / len(pred)

            # calculate confidence score (= multiply of pred_max_prob)
            try:
                confidence_score = pred_max_prob.cumprod(dim=0)[-1]
            except:
                confidence_score = 0  # for empty pred case, when prune after "end of sentence" token ([s])
            confidence_score_list.append(confidence_score)
            # print(pred, gt, pred==gt, confidence_score)

    accuracy = n_correct / float(length_of_data) * 100
    norm_ED = norm_ED / float(length_of_data)  # ICDAR2019 Normalized Edit Distance

    return valid_loss_avg.val(), accuracy, norm_ED, preds_str, confidence_score_list, labels, infer_time, length_of_data


if __name__ == '__main__':

    """ Seed and GPU setting """
    # print("Random Seed: ", opt.manualSeed)
    random.seed(opt.manualSeed)
    np.random.seed(opt.manualSeed)
    torch.manual_seed(opt.manualSeed)
    torch.cuda.manual_seed(opt.manualSeed)

    cudnn.benchmark = True
    cudnn.deterministic = True
    opt.num_gpu = torch.cuda.device_count()
    # print('device count', opt.num_gpu)
    if opt.num_gpu > 1:
        print('------ Use multi-GPU setting ------')
        print('if you stuck too long time with multi-GPU setting, try to set --workers 0')
        # check multi-GPU issue https://github.com/clovaai/deep-text-recognition-benchmark/issues/1
        opt.workers = opt.workers * opt.num_gpu
        opt.batch_size = opt.batch_size * opt.num_gpu









    numclass_path = "/data/work_dir/img/generate_text_ko/mrjaehong_text_generation/generate_img/ch_range.txt"
    f = open(numclass_path, 'r')
    ch_temp = f.read()
    f.close()
    opt.character = ch_temp

    converter = AttnLabelConverter(opt.character)
    opt.num_class = len(converter.character)



    train_dataset = custom_dataset("/data/work_dir/img/generate_text_ko/mrjaehong_text_generation/generate_img/label.csv")
    valid_dataset = custom_dataset("/data/work_dir/img/generate_text_ko/mrjaehong_text_generation/generate_img_val/label.csv")

    AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD)

    valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=opt.batch_size,
            shuffle=True,  # 'True' to check training progress with validation function.
            num_workers=int(opt.workers),
            collate_fn=AlignCollate_valid, pin_memory=True)

    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=opt.batch_size,
            shuffle=True,  # 'True' to check training progress with validation function.
            num_workers=int(opt.workers),
            collate_fn=AlignCollate_valid, pin_memory=True)


    if opt.rgb:
        opt.input_channel = 3
    model = Model(opt)


    for name, param in model.named_parameters():
        if 'localization_fc2' in name:
            print(f'Skip {name} as it is already initialized')
            continue
        try:
            if 'bias' in name:
                init.constant_(param, 0.0)
            elif 'weight' in name:
                init.kaiming_normal_(param)
        except Exception as e:  # for batchnorm.
            if 'weight' in name:
                param.data.fill_(1)
            continue


    model = torch.nn.DataParallel(model).to(device)



    criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device)

    loss_avg = Averager()


    filtered_parameters = []

    for p in filter(lambda p: p.requires_grad, model.parameters()):
        filtered_parameters.append(p)


    if opt.adam:
#         optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999))
        optimizer = optim.Adam(filtered_parameters, lr=opt.lr)
    else:
        optimizer = optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps)



    nb_epochs = 100000

    
    for epoch in range(nb_epochs + 1):
        for batch_idx, samples in enumerate(train_loader):

            log = open(f'./log_dataset.txt', 'a')


            start_time = time.time()        
            model.train()

            image_tensors, labels = samples
            image = image_tensors.to(device)
            text, length = converter.encode(labels, batch_max_length=opt.batch_max_length)




            preds = model(image, text[:, :-1])  # align with Attention.forward
            target = text[:, 1:]  # without [GO] Symbol
            cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1))



            model.zero_grad()
            cost.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip)  # gradient clipping with 5 (Default)
            optimizer.step()


    
            loss_avg.add(cost)


            for param_group in optimizer.param_groups:
                learning_rate_val=param_group['lr']
            

            ## 평가
            model.eval()
            with torch.no_grad():
                valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation(
                            model, criterion, valid_loader, converter, opt)

            end = time.time()
            loss_log = f'epoch : {epoch} [{batch_idx}/{len(train_loader)}] Train loss: {loss_avg.val():0.5f},Valid loss: {valid_loss:0.5f}, time : {end-start_time} lr : {learning_rate_val}'        
            loss_avg.reset()



            print(loss_log)

            dashed_line = '-' * 80
            head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F'
            predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n'
            for gt, pred, confidence in zip(labels[:5], preds[:5], confidence_score[:5]):
                if 'Attn' in opt.Prediction:
                    gt = gt[:gt.find('[s]')]
                    pred = pred[:pred.find('[s]')]

                predicted_result_log += f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n'
            predicted_result_log += f'{dashed_line}'
    #         print(predicted_result_log)
            
            log.write(loss_log + '\n')
            log.write(predicted_result_log + '\n')
            log.close()

------ Use multi-GPU setting ------
if you stuck too long time with multi-GPU setting, try to set --workers 0
Skip Transformation.LocalizationNetwork.localization_fc2.weight as it is already initialized
Skip Transformation.LocalizationNetwork.localization_fc2.bias as it is already initialized
히릿
epoch : 0 [0/23] Train loss: 7.82387,Valid loss: 7.84440, time : 25.346681833267212
epoch : 0 [1/23] Train loss: 7.75128,Valid loss: 7.82737, time : 7.4214088916778564
epoch : 0 [2/23] Train loss: 7.68429,Valid loss: 7.83121, time : 6.519996166229248
epoch : 0 [3/23] Train loss: 7.62678,Valid loss: 7.81342, time : 6.924203872680664
epoch : 0 [4/23] Train loss: 7.56756,Valid loss: 7.75100, time : 6.953729629516602
epoch : 0 [5/23] Train loss: 7.50743,Valid loss: 7.69104, time : 7.156440258026123
epoch : 0 [6/23] Train loss: 7.44455,Valid loss: 7.63721, time : 6.921473979949951
epoch : 0 [7/23] Train loss: 7.38463,Valid loss: 7.58631, time : 7.359449863433838
epoch : 0 [8/23] Train loss: 7.32303,

epoch : 4 [3/23] Train loss: 4.97610,Valid loss: 5.25751, time : 9.572705507278442
epoch : 4 [4/23] Train loss: 4.95769,Valid loss: 5.23176, time : 9.004140377044678
epoch : 4 [5/23] Train loss: 4.95608,Valid loss: 5.23752, time : 9.38229489326477
epoch : 4 [6/23] Train loss: 4.95287,Valid loss: 5.25015, time : 9.091065168380737
epoch : 4 [7/23] Train loss: 4.95462,Valid loss: 5.22233, time : 9.745593547821045
epoch : 4 [8/23] Train loss: 4.92882,Valid loss: 5.21727, time : 9.149464130401611
epoch : 4 [9/23] Train loss: 4.94419,Valid loss: 5.21660, time : 9.464049816131592
epoch : 4 [10/23] Train loss: 4.92030,Valid loss: 5.20642, time : 9.597691774368286
epoch : 4 [11/23] Train loss: 4.91290,Valid loss: 5.17824, time : 9.30196237564087
epoch : 4 [12/23] Train loss: 4.91529,Valid loss: 5.19655, time : 9.515069961547852
epoch : 4 [13/23] Train loss: 4.91311,Valid loss: 5.22620, time : 9.38420557975769
epoch : 4 [14/23] Train loss: 4.91082,Valid loss: 5.17951, time : 9.09241271018982
epo

epoch : 8 [10/23] Train loss: 4.69932,Valid loss: 4.90581, time : 9.589342832565308
epoch : 8 [11/23] Train loss: 4.68696,Valid loss: 4.92011, time : 9.522673845291138
epoch : 8 [12/23] Train loss: 4.70875,Valid loss: 4.91996, time : 9.526067018508911
epoch : 8 [13/23] Train loss: 4.69211,Valid loss: 4.89730, time : 9.607950925827026
epoch : 8 [14/23] Train loss: 4.68357,Valid loss: 4.88934, time : 9.642423868179321
epoch : 8 [15/23] Train loss: 4.68658,Valid loss: 4.88883, time : 9.760229349136353
epoch : 8 [16/23] Train loss: 4.69387,Valid loss: 4.88173, time : 9.42231011390686
epoch : 8 [17/23] Train loss: 4.69627,Valid loss: 4.87497, time : 9.604904651641846
epoch : 8 [18/23] Train loss: 4.69401,Valid loss: 4.87041, time : 9.795758962631226
epoch : 8 [19/23] Train loss: 4.67848,Valid loss: 4.89322, time : 9.48199987411499
epoch : 8 [20/23] Train loss: 4.69788,Valid loss: 4.89856, time : 9.586520195007324
epoch : 8 [21/23] Train loss: 4.67402,Valid loss: 4.87036, time : 9.2201402187

epoch : 12 [16/23] Train loss: 4.62086,Valid loss: 4.79553, time : 9.135483741760254
epoch : 12 [17/23] Train loss: 4.60629,Valid loss: 4.78002, time : 9.126425981521606
epoch : 12 [18/23] Train loss: 4.63308,Valid loss: 4.76828, time : 9.245548725128174
epoch : 12 [19/23] Train loss: 4.61355,Valid loss: 4.76373, time : 9.28244686126709
epoch : 12 [20/23] Train loss: 4.61808,Valid loss: 4.75828, time : 9.350959777832031
epoch : 12 [21/23] Train loss: 4.62403,Valid loss: 4.78001, time : 9.245245456695557
epoch : 12 [22/23] Train loss: 4.60527,Valid loss: 4.81458, time : 8.476383209228516
epoch : 13 [0/23] Train loss: 4.61453,Valid loss: 4.81931, time : 9.661285161972046
epoch : 13 [1/23] Train loss: 4.58779,Valid loss: 4.76199, time : 9.833369970321655
epoch : 13 [2/23] Train loss: 4.62462,Valid loss: 4.74930, time : 9.298033952713013
epoch : 13 [3/23] Train loss: 4.61561,Valid loss: 4.75154, time : 9.611822843551636
epoch : 13 [4/23] Train loss: 4.58386,Valid loss: 4.75826, time : 9.85

epoch : 16 [21/23] Train loss: 4.52497,Valid loss: 4.69981, time : 9.284336566925049
epoch : 16 [22/23] Train loss: 4.54315,Valid loss: 4.74442, time : 8.452738046646118
epoch : 17 [0/23] Train loss: 4.52730,Valid loss: 4.68392, time : 9.299189567565918
epoch : 17 [1/23] Train loss: 4.57006,Valid loss: 4.66489, time : 8.865921020507812
epoch : 17 [2/23] Train loss: 4.55980,Valid loss: 4.66174, time : 9.207290887832642
epoch : 17 [3/23] Train loss: 4.53063,Valid loss: 4.66830, time : 8.854294061660767
epoch : 17 [4/23] Train loss: 4.62047,Valid loss: 4.69318, time : 9.431920766830444
epoch : 17 [5/23] Train loss: 4.67914,Valid loss: 4.70817, time : 8.577419519424438
epoch : 17 [6/23] Train loss: 4.62573,Valid loss: 4.70675, time : 9.02492356300354
epoch : 17 [7/23] Train loss: 4.55487,Valid loss: 4.82180, time : 8.647639989852905
epoch : 17 [8/23] Train loss: 4.58497,Valid loss: 4.98485, time : 9.052767992019653
epoch : 17 [9/23] Train loss: 4.60110,Valid loss: 4.95437, time : 8.6570320

epoch : 21 [3/23] Train loss: 4.52832,Valid loss: 4.62714, time : 10.263680696487427
epoch : 21 [4/23] Train loss: 4.49341,Valid loss: 4.61731, time : 10.136961221694946
epoch : 21 [5/23] Train loss: 4.49688,Valid loss: 4.61668, time : 10.272418975830078
epoch : 21 [6/23] Train loss: 4.51313,Valid loss: 4.61073, time : 9.849257230758667
epoch : 21 [7/23] Train loss: 4.48669,Valid loss: 4.62374, time : 9.663696527481079
epoch : 21 [8/23] Train loss: 4.49491,Valid loss: 4.64157, time : 9.489737749099731
epoch : 21 [9/23] Train loss: 4.52398,Valid loss: 4.63214, time : 9.851526260375977
epoch : 21 [10/23] Train loss: 4.48719,Valid loss: 4.60664, time : 9.35289478302002
epoch : 21 [11/23] Train loss: 4.50282,Valid loss: 4.60332, time : 9.894147157669067
epoch : 21 [12/23] Train loss: 4.51417,Valid loss: 4.60831, time : 10.069743394851685
epoch : 21 [13/23] Train loss: 4.48324,Valid loss: 4.62783, time : 10.375164270401001
epoch : 21 [14/23] Train loss: 4.49384,Valid loss: 4.62815, time : 9

epoch : 25 [8/23] Train loss: 4.43838,Valid loss: 4.55676, time : 9.924730777740479
epoch : 25 [9/23] Train loss: 4.47516,Valid loss: 4.59972, time : 10.314893007278442
epoch : 25 [10/23] Train loss: 4.48433,Valid loss: 4.59810, time : 10.132354259490967
epoch : 25 [11/23] Train loss: 4.45801,Valid loss: 4.56623, time : 10.54594087600708
epoch : 25 [12/23] Train loss: 4.44031,Valid loss: 4.58142, time : 9.806257247924805
epoch : 25 [13/23] Train loss: 4.46602,Valid loss: 4.58643, time : 10.330292701721191
epoch : 25 [14/23] Train loss: 4.45242,Valid loss: 4.56205, time : 10.113980054855347
epoch : 25 [15/23] Train loss: 4.44880,Valid loss: 4.58113, time : 9.902518510818481
epoch : 25 [16/23] Train loss: 4.48716,Valid loss: 4.58862, time : 10.017881155014038
epoch : 25 [17/23] Train loss: 4.45232,Valid loss: 4.57322, time : 9.78641152381897
epoch : 25 [18/23] Train loss: 4.43919,Valid loss: 4.57272, time : 10.275869607925415
epoch : 25 [19/23] Train loss: 4.47362,Valid loss: 4.56769, ti

In [None]:
image_tensors, labels = train_dataset.get_batch()

In [None]:
image_tensors.shape

In [None]:
imzcage,dd = next(iter(valid_loader))
print(type(imzcage),type(dd))
print(imzcage.shape)

In [None]:
https://wikidocs.net/57165
    