In [1]:
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F

import torchvision.models as models
from torch.utils.data.dataset import Dataset
from torchvision import transforms

In [2]:
HIDDEN_DIM = 512
OUTPUT_DIM = 13
MAX_LEN = 5
BATCH_SIZE = 32

In [3]:
def one_hot_embedding(labels, num_classes):
    """Embedding labels to one-hot form.

    Args:
      labels: (LongTensor) class labels, sized [N,].
      num_classes: (int) number of classes.

    Returns:
      (tensor) encoded labels, sized [N, #classes].
    """
    y = torch.eye(num_classes) 
    return y[labels] 

In [4]:
class CRNN(nn.Module):
    def __init__(self, backbone):
        super(CRNN, self).__init__()
        self.backbone = backbone
        self.linear2 = nn.Linear(HIDDEN_DIM, MAX_LEN)
        self.lstm = nn.LSTM(OUTPUT_DIM, HIDDEN_DIM, batch_first=True)
        self.out = nn.Linear(HIDDEN_DIM, OUTPUT_DIM)
        
    def forward(self, x, target):
        latent = self.backbone(x)
        length = self.linear2(latent)
        inputs = torch.zeros(BATCH_SIZE, 1, OUTPUT_DIM)
        hidden = (latent.unsqueeze(0), torch.zeros(1, BATCH_SIZE, HIDDEN_DIM))
        number = []
        
        for i in range(target.size(1)):
            output, hidden = self.lstm(inputs, hidden)
            digit = self.out(output[:, -1, :])
            number.append(digit)
            inputs = target[:, i, :].unsqueeze(1)
            
        output, hidden = self.lstm(inputs, hidden)
        digit = self.out(output[:, -1, :])
        number.append(digit)
        return length, number
    
    def evaluate(self, x):
        pass

In [5]:
resnet50 = models.resnet50(pretrained=True)
resnet50.fc = nn.Linear(2048, HIDDEN_DIM)

model = CRNN(resnet50)

Downloading: "https://download.pytorch.org/models/resnet50-19c8e357.pth" to /Users/admin/.torch/models/resnet50-19c8e357.pth
102502400it [00:12, 8276354.36it/s]


In [97]:
def one_hot_digits(digits):
    
    cleaned_digits = []
    for digit in digits:
        digit = digit.replace('.', '10')
        digit = digit.replace(',', '11')
        digit = digit.replace('-', '12')
        cleaned_digits.append(int(digit))

    cleaned_digits.append(13)
    print(cleaned_digits)
    
    return one_hot_embedding(cleaned_digits, 14)

In [99]:
class OCRDataset(object):
    def __init__(self, df):
        self.df = d
        self.images = self.df.values[:, 0]
        self.labels = self.df.values[:, 1]
        self.length = len(self.df.index)
    
    def __getitem__(self, index):
        image_path = self.images[index]
        image = Image.open(image_path)
        image = torch.Tensor(image).float()
        
        label = self.labels[index]
        label = one_hot_digits(label)
        
        return (image, label)

    def __len__(self):
        return self.len

In [100]:
model(torch.zeros(BATCH_SIZE, 3, 224, 224), torch.zeros(BATCH_SIZE, 5, 13))

(tensor([[-0.1064, -0.4131, -0.0072, -0.2683, -0.2053],
         [-0.1064, -0.4131, -0.0072, -0.2683, -0.2053],
         [-0.1064, -0.4131, -0.0072, -0.2683, -0.2053],
         [-0.1064, -0.4131, -0.0072, -0.2683, -0.2053],
         [-0.1064, -0.4131, -0.0072, -0.2683, -0.2053],
         [-0.1064, -0.4131, -0.0072, -0.2683, -0.2053],
         [-0.1064, -0.4131, -0.0072, -0.2683, -0.2053],
         [-0.1064, -0.4131, -0.0072, -0.2683, -0.2053],
         [-0.1064, -0.4131, -0.0072, -0.2683, -0.2053],
         [-0.1064, -0.4131, -0.0072, -0.2683, -0.2053],
         [-0.1064, -0.4131, -0.0072, -0.2683, -0.2053],
         [-0.1064, -0.4131, -0.0072, -0.2683, -0.2053],
         [-0.1064, -0.4131, -0.0072, -0.2683, -0.2053],
         [-0.1064, -0.4131, -0.0072, -0.2683, -0.2053],
         [-0.1064, -0.4131, -0.0072, -0.2683, -0.2053],
         [-0.1064, -0.4131, -0.0072, -0.2683, -0.2053],
         [-0.1064, -0.4131, -0.0072, -0.2683, -0.2053],
         [-0.1064, -0.4131, -0.0072, -0.2683, -0

In [101]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
dataset = OCRDataset()