In [1]:
#General Imports
import torch
import torch.nn  as nn
import torch.optim as optim
import numpy as np
from torch.utils.data import DataLoader
import random

import matplotlib.pyplot as plt
from tensorboardX import SummaryWriter

#Load fake, non handwritten generator 
from IAM_dataset import hwrDataset as Dataset
#Import the loss from baidu 
from torch_baidu_ctc import CTCLoss

#Import the model 
from fully_conv_model import cnn_attention_ocr

#Helper to count params
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


#Evaluation function preds_to_integer
from evaluation import wer_eval,preds_to_integer,show,AverageMeter,my_collate

In [2]:
#Set up Tensorboard writer for current test
writer = SummaryWriter(log_dir="/home/leander/AI/repos/Pytorch-OCR-Fully-Conv//logs2/handwritten_test_4")

In [3]:
###Set up model. 
cnn=cnn_attention_ocr(model_dim=64,nclasses=93,n_layers=8)
cnn=cnn.cuda().train()
count_parameters(cnn)

3543584

In [4]:
#CTC Loss 
ctc_loss = CTCLoss(reduction="mean",average_frames=True)
#Optimizer: Good initial is 5e5 
optimizer = optim.Adam(cnn.parameters(), lr=5e-5)

In [5]:
#We keep track of the Average loss and CER 
ave_total_loss = AverageMeter()
CER_total= AverageMeter()

n_iter=0
batch_size=4

In [6]:
cnn.load_state_dict(torch.load("8_layers_continued_on_blanks_340k.pt"))

In [7]:
ds=Dataset()
trainset = DataLoader(dataset=ds,
                      batch_size=batch_size,
                      shuffle=True,
                      collate_fn=my_collate)

In [8]:
for epochs in range(10000):

    #Then we set up our own custom dataloade500r, with a custom collate, which packs the data
    #(does the padding) Should work with variable number of widths. 
    
    #Multiple worker leads to crash with CTC loss 

    gen = iter(trainset)
    print("train start")
    for i,ge in enumerate(gen):
        
        if ge[0].shape[3]<=800:

            #DONT FORGET THE ZERO GRAD!!!!
            optimizer.zero_grad()
            
            #Get Predictions, permuted for CTC loss 
            log_probs = cnn(ge[0]).permute((2,0,1))

            #Targets have to be CPU for baidu loss 
            targets = ge[1].cpu()

            #Get the Lengths/2 becase this is how much we downsample the width
            input_lengths = ge[2]/2
            target_lengths = ge[3]
            
            #Get the CTC Loss 
            loss = ctc_loss(log_probs, targets, input_lengths, target_lengths)
            
            #Then backward and step 
            loss.backward()
            optimizer.step()
            
            #Save Loss in averagemeter and write to tensorboard 
            ave_total_loss.update(loss.data.item())
            writer.add_scalar("total_loss", ave_total_loss.average(), n_iter) 
            
            
            #Here we Calculate the Character error rate
            cum_len=np.cumsum(target_lengths)
            targets=np.split(ge[1].cpu(),cum_len[:-1])
            wer_list=[]
            for j in range(log_probs.shape[1]):
                wer_list.append(wer_eval(log_probs[:,j,:][0:input_lengths[j],:],targets[j]))
            
            #Here we save an example together with its decoding and truth
            #Only if it is positive 
            
            if np.average(wer_list)>0.1 and n_iter> 10000:

                max_elem=np.argmax(wer_list)
                max_value=np.max(wer_list)
                max_image=ge[0][max_elem].cpu()
                max_target=targets[max_elem]
                
                max_target=[ds.decode_dict[x] for x in max_target.tolist()]
                max_target="".join(max_target)

                ou=preds_to_integer(log_probs[:,max_elem,:])
                max_preds=[ds.decode_dict[x] for x in ou]
                max_preds="".join(max_preds)
                
                writer.add_text("label",max_target,n_iter)
                writer.add_text("pred",max_preds,n_iter)
                writer.add_image("img",ge[0][max_elem].detach().cpu().numpy(),n_iter)
                
                #gen.close()
                #break
                
            #Might become infinite 
            if np.average(wer_list)< 10: 
                CER_total.update(np.average(wer_list))
                writer.add_scalar("CER", CER_total.average(), n_iter)
                
            n_iter=n_iter+1

train start


  return we/len(preds)


train start
train start
train start
train start
train start
train start
train start
train start
train start
train start
train start
train start
train start
train start
train start
train start
train start
train start
train start
train start
train start
train start
train start


KeyboardInterrupt: 

In [9]:
torch.save(cnn.state_dict(), "67ksteps_on_handwritten_CER012.pt")

In [10]:
##Checkpointing

In [12]:
torch.save({
            'epoch': n_iter,
            'model_state_dict': cnn.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            }, "67ksteps_on_handwritten_CER012_whole_cp.pt")

In [None]:
#get it back llike: 

checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
