In [None]:
from fastai.vision import unet_learner, imagenet_stats, torch, Path, os, load_learner, models, sys, Learner, partial, flatten_model, requires_grad, bn_types, defaults
from experiments import getDatasets, getData, random_seed
from losses import MixedLoss
from metrics import MetricsCallback, getDatasetMetrics
from fastai.callbacks import CSVLogger, SaveModelCallback
from config import *
sys.path.append('../../text-segmentation')

%load_ext autoreload
%autoreload 2

from models.text_segmentation import TextSegament, XceptionTextSegment

torch.cuda.set_device(0)

In [None]:
EXPERIMENT_PATH = Path(EXPERIMENTS_PATH) / 'model'
MODELS_PATH = EXPERIMENT_PATH / "models"
os.makedirs(MODELS_PATH, exist_ok=True)

In [None]:
allData = getData()

In [None]:
props = {'bs': 4, 'val_bs': 2, 'num_workers': 0}
modelDict = {'resnet34': models.resnet34, 'xception': XceptionTextSegment(), 'segament': TextSegament()}
propsOverride = {
    'xception': {'bs': 2},
    'segament': {'bs': 2}
}

In [None]:
for name, model in list(modelDict.items()):
    for index, dataset in enumerate(getDatasets(allData)):
        PATH = EXPERIMENT_PATH / name / str(index)    
        if not (PATH / 'final model.pkl').exists():
            overrides = {} if name not in propsOverride else propsOverride[name]
            random_seed(42)
            data = dataset.databunch(**{**props, **overrides}).normalize(imagenet_stats)
            func = Learner if name in ["xception", "segament"] else unet_learner
            random_seed(42)
            learn = func(data, model, callback_fns=[MetricsCallback, CSVLogger], model_dir='models', loss_func=MixedLoss(0, 1), path=PATH)
            random_seed(42)
            learn.fit_one_cycle(10, 1e-4)
            learn.save('model')
            learn.export(file='final model.pkl')
    for index, dataset in enumerate(getDatasets(allData, crop=False, cutInHalf=False)): 
        PATH = EXPERIMENT_PATH / name / str(index)    
        if not (PATH / 'final predictions.csv').exists():
            learn = load_learner(PATH, 'final model.pkl')
            random_seed(42)
            m = getDatasetMetrics(dataset, learn)
            m.save(PATH / 'final predictions.csv')

In [None]:
import segmentation_models_pytorch as smp

props = {'bs': 4, 'val_bs': 2, 'num_workers': 0}
models = ['resnet50', 'dpn68', 'vgg16', 'densenet169', 'efficientnet-b4']
propsOverride = {}
archs = [smp.Unet, smp.Linknet, smp.FPN, smp.PSPNet, smp.PAN]
for arch in archs:
    for model in models:
        if model in ['vgg16', 'densenet169'] and smp.PAN == arch: #not supported
            continue
        for index, dataset in enumerate(getDatasets(allData, padding = 16)):
            PATH = EXPERIMENT_PATH / (model + ' ' + arch.__name__) / str(index) 
            if not (PATH / 'final model.pkl').exists():
                overrides = {} if model not in propsOverride else propsOverride[model]
                random_seed(42)
                data = dataset.databunch(**{**props, **overrides}).normalize(imagenet_stats)
                random_seed(42)
                learn = Learner(data, arch(model, encoder_weights='imagenet'), callback_fns=[MetricsCallback, CSVLogger, partial(SaveModelCallback, monitor="ignore global f1 score %")], model_dir='models', loss_func=MixedLoss(0, 1), path=PATH)
                random_seed(42)
                #freeze encoder, still not implemented in smp
                if hasattr(learn.model, 'reset'): learn.model.reset()
                for l in flatten_model(learn.model.encoder):
                    requires_grad(l, isinstance(l, bn_types))
                learn.create_opt(defaults.lr)
                random_seed(42)
                learn.fit_one_cycle(10, 1e-4)
                learn.save('model')
                learn.export(file='final model.pkl')
        for index, dataset in enumerate(getDatasets(allData, crop=False, cutInHalf=False, padding = 16)):
            PATH = EXPERIMENT_PATH / (model + ' ' + arch.__name__) / str(index) 
            if not (PATH / 'final predictions.csv').exists():
                learn = load_learner(PATH, 'final model.pkl')
                random_seed(42)
                m = getDatasetMetrics(dataset, learn)
                m.save(PATH / 'final predictions.csv')                

In [None]:
PATH = EXPERIMENT_PATH / 'xception'  
for index, dataset in enumerate(getDatasets(allData, crop=False, cutInHalf = False)):
    learn = load_learner(PATH / str(index)  , 'final model.pkl')

    for idx in range(len(dataset.valid.x.items)):
        img = dataset.valid.x.items[idx]
        TENSOR_PATH = PATH / 'predictions' / img.parent.name / img.name.replace(path.suffix, '.pt')
        (PATH / 'predictions' / img.parent.name).mkdir(parents=True, exist_ok=True) 
        if not (TENSOR_PATH).exists():
            pred = learn.predict(dataset.valid.x.get(idx, False))[2]
            torch.save(pred, TENSOR_PATH)   