In [2]:
import json
from fastai.vision import unet_learner, imagenet_stats, torch, Path, os, load_learner, models, ItemBase, SegmentationLabelList, SegmentationItemList, partial, ImageSegment
from experiments import getDatasets, getData, random_seed
from losses import MixedLoss
from metrics import MetricsCallback, getDatasetMetrics
from fastai.callbacks import CSVLogger, SaveModelCallback
from config import *
import glob
from TextGenerator import Fonts
from transforms import textify, tensorize
from PIL import Image as pilImage

%load_ext autoreload
%autoreload 2
torch.cuda.set_device(0)

In [3]:
EXPERIMENT_PATH = Path(EXPERIMENTS_PATH) / 'synthetic'
MODELS_PATH = EXPERIMENT_PATH / "models"
os.makedirs(MODELS_PATH, exist_ok=True)

In [4]:
learn = load_learner(Path(EXPERIMENTS_PATH) / 'model/resnet34/0', 'final refined model 2.pkl')

if (EXPERIMENT_PATH / 'text_info.json').exists():
    with open(EXPERIMENT_PATH / 'text_info.json', 'r') as f:
        info = json.load(f)
else:
    info = dict()

for file in glob.glob(DANBOORU_PATH + '/**/*.jpg', recursive=True):
    file = Path(file)
    if file.name not in info:
        pred = learn.predict(open_image(file))[2]
        pred = torch.sigmoid(pred) > 0.5
        info[file.name] = (pred == 1).sum().item()
        with open(EXPERIMENT_PATH / 'text_info.json', 'w') as f:
            json.dump(info, f)            

In [5]:
allData = getData()

In [6]:
def custom_collate(batch):
    if hasattr(batch[0][0], "x_tensor"):
        return torch.stack(list(map(lambda x: x[0].x_tensor, batch))), torch.stack(list(map(getSegmentationMask, batch))).long()
    else:
        return torch.stack(list(map(lambda x: x[0].px, batch))), torch.stack(list(map(lambda x: x[1].px, batch))).long()

def getSegmentationMask(dan):
    return ((dan[0].x_tensor - dan[0].y_tensor).abs().sum(axis=0) > 0.1).unsqueeze(0)
    
def folder(p):
    folder = ("0000" + p[-7:-4])[-4:]
    return '/' + folder + "/" + p 


class CustomItem(ItemBase):
    def __init__(self, image):
        self.image = image
        self.data = 0
    
    def __str__(self): return str(self.image)
        
    def apply_tfms(self, tfms, **kwargs):
        for tfm in tfms:
            tfm(self, **kwargs)
        return self  

class CustomLabel(SegmentationLabelList):
    def open(self, fn):
        return ImageSegment(torch.zeros(1, 64, 64))    
    
class CustomItemList(SegmentationItemList):  
    _label_cls = CustomLabel
    def get(self, i):
        return self.reconstruct(pilImage.open(self.items[i]).convert('RGB'))
    
    def reconstruct(self, t):
        return CustomItem(t)

    
fonts = Fonts(Fonts.load(Path('../fonts')))

items = list(map(lambda p: DANBOORU_PATH + folder(p), filter(lambda k: info[k] == 0, info.keys())))

data = CustomItemList(items).split_none().label_const('a', classes=['text'])
data.valid = getDatasets(allData)[0].valid

data.train.transform([partial(textify, fonts=fonts), tensorize])

data = data.databunch(bs=8, val_bs = 2, collate_fn = custom_collate).normalize(imagenet_stats)

In [None]:
learn = unet_learner(data, models.resnet34, callback_fns=[MetricsCallback, CSVLogger, partial(SaveModelCallback, monitor = 'ignore global f1 score %')], model_dir='models', loss_func=MixedLoss(0, 1), path=EXPERIMENT_PATH)

In [None]:
if not (EXPERIMENT_PATH / 'final model.pkl').exists():
    random_seed(42)
    learn.fit_one_cycle(5, 1e-4)
    learn.save('model')
    learn.export('final model.pkl')    

In [None]:
if not (EXPERIMENT_PATH / 'final predictions.csv').exists():
    learn = load_learner(EXPERIMENT_PATH, 'final model.pkl')
    for index, dataset in enumerate(getDatasets(allData, crop=False, cutInHalf = False)):
        random_seed(42)
        m = getDatasetMetrics(dataset, learn)
        m.save(EXPERIMENT_PATH / 'final predictions.csv', index > 0)            