In [1]:
from loadlibs import *
import functions
import modules
# from configs import OUTPUT, CFG, TOOLS, DIRECTORY, INFERENCE_MODELS
CFG = {
    # train setting
    'EPOCHS':20,
    'WARM_UP_EPOCHS':6,
    'BATCH_SIZE': 128,
    'VISIBLE_TQDM' : True,
    'PARALLEL' : False,
    'MIXED_PRECISION' : True, 
    'SEED': None,
    'LEARNING_RATE':0.003,
    'CONTINUE' : False,
    'CONTINUE2' : False,
    
    # augmentation
    'CUTMIX' : True, 
    'ONES' : False,
    'ROTATION': False,
    'TRANSLATION': False, 
    'NUM_AUGS_TYPE': 0, 
    'NUM_AUGS' : 2,
    
    # image
    'IMG_HEIGHT_SIZE':64,
    'IMG_WIDTH_SIZE':192,
    'GRAY_SCALE': True,
    'INPUT_CHANNEL': 1,

    # data loading 
    'NUM_WORKERS':2,
    'TEST_SIZE':0.2,
    
    # model
    'NUM_FIDUCIAL': 20,
    
        # cnn
    'PRETRAINED' : True,
    'CNN_TYPE': None, 
    'CNN_OUTPUT': None,            

        # rnn (transformer)
    # 'SEQ_HIDDEN_SIZE': 768,
    'SEQ_HIDDEN_SIZE': 1024,
    'SEQ_TYPE' : None,
    'SEQ_ACTIVATION' : torch.nn.GELU(approximate='tanh'),
    'SEQ_NUM_LAYERS': 2,
    'SEQ_BIDIRECTIONAL': True,
    'NUM_HEADS': 4,
        
            # rnn -> conformer encoder

        # pred    
    'NUM_CLASS': 2350,

    # etc
    'DROPOUT' : 0.2,
    'DROPPATH' : 0.2,
    'DROPBLOCK' : 0.2, 
    'DEVICE' :  'cuda' if torch.cuda.is_available() else 'cpu',
    'NUM_DEVICE' : None,
}

DIRECTORY = {
    'TRAIN_DIR' : "/home/gyuseonglee/workspace/2301_OCR/data/train.csv",
    'TEST_DIR'  : "/home/gyuseonglee/workspace/2301_OCR/data/test.csv",
    'SUBMIT_DIR' : "/home/gyuseonglee/workspace/2301_OCR/data/sample_submission.csv",
    
    'TRAIN_IMAGE_DIR' : "/home/gyuseonglee/workspace/2301_OCR/data/train",
    'TEST_IMAGE_DIR'  : "/home/gyuseonglee/workspace/2301_OCR/data/test",
    
    'CUTMIX' : "/home/gyuseonglee/workspace/2301_OCR/data/cutmix.csv",
    'ONES' : "/home/gyuseonglee/workspace/2301_OCR/data/ones.csv",
    'TWOS' : "/home/gyuseonglee/workspace/2301_OCR/data/twos.csv",
    'THREES' : "/home/gyuseonglee/workspace/2301_OCR/data/threes.csv",
    
}

train, valid, test, submit = functions.prepare_data(CFG, DIRECTORY)
idx2char, char2idx = functions.prepare_vocab(train)
_, _, test_loader = functions.prepare_loader(train, valid, test, CFG)
today = datetime.datetime.strftime(datetime.datetime.today(), '%y%m%d')  

In [2]:
def inference(models, test_loader, idx2char, configs):
    global cur_probs, probs, yhat
    print(f"-- number of models : {len(models)}")
    def decode_predictions(text_batch_logits):
        text_batch_tokens = F.softmax(text_batch_logits, 2).argmax(2) # [T, batch_size]
        text_batch_tokens = text_batch_tokens.numpy().T # [batch_size, T]
        text_batch_tokens_new = []
        for text_tokens in text_batch_tokens:
            text = [idx2char[idx] for idx in text_tokens]
            text = "".join(text)
            text_batch_tokens_new.append(text)
        return text_batch_tokens_new
    
    probs = []
    for idx in range(len(models)):
        model = torch.load(models[idx], map_location=configs['DEVICE']).module
        model = model.to(configs['DEVICE'])
        if configs['PARALLEL'] == True:
            model = torch.nn.parallel.DataParallel(model)
        model.eval()
        cur_probs = []
        probs_sum = torch.zeros(12, 74121, 2350)
        with torch.no_grad():
            if configs['VISIBLE_TQDM']:
                test_iterator = tq(test_loader)
            else:
                test_iterator = test_loader
            
            for batch in test_iterator:
                batch = batch.to(configs['DEVICE'])        
                yhat, _ = model(batch)
                if configs['PARALLEL'] == True:
                    yhat = yhat.permute(1,0,2)
                yhat = yhat.cpu().detach()
                cur_probs.append(yhat)
        cur_probs = torch.vstack(cur_probs).permute(1,0,2)
        probs_sum.add_(cur_probs)
        del cur_probs
    probs_sum/=len(models)    
    
    pred = decode_predictions(probs_sum.cpu()) 
       
    return pred


def inference_single(model, test_loader, idx2char, configs):
    global preds, yhats, text_batch_pred, yhat
    def decode_predictions(yhat):
        text_batch_tokens = F.softmax(yhat, 2).argmax(2) # [T, batch_size]
        # print(f"test_batch_tokens.shape (should be [T, batch_size]) : {text_batch_tokens.shape}")
        text_batch_tokens = text_batch_tokens.numpy().T # [batch_size, T]
        # print(f"test_batch_tokens.shape (should be [batch_size, T]) : {text_batch_tokens.shape}")
        
        text_batch_tokens_new = []
        for text_tokens in text_batch_tokens:
            text = [idx2char[idx] for idx in text_tokens]
            text = "".join(text)
            text_batch_tokens_new.append(text)
        return text_batch_tokens_new
    model.eval()
    preds = []
    yhats = []
    with torch.no_grad():
        test_iterator = tq(test_loader) if configs['VISIBLE_TQDM'] else test_loader
        
        for image_batch in test_iterator:
            image_batch = image_batch.to(configs['DEVICE'])
            yhat, _ = model(image_batch)
            if configs['PARALLEL'] == True:
                yhat = yhat.permute(1,0,2)
            yhat = yhat.cpu()
            text_batch_pred = decode_predictions(yhat)
            
#             preds.extend(text_batch_pred)
            yhats.append(yhat)

    return preds, yhats

In [6]:
def correct_prediction(word):
    def remove_duplicates(text):
        if len(text) > 1:
            letters = [text[0]] + [letter for idx, letter in enumerate(text[1:], start=1) if text[idx] != text[idx-1]]
        elif len(text) == 1:
            letters = [text[0]]
        else:
            return ""
        return "".join(letters)
        
    parts = word.split("-")
    parts = [remove_duplicates(part) for part in parts]
    corrected_word = "".join(parts)
    return corrected_word

In [3]:
INFER_MODELS = [
    "/home/gyuseonglee/workspace/2301_OCR/Aug-RegNet-TransformerDecoder_con20_42/model.pt",
    "/home/gyuseonglee/workspace/2301_OCR/Aug-RegNet-TransformerDecoder_con20_1203/model.pt"
]

In [4]:
pred = inference(INFER_MODELS, test_loader, idx2char, CFG)

-- number of models : 2


  0%|          | 0/580 [00:00<?, ?it/s]

  0%|          | 0/580 [00:00<?, ?it/s]

In [9]:
submit['label'] = pred
submit['label'] = submit['label'].apply(correct_prediction)
submit.to_csv(f'voting_submission_{today}.csv', index=False)

In [None]:
cur_model = torch.load(INFER_MODELS[0], map_location='cpu').module
cur_model.to('cuda')

In [None]:
len(yhats[1])

In [None]:
out = yhats[1]

In [None]:
out[0].shape

In [None]:
out[-1].shape

In [None]:
vstacked = torch.vstack(out)
vstacked.shape

In [None]:
vstacked.permute(1,0,2).shape
