In [None]:
from fastai.vision import unet_learner, imagenet_stats, torch, Path, os, load_learner, models, SegmentationLabelList, SegmentationItemList, TfmPixel, pil2tensor, Image, ImageSegment, partial
from experiments import getDatasets, getData, random_seed
from losses import MixedLoss
from metrics import MetricsCallback, getDatasetMetrics
from fastai.callbacks import CSVLogger, SaveModelCallback
from config import *
from dataset import pad_tensor
import PIL
import numpy as np

%load_ext autoreload
%autoreload 2

torch.cuda.set_device(0)

In [None]:
EXPERIMENT_PATH = Path(EXPERIMENTS_PATH) / 'datasets'
MODELS_PATH = EXPERIMENT_PATH / "models"
os.makedirs(MODELS_PATH, exist_ok=True)

In [None]:
def icdar_web_segmentation(x):
    return str(x).replace('/images', '/gt').replace(x.name, 'gt_' + x.stem + '.png')

def icdar_scene_segmentation(x):
    if "/train/" in str(x):
        return str(x).replace('/images', '/gt').replace(x.name, x.stem + '_GT' + '.bmp')
    return icdar_web_segmentation(x)

def total_text_segmentation(x):
    return str(x).replace('/images', '/gt')

def dibco_segmentation(x):
    suffix = '.tiff'
    name = x.stem
    if any(map(lambda y: y in str(x), ["2012"])):
        suffix = '.tif'
    if any(map(lambda y: y in str(x), ["2016", "2017", "2018", "2019"])):
        suffix = '.bmp'       
    if any(map(lambda y: y in str(x), ["2010", "2013", "2014"])):
        name = x.stem + '_estGT'
    if any(map(lambda y: y in str(x), ["2011", "2012"])):
        name = x.stem + '_GT'    
    if any(map(lambda y: y in str(x), ["2016", "2017", "2018"])):
        name = x.stem + '_gt'        
    return str(x).replace('/images', '/gt').replace(x.name, name + suffix)

def kaist_segmentation(x):
    return str(x).replace(x.suffix, '.bmp')

class CustomSegLabel2(SegmentationLabelList):
    def open(self, fn):
        im = PIL.Image.open(fn).convert('L')
        im.thumbnail((500, 500),PIL.Image.NEAREST)
        im = pil2tensor(im,np.float32)
        
        if str(self.path) != KAIST_PATH:
            im = im // 255
        else:
            im = (im != 0).float()
            
        if str(self.path) in [ICDAR2013_WEB_PATH, ICDAR2013_SCENE_PATH, DIBCO_PATH]:
            im = 1 - im
        return ImageSegment(im)
    
class SItemListCustom2(SegmentationItemList):
    _label_cls = CustomSegLabel2
    def open(self, fn):
        im = PIL.Image.open(fn).convert('RGB')
        im.thumbnail((500, 500),PIL.Image.ANTIALIAS)
        im = pil2tensor(im,np.float32) / 255
        return Image(im)

def getDataByName(name):
    if name == 'icdar2013-web' or name == 'icdar2013-scene':
        return (SItemListCustom2.from_folder(ICDAR2013_WEB_PATH if name == 'icdar2013-web' else ICDAR2013_SCENE_PATH)
           .filter_by_func(lambda p: '/train/images' in str(p) or '/test/images' in str(p))
           .split_none()
           .label_from_func(icdar_web_segmentation if name == 'icdar2013-web' else icdar_scene_segmentation, classes=['text']))
    elif name == 'total-text':
        return (SItemListCustom2.from_folder(TOTAL_TEXT_PATH)
           .filter_by_func(lambda p: '/train/images' in str(p) or '/test/images' in str(p))
           .split_none()
           .label_from_func(total_text_segmentation, classes=['text']))                             
    elif name == 'dibco':
        return (SItemListCustom2.from_folder(DIBCO_PATH)
           .filter_by_func(lambda p: '/images' in str(p))
           .split_none()
           .label_from_func(dibco_segmentation, classes=['text'])) 
    elif name == 'kaist':
        return (SItemListCustom2.from_folder(KAIST_PATH)
           .filter_by_func(lambda p: p.suffix != '.bmp')
           .split_none()
           .label_from_func(kaist_segmentation, classes=['text']))         
    
def getDatabunch(name):
    props = {'bs': 1,  'val_bs': 2, 'num_workers': 0}
    random_seed(42)
    data = getDataByName(name)
    tfms = [TfmPixel(pad_tensor)(multiple = 8)]
    data.train.transform(tfms, tfm_y=True)
    data.valid = getDatasets(allData)[0].valid
    random_seed(42)
    return data.databunch(**props).normalize(imagenet_stats)
    
allData = getData()

In [None]:
for name in ['icdar2013-web', 'icdar2013-scene', 'total-text', 'dibco', 'kaist']:
    PATH = EXPERIMENT_PATH / name
    print(name)
    if not (PATH / 'model.pkl').exists():
        learn = unet_learner(getDatabunch(name), models.resnet34, model_dir='models', callback_fns=[MetricsCallback, CSVLogger, partial(SaveModelCallback, monitor = 'normal pixel f1 %')], loss_func=MixedLoss(0.0, 1.0), path=PATH)
        random_seed(42)
        learn.fit_one_cycle(10, 1e-4)
        learn.save('model')
        learn.export(file='model.pkl')
    if not (PATH / 'predictions.csv').exists():
        learn = load_learner(PATH, 'model.pkl')
        random_seed(42)
        m = getDatasetMetrics(getDataset(allData), learn)
        m.save(PATH / 'predictions.csv', False)

In [None]:
props = {'bs': 1,  'val_bs': 2, 'num_workers': 0}
name = 'manga'
random_seed(42)
for index, dataset in enumerate(getDatasets(allData, crop=False, cutInHalf=False)):
    PATH = EXPERIMENT_PATH / name / str(index)
    if not (PATH / 'model.pkl').exists() or True:
        random_seed(42)
        data = dataset.databunch(**props).normalize(imagenet_stats)
        learn = unet_learner(data, models.resnet34, model_dir='models', callback_fns=[MetricsCallback, CSVLogger, partial(SaveModelCallback, monitor = 'normal pixel f1 %')], loss_func=MixedLoss(0.0, 1.0), path=PATH)
        random_seed(42)
        learn.fit_one_cycle(10, 1e-4)
        learn.save('model')
        learn.export(file='model.pkl')
    if not (PATH / 'predictions.csv').exists():
        learn = load_learner(PATH, 'model.pkl')
        random_seed(42)
        m = getDatasetMetrics(dataset, learn)
        m.save(PATH / 'predictions.csv', False)

In [None]:
#unzipping KAIST dataset
from contextlib import closing
from zipfile import ZipFile

count = dict()
seen = dict()

for f in sorted(glob.glob(KAIST_PATH + "/KAIST/**/*.zip", recursive=True)):
    with closing(ZipFile(f)) as archive:
        for info in archive.infolist():
            name = f + info.filename 
            if "Digital_Camera/(C.S)C-outdoor4.zipDSC03706" in name:
                name = name.replace('zipDSC03706', 'zipDSC03707')
            if name.endswith('.bmp') or name.lower().endswith('.jpg'):
                if name[0:-4] not in count:
                    count[name[0:-4]] = len(count)
                    seen[name[0:-4]] = 0
                seen[name[0:-4]] += 1
                info.filename = str(count[name[0:-4]]) + Path(name).suffix
                archive.extract(info, path = KAIST_PATH + '/images/')
for k in seen.keys():
    if seen[k] != 2:
        print(k, seen[k])                