In [2]:
import json
import fastai
from experiments import *
from fastai.vision import *
from fastai.callbacks import *
from losses import MixedLoss
from dataset import *
from transforms import *
from config import *
import glob
from PIL import Image as pilImage
from metrics import *

%load_ext autoreload
%autoreload 2

torch.cuda.set_device(0)

In [4]:
EXPERIMENT_PATH = Path(EXPERIMENTS_PATH) / 'synthetic'

def custom_loss(pred, truth):
    truth = truth.float()
    return F.binary_cross_entropy(pred, truth)

def custom_collate(batch):
    if isinstance(batch[0][1], int):
        return torch.stack(list(map(lambda x: x[0].data, batch))), torch.stack(list(map(lambda x: tensor(x[1]), batch)))
    if hasattr(batch[0][0], "x_tensor"):
        return torch.stack(list(map(lambda x: x[0].x_tensor, batch))), torch.stack(list(map(getSegmentationMask, batch))).long()
    else:
        return torch.stack(list(map(lambda x: x[0].px, batch))), torch.stack(list(map(lambda x: x[1].px, batch))).long()

def getSegmentationMask(dan):
    y, x = dan[0].y_tensor, dan[0].x_tensor
    res = ((y[0] == x[0]).int() + (y[1] == x[1]).int() + (x[2] == y[2]).int()) != 3
    res = res.unsqueeze(0)
    return res    
    
def folder(p):
    folder = ("0000" + p[-7:-4])[-4:]
    return '/' + folder + "/" + p 


class CustomItem(ItemBase):
    def __init__(self, image):
        self.image = image
        self.data = 0
    
    def __str__(self): return str(self.image)
        
    def apply_tfms(self, tfms, **kwargs):
        for tfm in tfms:
            tfm(self, **kwargs)
        return self  

class CustomLabel(SegmentationLabelList):
    def open(self, fn):
        return ImageSegment(torch.zeros(1, 64, 64))    
    
class CustomItemList(SegmentationItemList):  
    _label_cls = CustomLabel
    def get(self, i):
        return self.reconstruct(pilImage.open(self.items[i]).convert('RGB'))
    
    def reconstruct(self, t):
        return CustomItem(t)


fonts = Fonts(Fonts.load(Path('../fonts')))
with open(EXPERIMENT_PATH / 'text_info.json', 'r') as f:
    info = json.load(f)

random_seed(42)
allData = getData()   
    
items = list(map(lambda p: DANBOORU_PATH + folder(p), filter(lambda k: info[k] == 0, info.keys())))

data = CustomItemList(items[0:10]).split_none().label_const('a', classes=['text'])

data.valid = getDatasets(allData)[0].valid

data.train.transform([partial(textify, fonts=fonts), tensorize])

data = data.databunch(bs=8, val_bs = 2, collate_fn = custom_collate).normalize(imagenet_stats)

learn = unet_learner(data, models.resnet18, metrics=[accuracy_thresh, partial(accuracy_thresh, thresh=0.95, sigmoid=False)], loss_func=custom_loss, y_range=(0,1))

In [18]:
cntIndex = 0
if cv2.__version__.startswith("3"):
    cntIndex = 1

def expand(mask, img):
    mask = mask.astype('uint8')
    gray = cv2.cvtColor(img.data.mul(255).permute(1,2,0).numpy().astype('uint8'),cv2.COLOR_RGB2GRAY)
    thresh = cv2.adaptiveThreshold(gray,255,cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY,15,30)
    cnts = cv2.findContours(thresh, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)[cntIndex]
    im3 = np.zeros(thresh.shape, np.uint8)

    for c in cnts:
        x,y,w,h = cv2.boundingRect(c)
        thresh = cv2.adaptiveThreshold(gray[y:y+h, x:x+w],255,cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY,15,30)
        ret, markers = cv2.connectedComponents(cv2.bitwise_not(thresh), connectivity=8)
        if ret < 10:
            for label in range(1,ret):
                m = markers == label
                if m.sum() > 3:
                    if (m & mask[y:y+h, x:x+w] > 0).sum() > m.sum() * 0.1:
                        im3[y:y+h, x:x+w][m] = 255
    
    return im3
def removeNoise(mask):
    cnts = cv2.findContours(mask, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)[cntIndex]
    goods = [cv2.contourArea(c) >= 50 for c in cnts]
    rects = [cv2.boundingRect(c) for c in cnts]
    circles = [cv2.minEnclosingCircle(c) for c in cnts]
    banned = [False] * len(cnts)
        
    m = cv2.dilate(mask,(5, 5),iterations = 7)
    cc = cv2.findContours(m, cv2.RETR_LIST, cv2.CHAIN_APPROX_NONE)[cntIndex] 
    rr = [cv2.boundingRect(c) for c in cc]

    for c, good, idx, rect in zip(cnts, goods, range(0, len(cnts)), rects):
        x,y,w,h = rect
        
        if max(w,h) / min(w,h) > 5:
            goods[idx] = False
            if max(w,h) / min(w,h) > 8:
                banned[idx] = True
                continue
        
        for r2, c2 in zip(rr, cc):
            x2, y2, w2, h2 = r2
            if cv2.contourArea(c2) >= 50 and x >= x2 and x + w <= x2 + w2 and y >= y2 and y + h <= y2 + h2 and (mask[y2:y2+h2, x2:x2+w2] > 0).sum() > len(c2) * 0.5:
                goods[idx] = True    

    
    changed = True
    while changed:
        changed = False
        for c, good, idx, rect in zip(cnts, goods, range(0, len(cnts)), rects):
            if banned[idx]:
                continue
                
            x,y,w,h = rect
            x, y = x + w / 2, y + h / 2 
            if not good:
                for a in range(max(idx - 50, 0), len(cnts)):
                    if a != idx and goods[a]:
                        x2, y2, w2, h2 = rects[a]
                        x2, y2, = x2 + w2 / 2, y2 + h2 / 2 
                        
                        if abs(y2 - y) > 100 + h:
                            break
                        
                        if abs (cv2.contourArea(cnts[idx]) - circles[idx][1]**2) > 20 and abs(y2 - y) < (h + h2) / 2 + 20 and abs(x2 - x) < (w + w2) / 2 + 20:
                            good = goods[idx] = True
                            changed = True
                            break

 

    for c, good, idx, rect in zip(cnts, goods, range(0, len(cnts)), rects):                        
        if not good:
            cv2.drawContours(mask, [c], 0, (0, 0, 0), -1)

In [5]:
learn.load(EXPERIMENT_PATH / 'models' / 'v1_2');

In [19]:
if not (EXPERIMENT_PATH / 'v1_2 predictions.csv').exists() or True:
    for index, dataset in enumerate(getDatasets(allData, crop=False, cutInHalf = False)):
        random_seed(42)
        m = MetricsCallback(None)
        m.on_train_begin()
        for idx in range(len(dataset.valid.x.items)):
            x = dataset.valid.x.get(idx, False)
            y = learn.predict(x)[2] > 0.95
            y = y.permute(1,2,0).numpy() * 255
            y = expand(y[:,:,0], x)
            removeNoise(y)
            y = tensor(y).unsqueeze(0).div_(255).bool()
            m.on_batch_end(False, y, dataset.valid.y.get(idx, False).px)
        m.calculateMetrics() 
        m.save(EXPERIMENT_PATH / 'v1_2 predictions.csv', index > 0)            