In [14]:
#General Imports
import torch
import torch.nn  as nn
import torch.optim as optim
import numpy as np
from torch.utils.data import DataLoader

import matplotlib.pyplot as plt
from tensorboardX import SummaryWriter

#Load fake, non handwritten generator 
from fake_texts.pytorch_dataset_fake import Dataset,my_collate,AverageMeter

#Import the loss from baidu 
from torch_baidu_ctc import CTCLoss

#Import the model 
from fully_conv_model import cnn_attention_ocr

#Helper to count params
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

#Evaluation function 
from evaluation import wer_eval

In [7]:
#Set up Tensorboard writer for current test
writer = SummaryWriter(log_dir="/home/leander/AI/repos/OCR-CNN/logs2/test__")

In [8]:
###Set up model. 
cnn=cnn_attention_ocr(model_dim=64,nclasses=93,n_layers=16)
cnn=cnn.cuda().train()
count_parameters(cnn)

9925152

In [9]:
#CTC Loss 
ctc_loss = CTCLoss(reduction="mean",average_frames=True)
#Optimizer: Good initial is 5e5 
optimizer = optim.Adam(cnn.parameters(), lr=5e-4)

In [10]:
#We keep track of the Average loss and CER 
ave_total_loss = AverageMeter()
CER_total= AverageMeter()

n_iter=0
batch_size=2

In [11]:
#cnn.load_state_dict(torch.load("contiued_training_on_heavy_agment_CER4%_22_6kite.pt"))

In [12]:
m=cnn.state_dict()

In [15]:
for epochs in range(10000):
    #Initing the dataset actually downloads a bunch of data and creats the images The generator actually has a bunch more 
    #Options
    
    print("getting data")
    ds=Dataset(batch_size,epoch_size=1000,random_strings=True,num_words=5)
    #Then we set up our own custom dataloader, with a custom collate, which packs the data
    #(does the padding) Should work with variable number of widths. 
    
    #Multiple worker leads to crash with CTC loss 
    trainset = DataLoader(dataset=ds,
                          batch_size=batch_size,
                          shuffle=False,
                          collate_fn=my_collate)
    gen = iter(trainset)
    print("train start")
    for ge in gen:
        
        #to avoid OOM 
        if ge[0].shape[3]<800:

            #DONT FORGET THE ZERO GRAD!!!!
            optimizer.zero_grad()
            
            #Get Predictions, permuted for CTC loss 
            log_probs = cnn(ge[0]).permute((2,0,1))

            #Targets have to be CPU for baidu loss 
            targets = ge[1].cpu()

            #Get the Lengths/2 becase this is how much we downsample the width
            input_lengths = ge[2]/2
            target_lengths = ge[3]
            
            #Get the CTC Loss 
            loss = ctc_loss(log_probs, targets, input_lengths, target_lengths)
            
            #Then backward and step 
            loss.backward()
            optimizer.step()
            
            #Save Loss in averagemeter and write to tensorboard 
            ave_total_loss.update(loss.data.item())
            writer.add_scalar("total_loss", ave_total_loss.average(), n_iter) 
            
            
            #Here we Calculate the Character error rate
            cum_len=np.cumsum(target_lengths)
            targets=np.split(ge[1].cpu(),cum_len[:-1])
            wer_list=[]
            for j in range(log_probs.shape[1]):
                wer_list.append(wer_eval(log_probs[:,j,:][0:input_lengths[j],:],targets[j]))
            
            #Might become infinite 
            if np.average(wer_list)< 10: 
                CER_total.update(np.average(wer_list))
                writer.add_scalar("CER", CER_total.average(), n_iter)
                
            n_iter=n_iter+1

getting data
train start


  return we/len(preds)


KeyboardInterrupt: 

In [17]:
torch.save(m,"16_layers_5e4_training_2k_bs_2.pt")