In [2]:
import torch
import torch.nn as nn
import torchvision.models as models
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import Compose

import json
from os.path import join
import cv2
import numpy as np

In [3]:
class crnn(nn.Module):
    def __init__(self, hid_dim, char_dim):
        super().__init__()
        ## (batch_size,3,a,b) --> (batch_size,512,a/32,b/32)
        self.vgg16 = models.vgg16(pretrained = True).features
        ## (seq_len,batch_size,dim1) --> (seq_len,batch_size,2*hid_dim)
        self.bilstm = nn.LSTM(input_size=512*7, hidden_size=hid_dim, batch_first=False,
                             dropout=0, bidirectional=True)
        self.linear1 = nn.linear(in_features=2*hid_dim,out_features=char_dim,bias=True)
    
    def forward(self, x):
        x = self.vgg16(x)
        x = x.permute(3,0,1,2)
        size = x.size()
        z = x.view(size[0], size[1], size[2]*size[3])
        z = self.bilstm(z)[0]
        z = self.linear1(z) # out: (seq_len,batch_size,char_dim)
        
        return z

class coco_train(Dataset):
    def __init__(self, root_dir, annotation, transform=None):
        with open(join(root_dir,annotation)) as f:
            myanno = json.load(f)
        
        self.mydict = myanno['abc'] # type: str
        self.root_dir = root_dir
        self.annotation = myanno['train'] # list of dicts {'text':..., 'name':...}
        self.transform = transform
    
    def __len__(self):
        return len(self.annotation)
    
    def __getitem__(self, idx):
        
        img_name = self.annotation[idx]['name']
        text = self.annotation[idx]['text']
        
        img = cv2.imread(join(self.root_dir,img_name)) # shape: row * col * channels

        seq = self.text_to_seq(text)
        sample = { 'text':seq, 'image':img }
        if self.transform:
            sample = self.transform(sample)
            
        return sample
    
    def text_to_seq(self, text):
        return [ self.mydict.index(char) for char in text ]

    
    
# data transform
class resize:
    def __init__(self, size=(224,224)):
        self.size = size
    
    def __call__(self, sample):
        sample['image'] = cv2.resize(sample['image'], self.size)
        return sample

# strategy to combine a batch of data
def data_collate(batch):
    img = []
    seq = []
    seq_len = []
    for sample in batch:
        img.append(torch.from_numpy(sample['image'].transpose((2,0,1))).float())
        seq.extend(sample['text'])
        seq_len.append( len(sample['text']) )
    
    img = torch.stack(img)
    seq = torch.Tensor(seq).int()
    seq_len = torch.Tensor(seq_len).int()
    
    batch = {"img": img, "seq": seq, "seq_len": seq_len}
    return batch


In [None]:
transform = Compose([ resize(size=(224,224)) ])


trainset = coco_train(root_dir='cropped_COCO',annotation='desc.json', transform=transform)

#trainset.__getitem__(0)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=32,
                                          shuffle=True, num_workers=1, collate_fn=data_collate)

next(iter(trainloader))
# for i, data in enumerate(trainloader):
#         if i == 0:
#             print(data)

In [2]:
img = cv2.imread('/home/kernellabs/Desktop/crnn/cropped_COCO/test0_COCO_train2014_000000023203.jpg')
img.shape

(19, 54, 3)

In [None]:
net = crnn(128)
x = torch.randn(32,3,224,224)
out = net.forward(x)
out.size()

In [None]:
x = torch.randn(1,2,3,4)
print(type(x.size()))

print( x.permute(3,0,1,2).size() )