In [1]:
import os

import cv2
import torch
from torch.utils.data import DataLoader
import torch.optim as optim
from torch.nn import CTCLoss

from CRNN_Dataset import SVHC_Dataset, SVHC_collate_fn,train_path,valid_path
from CRNN_model import CRNN
from CRNN_evaluate import evaluate
from config import train_config as config

In [2]:
def train_batch(crnn,data,optimizer,criterion,device):
    crnn.train()
    images,targets,target_lengths = [d.to(device) for d in data]
    
    logits = crnn(images)
    log_probs = torch.nn.functional.log_softmax(logits,dim=2)
    # logits：(T, N, n_class) -> (24, 32, 37)  dim = 2 在n_class轴上进行softmax 也就是说11类概率加和为1
    # log_probs维度不变
    batch_size = images.size(0)
    input_lengths = torch.LongTensor([logits.size(0)] * batch_size)
    # 生成以batch_size个len(Logits)的张量
    target_lengths = torch.flatten(target_lengths)
    
    loss = criterion(log_probs,targets,input_lengths,target_lengths)
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    return loss.item()

In [3]:
epochs = config['epochs']
train_batch_size = config['train_batch_size']
eval_batch_size = config['eval_batch_size']
lr = config['lr']
show_interval = config['show_interval']
valid_interval = config['valid_interval']
save_interval = config['save_interval']
cpu_workers = config['cpu_workers']
reload_checkpoint = config['reload_checkpoint']
valid_max_iter = config['valid_max_iter']

img_width = config['img_width']
img_height = config['img_height']


num_class = len(SVHC_Dataset.seq_to_char) + 1

In [15]:
def check_file():
    save_path = config['checkpoints_dir']

    if os.path.exists(save_path) : 
        pass 
    else:
         os.mkdir(save_path)
    return save_path 

In [16]:
check_file()

'checkpoints/'

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'device: {device}')

device: cuda


In [5]:
train_dataset = SVHC_Dataset(train_path)
train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=train_batch_size,
    shuffle=True,
    num_workers=cpu_workers,
    collate_fn=SVHC_collate_fn)
valid_dataset = SVHC_Dataset(valid_path)
valid_loader = DataLoader(
    dataset=valid_dataset,
    batch_size=eval_batch_size,
    shuffle=True,
    num_workers=cpu_workers,
    collate_fn=SVHC_collate_fn)


In [6]:
crnn = CRNN(1, img_height, img_width, num_class,
            map_to_seq_hidden=config['map_to_seq_hidden'],
            rnn_hidden=config['rnn_hidden'],
            leaky_relu=config['leaky_relu'])
if reload_checkpoint:
    crnn.load_state_dict(torch.load(reload_checkpoint, map_location=device))
crnn.to(device)

CRNN(
  (cnn): Sequential(
    (conv0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (relu0): ReLU(inplace=True)
    (pooling0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (relu1): ReLU(inplace=True)
    (pooling1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (conv2): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (relu2): ReLU(inplace=True)
    (conv3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (relu3): ReLU(inplace=True)
    (pooling2): MaxPool2d(kernel_size=(2, 1), stride=(2, 1), padding=0, dilation=1, ceil_mode=False)
    (conv4): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (batchnorm4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu4): ReLU(inplace=True)
    (conv5): Conv2d(512, 512, 

In [7]:
optimizer = optim.RMSprop(crnn.parameters(), lr=lr)
criterion = CTCLoss(reduction='sum')
criterion.to(device)

CTCLoss()

In [8]:
assert save_interval % valid_interval == 0
i = 1

In [17]:
for epoch in range(1, epochs + 1):
    print(f'epoch: {epoch}')
    tot_train_loss = 0.
    tot_train_count = 0
    for train_data in train_loader:
        loss = train_batch(crnn, train_data, optimizer, criterion, device)
        train_size = train_data[0].size(0)

        tot_train_loss += loss
        tot_train_count += train_size
        if i % show_interval == 0:
            print('train_batch_loss[', i, ']: ', loss / train_size)

        if i % valid_interval == 0:
            evaluation = evaluate(crnn, valid_loader, criterion,
                                  decode_method=config['decode_method'],
                                  beam_size=config['beam_size'])
            print('valid_evaluation: loss={loss}, acc={acc}'.format(**evaluation))

        if i % save_interval == 0:
            prefix = 'crnn'
            loss = evaluation['loss']
            save_model_path = os.path.join(config['checkpoints_dir'],
                                           f'{prefix}_{i:06}_loss{loss}.pt')
            torch.save(crnn.state_dict(), save_model_path)
            print('save model at ', save_model_path)

        i += 1

    print('train_loss: ', tot_train_loss / tot_train_count)

epoch: 1
train_batch_loss[ 2000 ]:  2.364166259765625


Evaluate: 100%|████████████████████████████████████████████████████████████████████████| 17/17 [00:08<00:00,  1.99it/s]


valid_evaluation: loss=2.3553525853839528, acc=0.48852614147149276
save model at  checkpoints/crnn_002000_loss2.3553525853839528.pt
train_batch_loss[ 2010 ]:  1.668026328086853
train_batch_loss[ 2020 ]:  1.6407315731048584
train_batch_loss[ 2030 ]:  2.3200650215148926
train_batch_loss[ 2040 ]:  2.1294784545898438
train_batch_loss[ 2050 ]:  3.013375759124756
train_batch_loss[ 2060 ]:  2.634132146835327
train_batch_loss[ 2070 ]:  2.27886962890625
train_batch_loss[ 2080 ]:  1.8047376871109009
train_batch_loss[ 2090 ]:  1.9633941650390625
train_batch_loss[ 2100 ]:  2.57022762298584
train_batch_loss[ 2110 ]:  2.056077003479004
train_batch_loss[ 2120 ]:  2.142949104309082
train_batch_loss[ 2130 ]:  2.058037757873535
train_batch_loss[ 2140 ]:  2.3490867614746094
train_batch_loss[ 2150 ]:  1.6002488136291504
train_batch_loss[ 2160 ]:  2.2955923080444336
train_batch_loss[ 2170 ]:  2.563969612121582
train_batch_loss[ 2180 ]:  1.304586410522461
train_batch_loss[ 2190 ]:  1.5196025371551514
train_

Evaluate: 100%|████████████████████████████████████████████████████████████████████████| 17/17 [00:09<00:00,  1.87it/s]


valid_evaluation: loss=2.065146939647883, acc=0.5410456588597113
train_batch_loss[ 2510 ]:  1.976383090019226
train_batch_loss[ 2520 ]:  2.278855323791504
train_batch_loss[ 2530 ]:  2.315486431121826
train_batch_loss[ 2540 ]:  1.3894059658050537
train_batch_loss[ 2550 ]:  2.960674524307251
train_batch_loss[ 2560 ]:  1.6165680885314941
train_batch_loss[ 2570 ]:  1.3557409048080444
train_batch_loss[ 2580 ]:  2.0065338611602783
train_batch_loss[ 2590 ]:  2.4578564167022705
train_batch_loss[ 2600 ]:  2.3821582794189453
train_batch_loss[ 2610 ]:  1.9949064254760742
train_batch_loss[ 2620 ]:  1.7662887573242188
train_batch_loss[ 2630 ]:  2.1986234188079834
train_batch_loss[ 2640 ]:  1.3785514831542969
train_batch_loss[ 2650 ]:  1.7360048294067383
train_batch_loss[ 2660 ]:  1.4571890830993652
train_batch_loss[ 2670 ]:  1.6219332218170166
train_batch_loss[ 2680 ]:  1.6136810779571533
train_batch_loss[ 2690 ]:  1.1545145511627197
train_batch_loss[ 2700 ]:  1.6373369693756104
train_batch_loss[ 2

Evaluate: 100%|████████████████████████████████████████████████████████████████████████| 17/17 [00:09<00:00,  1.88it/s]


valid_evaluation: loss=1.836256075503275, acc=0.6109533948426781
train_batch_loss[ 3010 ]:  1.164271593093872
train_batch_loss[ 3020 ]:  1.436570644378662
train_batch_loss[ 3030 ]:  1.9959170818328857
train_batch_loss[ 3040 ]:  1.1380832195281982
train_batch_loss[ 3050 ]:  0.7196630239486694
train_batch_loss[ 3060 ]:  2.0604944229125977
train_batch_loss[ 3070 ]:  1.3065364360809326
train_batch_loss[ 3080 ]:  1.2138125896453857
train_batch_loss[ 3090 ]:  1.6344420909881592
train_batch_loss[ 3100 ]:  1.2625234127044678
train_batch_loss[ 3110 ]:  1.4065265655517578
train_batch_loss[ 3120 ]:  0.8107227683067322
train_batch_loss[ 3130 ]:  1.7729727029800415
train_batch_loss[ 3140 ]:  1.1283236742019653
train_batch_loss[ 3150 ]:  1.3057959079742432
train_batch_loss[ 3160 ]:  1.7472319602966309
train_batch_loss[ 3170 ]:  1.5366017818450928
train_batch_loss[ 3180 ]:  2.4883275032043457
train_batch_loss[ 3190 ]:  1.127942943572998
train_loss:  1.490113369130013
epoch: 4
train_batch_loss[ 3200 ]

Evaluate: 100%|████████████████████████████████████████████████████████████████████████| 17/17 [00:10<00:00,  1.62it/s]


valid_evaluation: loss=1.6670364513446851, acc=0.6599242961911521
train_batch_loss[ 3510 ]:  1.1000984907150269
train_batch_loss[ 3520 ]:  1.4764807224273682
train_batch_loss[ 3530 ]:  0.8715772032737732
train_batch_loss[ 3540 ]:  1.2823457717895508
train_batch_loss[ 3550 ]:  0.7553696036338806
train_batch_loss[ 3560 ]:  2.115997791290283
train_batch_loss[ 3570 ]:  0.951066255569458
train_batch_loss[ 3580 ]:  1.5173060894012451
train_loss:  1.3029616009529186
epoch: 5
train_batch_loss[ 3590 ]:  0.45204824209213257
train_batch_loss[ 3600 ]:  1.8375344276428223
train_batch_loss[ 3610 ]:  1.5268807411193848
train_batch_loss[ 3620 ]:  0.7495046854019165
train_batch_loss[ 3630 ]:  0.912958562374115
train_batch_loss[ 3640 ]:  2.3940587043762207
train_batch_loss[ 3650 ]:  1.4865427017211914
train_batch_loss[ 3660 ]:  1.2332134246826172
train_batch_loss[ 3670 ]:  1.2014539241790771
train_batch_loss[ 3680 ]:  1.4372329711914062
train_batch_loss[ 3690 ]:  0.9017958641052246
train_batch_loss[ 370

Evaluate: 100%|████████████████████████████████████████████████████████████████████████| 17/17 [00:09<00:00,  1.88it/s]


valid_evaluation: loss=1.656305598738855, acc=0.6724627395315826
save model at  checkpoints/crnn_004000_loss1.656305598738855.pt
train_batch_loss[ 4010 ]:  0.600747287273407
train_batch_loss[ 4020 ]:  1.0282666683197021
train_batch_loss[ 4030 ]:  0.7218974828720093
train_batch_loss[ 4040 ]:  0.8932498693466187
train_batch_loss[ 4050 ]:  0.7901784777641296
train_batch_loss[ 4060 ]:  1.2278287410736084
train_batch_loss[ 4070 ]:  0.5419585108757019
train_batch_loss[ 4080 ]:  1.6814696788787842
train_batch_loss[ 4090 ]:  1.5476001501083374
train_batch_loss[ 4100 ]:  0.8230554461479187
train_batch_loss[ 4110 ]:  0.7138364315032959
train_batch_loss[ 4120 ]:  1.315244197845459
train_batch_loss[ 4130 ]:  0.7799280285835266
train_batch_loss[ 4140 ]:  1.0918692350387573
train_batch_loss[ 4150 ]:  0.6140811443328857
train_batch_loss[ 4160 ]:  0.7916460037231445
train_batch_loss[ 4170 ]:  1.7030596733093262
train_batch_loss[ 4180 ]:  0.7111654281616211
train_batch_loss[ 4190 ]:  0.8049297332763672

Evaluate: 100%|████████████████████████████████████████████████████████████████████████| 17/17 [00:10<00:00,  1.70it/s]


valid_evaluation: loss=1.5565544725846308, acc=0.6904423941329548
train_batch_loss[ 4510 ]:  1.168224811553955
train_batch_loss[ 4520 ]:  0.8833922147750854
train_batch_loss[ 4530 ]:  1.1184996366500854
train_batch_loss[ 4540 ]:  0.7709082365036011
train_batch_loss[ 4550 ]:  1.2420600652694702
train_batch_loss[ 4560 ]:  0.6417824625968933
train_batch_loss[ 4570 ]:  1.063063621520996
train_batch_loss[ 4580 ]:  1.4960944652557373
train_batch_loss[ 4590 ]:  1.0633370876312256
train_batch_loss[ 4600 ]:  0.6079260110855103
train_batch_loss[ 4610 ]:  0.8312145471572876
train_batch_loss[ 4620 ]:  0.7752418518066406
train_batch_loss[ 4630 ]:  0.6934607625007629
train_batch_loss[ 4640 ]:  0.6704452037811279
train_batch_loss[ 4650 ]:  0.7085787057876587
train_batch_loss[ 4660 ]:  0.7291940450668335
train_batch_loss[ 4670 ]:  0.9465459585189819
train_batch_loss[ 4680 ]:  1.508072853088379
train_batch_loss[ 4690 ]:  1.7726575136184692
train_batch_loss[ 4700 ]:  0.3856433033943176
train_batch_loss[

Evaluate: 100%|████████████████████████████████████████████████████████████████████████| 17/17 [00:10<00:00,  1.68it/s]


valid_evaluation: loss=1.5414507549819243, acc=0.7014431038561627
train_batch_loss[ 5010 ]:  0.7526815533638
train_batch_loss[ 5020 ]:  0.748944878578186
train_batch_loss[ 5030 ]:  1.5638148784637451
train_batch_loss[ 5040 ]:  0.705430805683136
train_batch_loss[ 5050 ]:  1.3062254190444946
train_batch_loss[ 5060 ]:  0.48743319511413574
train_batch_loss[ 5070 ]:  0.46019190549850464
train_batch_loss[ 5080 ]:  0.31844019889831543
train_batch_loss[ 5090 ]:  0.7319234609603882
train_batch_loss[ 5100 ]:  0.47237271070480347
train_batch_loss[ 5110 ]:  0.496817022562027
train_batch_loss[ 5120 ]:  0.7625218629837036
train_batch_loss[ 5130 ]:  0.8731170892715454
train_batch_loss[ 5140 ]:  0.6582636833190918
train_batch_loss[ 5150 ]:  0.46523067355155945
train_batch_loss[ 5160 ]:  0.5530381202697754
train_batch_loss[ 5170 ]:  0.4368952512741089
train_loss:  0.7947727947313314
epoch: 9
train_batch_loss[ 5180 ]:  0.7053214311599731
train_batch_loss[ 5190 ]:  0.7463304996490479
train_batch_loss[ 52

Evaluate: 100%|████████████████████████████████████████████████████████████████████████| 17/17 [00:09<00:00,  1.81it/s]


valid_evaluation: loss=1.5385330016093528, acc=0.703927135083984
train_batch_loss[ 5510 ]:  0.7463782429695129
train_batch_loss[ 5520 ]:  1.0107824802398682
train_batch_loss[ 5530 ]:  0.5746711492538452
train_batch_loss[ 5540 ]:  0.42605283856391907
train_batch_loss[ 5550 ]:  0.5638715028762817
train_batch_loss[ 5560 ]:  0.4725308120250702
train_batch_loss[ 5570 ]:  1.007293939590454
train_loss:  0.7061841677198123
epoch: 10
train_batch_loss[ 5580 ]:  0.2056664079427719
train_batch_loss[ 5590 ]:  0.7790564894676208
train_batch_loss[ 5600 ]:  0.358333557844162
train_batch_loss[ 5610 ]:  0.40767520666122437
train_batch_loss[ 5620 ]:  0.33409619331359863
train_batch_loss[ 5630 ]:  0.5027170777320862
train_batch_loss[ 5640 ]:  0.2921300232410431
train_batch_loss[ 5650 ]:  0.3450584411621094
train_batch_loss[ 5660 ]:  0.7478896975517273
train_batch_loss[ 5670 ]:  0.6270022392272949
train_batch_loss[ 5680 ]:  0.2787173092365265
train_batch_loss[ 5690 ]:  0.7110592126846313
train_batch_loss[ 