In [7]:
# ! pip install lmdb
# ! pip install fire
# ! pip install opencv-python

Collecting opencv-python
  Using cached opencv_python-4.6.0.66-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (60.9 MB)
Collecting numpy>=1.17.3
  Using cached numpy-1.23.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (17.1 MB)
Installing collected packages: numpy, opencv-python
Successfully installed numpy-1.23.3 opencv-python-4.6.0.66


In [119]:
import fire
import os
import lmdb
import cv2

import numpy as np
import lmdb
import torch
import torch.nn as nn
from torch.utils.data import Dataset, ConcatDataset, Subset
import torch.optim as optim

In [52]:
imgH  = 224 
imgW = 224

In [53]:
def checkImageIsValid(imageBin):
    if imageBin is None:
        return False
    imageBuf = np.frombuffer(imageBin, dtype=np.uint8)
    img = cv2.imdecode(imageBuf, cv2.IMREAD_GRAYSCALE)
    imgH, imgW = img.shape[0], img.shape[1]
    if imgH * imgW == 0:
        return False
    return True


def writeCache(env, cache):
    with env.begin(write=True) as txn:
        for k, v in cache.items():
            txn.put(k, v)


def createDataset(inputPath = "/home/jezjamez/work/Assignment1/data_lmdb_release/training/mjsynth/mnt/ramdisk/max/90kDICT32px/",
                  gtFile = "/home/jezjamez/work/Assignment1/data_lmdb_release/training/mjsynth/mnt/ramdisk/max/90kDICT32px/annotation.txt",
                  outputPath = "/home/jezjamez/work/Assignment1/data_lmdb_release/", checkValid=True):

    os.makedirs(outputPath, exist_ok=True)
    env = lmdb.open(outputPath, map_size=1099511627776)
    cache = {}
    cnt = 1

    with open(gtFile, 'r', encoding='utf-8') as data:
        datalist = data.readlines()

    nSamples = len(datalist)
    for i in range(nSamples):
        
        imagePath, label = datalist[i].strip('\n').split(' ')

        imagePath = os.path.join(inputPath, imagePath)
        if not os.path.exists(imagePath):
            print('%s does not exist' % imagePath)
            continue
        with open(imagePath, 'rb') as f:
            imageBin = f.read()
        if checkValid:
            try:
                if not checkImageIsValid(imageBin):
                    print('%s is not a valid image' % imagePath)
                    continue
            except:
                print('error occured', i)
                with open(outputPath + '/error_image_log.txt', 'a') as log:
                    log.write('%s-th image data occured error\n' % str(i))
                continue

        imageKey = 'image-%09d'.encode() % cnt
        labelKey = 'label-%09d'.encode() % cnt
        cache[imageKey] = imageBin
        cache[labelKey] = label.encode()

        if cnt % 1000 == 0:
            writeCache(env, cache)
            cache = {}
            print('Written %d / %d' % (cnt, nSamples))
        cnt += 1
    nSamples = cnt-1
    cache['num-samples'.encode()] = str(nSamples).encode()
    writeCache(env, cache)
    print('Created dataset with %d samples' % nSamples)

In [54]:
# fire.Fire(createDataset)

In [55]:
class LmdbDataset(Dataset):

    def __init__(self, root, ):

        self.root = root
        self.env = lmdb.open(root, max_readers=32, readonly=True, lock=False, readahead=False, meminit=False)
        if not self.env:
            print('cannot create lmdb from %s' % (root))
            sys.exit(0)

        with self.env.begin(write=False) as txn:
            nSamples = int(txn.get('num-samples'.encode()))
            self.nSamples = nSamples
            
            self.filtered_index_list = [index + 1 for index in range(self.nSamples)]

    def __len__(self):
        return self.nSamples

    def __getitem__(self, index):
        assert index <= len(self), 'index range error'
        index = self.filtered_index_list[index]

        with self.env.begin(write=False) as txn:
            label_key = 'label-%09d'.encode() % index
            label = txn.get(label_key).decode('utf-8')
            img_key = 'image-%09d'.encode() % index
            imgbuf = txn.get(img_key)

            buf = six.BytesIO()
            buf.write(imgbuf)
            buf.seek(0)
            try:

                img = Image.open(buf).convert('L')

            except IOError:
                print(f'Corrupted image for {index}')
                img = Image.new('L', (imgW, imgH))
                label = '[dummy_label]'

            label = label.lower()
            out_of_char = f'[^0123456789abcdefghijklmnopqrstuvwxyz]'
            label = re.sub(out_of_char, '', label)

        return (img, label)


In [56]:
# ! ls data_lmdb_release/mjsynth/mnt/ramdisk/max/90kDICT32px/298/1
# lmdb.open("data_lmdb_release/", max_readers=32, readonly=True, lock=False, readahead=False, meminit=False)

In [57]:
class AlignCollate(object):

    def __init__(self, imgH=32, imgW=100, keep_ratio_with_pad=True):
        self.imgH = imgH
        self.imgW = imgW
        self.keep_ratio_with_pad = keep_ratio_with_pad

    def __call__(self, batch):
        batch = filter(lambda x: x is not None, batch)
        images, labels = zip(*batch)

        resized_max_w = self.imgW
        input_channel = 3 if images[0].mode == 'RGB' else 1
        transform = NormalizePAD((input_channel, self.imgH, resized_max_w))

        resized_images = []
        for image in images:
            w, h = image.size
            ratio = w / float(h)
            if math.ceil(self.imgH * ratio) > self.imgW:
                resized_w = self.imgW
            else:
                resized_w = math.ceil(self.imgH * ratio)

            resized_image = image.resize((resized_w, self.imgH), Image.BICUBIC)
            resized_images.append(transform(resized_image))

        image_tensors = torch.cat([t.unsqueeze(0) for t in resized_images], 0)


        return image_tensors, labels

In [207]:
class Batch_Balanced_Dataset(object):

    def __init__(self, path ="data_lmdb_release"):
        self.path = path
        select_data = 'MJ'
        imgH  = 224 
        imgW = 224
        dashed_line = '-' * 80
        print(dashed_line)
        _batch_size = 8

        _AlignCollate = AlignCollate(imgH=imgH, imgW=imgW, keep_ratio_with_pad=False)
        self.data_loader_list = []
        self.dataloader_iter_list = []
        batch_size_list = []
        Total_batch_size = 0
        
        dataset = LmdbDataset(path)
        total_number_dataset = len(dataset)
        print(total_number_dataset)
        indices = range(total_number_dataset)
        print(dataset)
        batch_size_list.append(str(_batch_size))
        Total_batch_size += _batch_size
        _data_loader = torch.utils.data.DataLoader(
        dataset, batch_size=_batch_size,
        shuffle=True,
        num_workers=8,
        collate_fn=_AlignCollate, pin_memory=True)

        self.data_loader_list.append(_data_loader)
        self.dataloader_iter_list.append(iter(_data_loader))
        

    def get_batch(self):
        balanced_batch_images = []
        balanced_batch_texts = []

        for i, data_loader_iter in enumerate(self.dataloader_iter_list):
            try:
                image, text = data_loader_iter.next()
                balanced_batch_images.append(image)
                balanced_batch_texts += text
            except StopIteration:
                self.dataloader_iter_list[i] = iter(self.data_loader_list[i])
                image, text = self.dataloader_iter_list[i].next()
                balanced_batch_images.append(image)
                balanced_batch_texts += text
            except ValueError:
                pass

        balanced_batch_images = torch.cat(balanced_batch_images, 0)

        return balanced_batch_images, balanced_batch_texts

In [208]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [209]:
class TokenLabelConverter:

    def __init__(self):
        character='0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~'
        batch_max_length = 25
        self.SPACE = '[s]'
        self.GO = '[GO]'

        self.list_token = [self.GO, self.SPACE]
        self.character = self.list_token + list(character)

        self.dict = {word: i for i, word in enumerate(self.character)}
        self.batch_max_length = batch_max_length + len(self.list_token)

    def encode(self, text):

        length = [len(s) + len(self.list_token) for s in text]  # +2 for [GO] and [s] at end of sentence.
        batch_text = torch.LongTensor(len(text), self.batch_max_length).fill_(self.dict[self.GO])
        for i, t in enumerate(text):
            txt = [self.GO] + list(t) + [self.SPACE]
            txt = [self.dict[char] for char in txt]
            batch_text[i][:len(txt)] = torch.LongTensor(txt)  # batch_text[:, 0] = [GO] token
        return batch_text.to(device)

    def decode(self, text_index, length):
        texts = []
        for index, l in enumerate(length):
            text = ''.join([self.character[i] for i in text_index[index, :]])
            texts.append(text)
        return texts

In [210]:
converter = TokenLabelConverter()
num_class = len(converter.character)
print(num_class)

96


In [211]:
from timm.models.vision_transformer import VisionTransformer, _cfg
from timm.models.registry import register_model
from timm.models import create_model
import torch.utils.model_zoo as model_zoo

In [212]:
__all__ = ['vis_model']

def create_vitstr(num_tokens, model="vis_model", checkpoint_path=''):
    vitstr = create_model(
        model,
        pretrained=True,
        num_classes=num_tokens,
        checkpoint_path=checkpoint_path)

    vitstr.reset_classifier(num_classes=num_tokens)

    return vitstr

class ViTSTR(VisionTransformer):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
    
    def reset_classifier(self, num_classes):
        self.num_classes = num_classes
        self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()

    def forward_features(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)

        cls_tokens = self.cls_token.expand(B, -1, -1)  # stole cls_tokens impl from Phil Wang, thanks
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embed
        x = self.pos_drop(x)

        for blk in self.blocks:
            x = blk(x)

        x = self.norm(x)
        return x

    def forward(self, x, seqlen: int =25):
        x = self.forward_features(x)
        x = x[:, :seqlen]

        # batch, seqlen, embsize
        b, s, e = x.size()
        x = x.reshape(b*s, e)
        x = self.head(x).view(b, s, self.num_classes)
        return x


def load_pretrained(model, cfg=None, num_classes=1000, in_chans=1, filter_fn=None, strict=True):

    if cfg is None:
        cfg = getattr(model, 'default_cfg')
    if cfg is None or 'url' not in cfg or not cfg['url']:
        _logger.warning("Pretrained model URL is invalid, using random initialization.")
        return

    state_dict = model_zoo.load_url(cfg['url'], progress=True, map_location='cpu')
    if "model" in state_dict.keys():
        state_dict = state_dict["model"]

    if filter_fn is not None:
        state_dict = filter_fn(state_dict)

    if in_chans == 1:
        conv1_name = cfg['first_conv']
        key = conv1_name + '.weight'
        if key in state_dict.keys():
            conv1_weight = state_dict[conv1_name + '.weight']
        else:
            return
        conv1_type = conv1_weight.dtype
        conv1_weight = conv1_weight.float()
        O, I, J, K = conv1_weight.shape
        if I > 3:
            assert conv1_weight.shape[1] % 3 == 0
            conv1_weight = conv1_weight.reshape(O, I // 3, 3, J, K)
            conv1_weight = conv1_weight.sum(dim=2, keepdim=False)
        else:
            conv1_weight = conv1_weight.sum(dim=1, keepdim=True)
        conv1_weight = conv1_weight.to(conv1_type)
        state_dict[conv1_name + '.weight'] = conv1_weight

    classifier_name = cfg['classifier']
    if num_classes == 1000 and cfg['num_classes'] == 1001:
        classifier_weight = state_dict[classifier_name + '.weight']
        state_dict[classifier_name + '.weight'] = classifier_weight[1:]
        classifier_bias = state_dict[classifier_name + '.bias']
        state_dict[classifier_name + '.bias'] = classifier_bias[1:]
    elif num_classes != cfg['num_classes']:
        del state_dict[classifier_name + '.weight']
        del state_dict[classifier_name + '.bias']
        strict = False

    print("Loading pre-trained vision transformer weights from %s ..." % cfg['url'])
    model.load_state_dict(state_dict, strict=strict)


def _conv_filter(state_dict, patch_size=16):
    out_dict = {}
    for k, v in state_dict.items():
        if 'patch_embed.proj.weight' in k:
            v = v.reshape((v.shape[0], 3, patch_size, patch_size))
        out_dict[k] = v
    return out_dict


def vis_model(pretrained=False, **kwargs):
    kwargs['in_chans'] = 1
    kwargs.pop('pretrained_cfg', None)
    print(kwargs)
    model = ViTSTR(
        patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True, **kwargs)
    model.default_cfg = _cfg(
            url="https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth"
    )
    if pretrained:
        load_pretrained(
            model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 1), filter_fn=_conv_filter)
    return model

In [213]:


class Model(nn.Module):

    def __init__(self,):
        super(Model, self).__init__()
        self.vitstr = create_vitstr(num_tokens=num_class)
        return

    def forward(self, input, text, is_train=True, seqlen=25):


        prediction = self.vitstr(input, seqlen=seqlen)
        return prediction

class JitModel(Model):
    def __init__(self):
        super(Model, self).__init__()
        self.vitstr= create_vitstr(num_tokens=opt.num_class)

    def forward(self, input, seqlen:int = 25):
        prediction = self.network(input, seqlen=seqlen)
        return prediction

In [214]:
train_dataset = Batch_Balanced_Dataset()

--------------------------------------------------------------------------------
8919257
<__main__.LmdbDataset object at 0x7f37c02482b0>


In [215]:
model = Model()

{'num_classes': 96, 'in_chans': 1}
Loading pre-trained vision transformer weights from https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth ...


In [216]:
model = torch.nn.DataParallel(model).to(device)
model.train()

DataParallel(
  (module): Model(
    (vitstr): ViTSTR(
      (patch_embed): PatchEmbed(
        (proj): Conv2d(1, 384, kernel_size=(16, 16), stride=(16, 16))
        (norm): Identity()
      )
      (pos_drop): Dropout(p=0.0, inplace=False)
      (blocks): Sequential(
        (0): Block(
          (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
          (attn): Attention(
            (qkv): Linear(in_features=384, out_features=1152, bias=True)
            (attn_drop): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=384, out_features=384, bias=True)
            (proj_drop): Dropout(p=0.0, inplace=False)
          )
          (ls1): Identity()
          (drop_path1): Identity()
          (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_features=384, out_features=1536, bias=True)
            (act): GELU(approximate=none)
            (drop1): Dropout(p=0.0, inplace=False)
            (fc2)

In [217]:
class Averager(object):
    def __init__(self):
        self.reset()

    def add(self, v):
        count = v.data.numel()
        v = v.data.sum()
        self.n_count += count
        self.sum += v

    def reset(self):
        self.n_count = 0
        self.sum = 0

    def val(self):
        res = 0
        if self.n_count != 0:
            res = self.sum / float(self.n_count)
        return res

In [218]:
def validation(model, criterion, evaluation_loader, converter):
    n_correct = 0
    norm_ED = 0
    length_of_data = 0
    infer_time = 0
    valid_loss_avg = Averager()

    for i, (image_tensors, labels) in enumerate(evaluation_loader):
        batch_size = image_tensors.size(0)
        length_of_data = length_of_data + batch_size
        image = image_tensors.to(device)

        target = converter.encode(labels)

        start_time = time.time()
        preds = model(image, text=target, seqlen=converter.batch_max_length)
        _, preds_index = preds.topk(1, dim=-1, largest=True, sorted=True)
        preds_index = preds_index.view(-1, converter.batch_max_length)
        forward_time = time.time() - start_time
        cost = criterion(preds.contiguous().view(-1, preds.shape[-1]), target.contiguous().view(-1))
        length_for_pred = torch.IntTensor([converter.batch_max_length - 1] * batch_size).to(device)
        preds_str = converter.decode(preds_index[:, 1:], length_for_pred)
        infer_time += forward_time
        valid_loss_avg.add(cost)
        preds_prob = F.softmax(preds, dim=2)
        preds_max_prob, _ = preds_prob.max(dim=2)
        confidence_score_list = []
        for gt, pred, pred_max_prob in zip(labels, preds_str, preds_max_prob):
            
            pred_EOS = pred.find('[s]')
            pred = pred[:pred_EOS]  
            pred_max_prob = pred_max_prob[:pred_EOS]
            
            pred = pred.lower()
            gt = gt.lower()
            alphanumeric_case_insensitve = '0123456789abcdefghijklmnopqrstuvwxyz'
            out_of_alphanumeric_case_insensitve = f'[^{alphanumeric_case_insensitve}]'
            pred = re.sub(out_of_alphanumeric_case_insensitve, '', pred)
            gt = re.sub(out_of_alphanumeric_case_insensitve, '', gt)

            if pred == gt:
                n_correct += 1

            if len(gt) == 0 or len(pred) == 0:
                norm_ED += 0
            elif len(gt) > len(pred):
                norm_ED += 1 - edit_distance(pred, gt) / len(gt)
            else:
                norm_ED += 1 - edit_distance(pred, gt) / len(pred)
            try:
                confidence_score = pred_max_prob.cumprod(dim=0)[-1]
            except:
                confidence_score = 0  
            confidence_score_list.append(confidence_score)

    accuracy = n_correct / float(length_of_data) * 100
    norm_ED = norm_ED / float(length_of_data)  # ICDAR2019 Normalized Edit Distance

    return valid_loss_avg.val(), accuracy, norm_ED, preds_str, confidence_score_list, labels, infer_time, length_of_data

In [220]:
criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device)
loss_avg = Averager()
filtered_parameters = []
params_num = []
for p in filter(lambda p: p.requires_grad, model.parameters()):
    filtered_parameters.append(p)
    params_num.append(np.prod(p.size()))
    
optimizer = optim.Adadelta(filtered_parameters, lr=1.0, rho=.95, eps=1e-8)
scheduler = None
best_accuracy = -1
best_norm_ED = -1
iteration = 0

_AlignCollate = AlignCollate(imgH=imgH, imgW=imgW, keep_ratio_with_pad=False)
valid_dataset, valid_dataset_log = Batch_Balanced_Dataset("data_lmdb_release/valadation")
valid_loader = torch.utils.data.DataLoader(
    valid_dataset, batch_size=8,
    shuffle=True,  
    num_workers=int(2),
    collate_fn=AlignCollate_valid, pin_memory=True)

In [206]:
while(True):
        image_tensors, labels = train_dataset.get_batch()
        image = image_tensors.to(device)
    
        batch_size = image.size(0)


        target = converter.encode(labels)
        preds = model(image, text=target, seqlen=converter.batch_max_length)
        cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1))
        model.zero_grad()
        cost.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 5) 
        optimizer.step()

        loss_avg.add(cost)
        print(cost)
        if (iteration + 1) % 2000 == 0 or iteration == 0:
            model.eval()
            with torch.no_grad():
                valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation(
                    model, criterion, valid_loader, converter)
            model.train()

            print(f'[{iteration+1}/2000] Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}')
            loss_avg.reset()
            if current_accuracy > best_accuracy:
                best_accuracy = current_accuracy
                torch.save(model.state_dict(), f'./saved_models/Models/best_accuracy.pth')
            if current_norm_ED > best_norm_ED:
                best_norm_ED = current_norm_ED
                torch.save(model.state_dict(), f'./saved_models/Models/best_norm_ED.pth')