In [68]:
import os
import cv2
import torch
import numpy as np
from pathlib import Path
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from torch import nn
import torch.nn.functional as F

In [69]:
from itertools import permutations
import zipfile
from typing import Optional, List
from pathlib import Path
import cv2
import numpy as np
from collections import defaultdict, Counter
import lmdb


class Vocabulary:
    def __init__(self, classes):
        self.classes = sorted(set(classes))
        self._class_to_index = dict((cls, idx) for idx, cls in enumerate(self.classes))
    
    def class_by_index(self, idx: int) -> str:
        return self.classes[idx]

    def index_by_class(self, cls: str) -> int:
        return self._class_to_index[cls]
    
    def num_classes(self) -> int:
        return len(self.classes)


class ArchivedHWDBReader:
    def __init__(self, path: Path):
        self.path = path
        self.archive = None
    
    def open(self):
        self.archive = zipfile.ZipFile(self.path)
    
    def namelist(self):
        return self.archive.namelist()
    
    def decode_image(self, name):
        sample = self.archive.read(name)
        buf = np.asarray(bytearray(sample), dtype='uint8')
        return cv2.imdecode(buf, cv2.IMREAD_GRAYSCALE)
    
    def close(self):
        self.archive.close()
    
    def __enter__(self):
        self.open()
        return self
    
    def __exit__(self, exc_type, exc_value, traceback):
        self.close()


GB = 2**30
class LMDBReader:
    def __init__(self, path: Path):
        self.path = path
        self.env = None
        self.namelist_ = []
    
    def open(self):
        self.env = lmdb.open(self.path, 
                             map_size=GB * 16,
                             lock=False, 
                             subdir=False, 
                             readonly=True)
        self.namelist_ = []
        with self.env.begin(buffers=True) as txn:
            cursor = txn.cursor()
            for key, _ in cursor:
                key = bytes(key).decode('utf-8')
                self.namelist_.append(key)
    
    def namelist(self):
        return self.namelist_
    
    def decode_image(self, name):
        key = name.encode('utf-8')
        with self.env.begin() as txn:
            sample = txn.get(key)
        buf = np.frombuffer(sample, dtype='uint8')
        return cv2.imdecode(buf, cv2.IMREAD_GRAYSCALE)
    
    def close(self):
        self.env.close()
    
    def __enter__(self):
        self.open()
        return self
    
    def __exit__(self, exc_type, exc_value, traceback):
        self.close()


class HWDBDatasetHelper:
    def __init__(self, reader, prefix='Train', vocabulary: Optional[Vocabulary]=None, namelist: Optional[List[str]]=None):
        self.reader = reader
        self.prefix = prefix
        self.index = defaultdict(list)
        self.counter = Counter()
        self.namelist = namelist
        if self.namelist is None:
            self.namelist = list(filter(lambda x: self.prefix in x, self.reader.namelist()))
        self.vocabulary = vocabulary
        self._build_index()
    
    def get_item(self, idx):
        name = self.namelist[idx]
        return self.reader.decode_image(name), \
            self.vocabulary.index_by_class(HWDBDatasetHelper._get_class(name))
    
    def size(self):
        return len(self.namelist)

    def get_all_class_items(self, idx):
        cls = self.vocabulary.class_by_index(idx)
        return self.index[cls]
    
    def most_common_classes(self, n=None):
        return self.counter.most_common(n)
    
    def train_val_split(self, train_part=0.8, seed=42):
        rnd = np.random.default_rng(seed)
        permutation = rnd.permutation(len(self.namelist))
        train_part = int(len(permutation) * train_part)
        train_names = [self.namelist[idx] for idx in permutation[:train_part]]
        val_names = [self.namelist[idx] for idx in permutation[train_part:]]

        return HWDBDatasetHelper(self.reader, self.prefix, self.vocabulary, train_names),\
            HWDBDatasetHelper(self.reader, self.prefix, self.vocabulary, val_names)
    
    @staticmethod
    def _get_class(name):
        return Path(name).parent.name
    
    def _build_index(self):
        classes = set()
        for idx, name in enumerate(self.namelist):
            cls = HWDBDatasetHelper._get_class(name)
            classes.add(cls)
            self.index[cls].append(idx)
            self.counter.update([cls])
        
        if self.vocabulary is None:
            self.vocabulary = Vocabulary(classes)

In [70]:
from pathlib import Path

def evaluate(gt_path, pred_path):
    gt = dict()
    with open(gt_path) as gt_f:
        for line in gt_f:
            name, cls = line.strip().split()
            gt[name] = cls
    
    n_good = 0
    n_all = len(gt)
    with open(pred_path) as pred_f:
        for line in pred_f:
            name, cls = line.strip().split()
            if cls == gt[name]:
                n_good += 1
    
    return n_good / n_all

In [71]:
root = Path().absolute().parent.parent / 'data'
train_path = os.path.join(root, 'train.lmdb')
test_path = os.path.join(root, 'test.lmdb')
gt_path = './gt.txt'
pred_path = './pred.txt'

In [72]:
train_reader = LMDBReader(train_path)
train_reader.open()
train_helper = HWDBDatasetHelper(train_reader)
train_helper, val_helper = train_helper.train_val_split()
train_dataset = HWDBDataset(train_helper)
val_dataset = HWDBDataset(val_helper)
train_loader = DataLoader(train_dataset, batch_size=512, shuffle=True, drop_last=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=512, shuffle=False, num_workers=0)

In [73]:
class CustomDataset(Dataset):
    def __init__(self, helper: HWDBDatasetHelper):
        self.helper = helper
    
    def __len__(self):
        return len(self.helper)
    
    def __getitem__(self, idx):
        img, label = self.helper.get_item(idx)
        img = cv2.resize(img, (32, 104*32//79))
        img = (img - 127.5) / 255.
        img = torch.from_numpy(img).float()
        return img, label

class CustomLoss(nn.Module):
    def __init__(self, in_features, out_features, s=30.0, m=0.50):
        super(CustomLoss, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.s = s
        self.m = m
        self.weight = nn.Parameter(torch.FloatTensor(out_features, in_features))
        nn.init.xavier_uniform_(self.weight)

    def forward(self, input, target):
        cosine = F.linear(F.normalize(input), F.normalize(self.weight))
        phi = cosine - self.m
        one_hot = F.one_hot(target, num_classes=self.out_features).float()
        output = one_hot * phi + (1.0 - one_hot) * cosine
        output *= self.s
        return F.cross_entropy(output, target)

class CustomBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

class CustomResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super().__init__()
        self.in_planes = 32
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(32)
        self.layer1 = self._make_layer(block, 32, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 64, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 128, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 256, num_blocks[3], stride=2)
        self.linear = nn.Linear(256*block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        layers = []
        layers.append(block(self.in_planes, planes, stride))
        self.in_planes = planes * block.expansion
        for _ in range(1, num_blocks):
            layers.append(block(self.in_planes, planes, 1))
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

In [74]:
model = ResNet(BasicBlock, [2,2,2,2], train_helper.vocabulary.num_classes())
model = model.to(device)
optim = torch.optim.AdamW(model.parameters(), lr=0.001)
loss_fn = torch.nn.CrossEntropyLoss()

In [75]:
def run_validation(val_loader: DataLoader, model: nn.Module, n_steps=None):
    model.eval()
    n_good = 0
    n_all = 0
    wrapper = lambda x: x
    if n_steps is None:
        n_steps = len(val_loader)
        wrapper = tqdm
    
    with torch.no_grad():
        for batch, (X, y) in enumerate(wrapper(val_loader)):
            if batch == n_steps:
                break
            logits = model(X.unsqueeze(1).to(torch.float32).to(device))
            classes = torch.argmax(logits, dim=1).cpu().numpy()
            n_good += sum(classes == y.cpu().numpy())
            n_all += len(classes)
    
    return n_good / n_all

In [76]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
for epoch in range(10):
    print(f'Epoch {epoch}:')
    for batch, (X, y) in enumerate(tqdm(train_loader)):
        model.train()
        logits = model(X.unsqueeze(1).to(torch.float32).to(device))
        loss = loss_fn(logits, y.to(torch.long).to(device))
        
        optim.zero_grad()
        loss.backward()
        optim.step()

    torch.save(model.state_dict(), f'my_epoch{epoch}.pth')
    
    accuracy = run_validation(val_loader, model)
    print(f'accuracy: {accuracy}')


100%|██████████| 5036/5036 [11:31<00:00,  7.28it/s]
100%|██████████| 1260/1260 [02:06<00:00,  9.92it/s]


accuracy: 0.896014483198342


100%|██████████| 5036/5036 [11:30<00:00,  7.29it/s]
100%|██████████| 1260/1260 [02:01<00:00, 10.33it/s]


accuracy: 0.9265647857848711


100%|██████████| 5036/5036 [11:30<00:00,  7.29it/s]
100%|██████████| 1260/1260 [02:01<00:00, 10.37it/s]


accuracy: 0.938592231880101


100%|██████████| 5036/5036 [11:30<00:00,  7.29it/s]
100%|██████████| 1260/1260 [02:02<00:00, 10.30it/s]


accuracy: 0.9433036150596719


100%|██████████| 5036/5036 [11:30<00:00,  7.29it/s]
100%|██████████| 1260/1260 [02:03<00:00, 10.21it/s]


accuracy: 0.9461890851663567


100%|██████████| 5036/5036 [11:30<00:00,  7.29it/s]
100%|██████████| 1260/1260 [02:01<00:00, 10.34it/s]


accuracy: 0.9478970973101524


 62%|██████▏   | 3141/5036 [07:10<04:19,  7.29it/s]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

100%|██████████| 1260/1260 [02:01<00:00, 10.33it/s]


accuracy: 0.9511548861402804


100%|██████████| 5036/5036 [11:30<00:00,  7.29it/s]
100%|██████████| 1260/1260 [02:01<00:00, 10.33it/s]


accuracy: 0.9509501108423867


100%|██████████| 5036/5036 [11:30<00:00,  7.29it/s]
100%|██████████| 1260/1260 [02:04<00:00, 10.14it/s]

accuracy: 0.9513394941739877





In [77]:
test_reader = LMDBReader(test_path)
test_reader.open()
test_helper = HWDBDatasetHelper(test_reader, prefix='Test')
test_dataset = HWDBDataset(test_helper)
test_loader = DataLoader(test_dataset, batch_size=512, shuffle=False, num_workers=0)

In [78]:
preds = []
model.eval()
with torch.no_grad():
    for X, _ in tqdm(test_loader):
        logits = model(X.unsqueeze(1).to(torch.float32).to(device))
        classes = torch.argmax(logits, dim=1).cpu().numpy()
        preds.extend(classes)

100%|██████████| 1517/1517 [02:30<00:00, 10.08it/s]


In [79]:
with open(pred_path, 'w', encoding="utf-8") as f_pred:
    for idx, pred in enumerate(preds):
        name = test_helper.namelist[idx]
        cls = train_helper.vocabulary.class_by_index(pred)
        print(name, cls, file=f_pred)

In [80]:
evaluate(gt_path, pred_path)

0.9310374580018879