In [None]:
# default_exp models.module

# Module

> Deep Learning modules with Fastai/Pytorch.

In [None]:
# hide
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
# export
from steel_segmentation.core import *
from steel_segmentation.data import *
from steel_segmentation.preprocessing import *
from steel_segmentation.models.dls import *
from steel_segmentation.models.metrics import *

import warnings
with warnings.catch_warnings():
    warnings.simplefilter('ignore')
    from fastai.vision.all import *
    import fastai
from fastcore.foundation import *

In [None]:
# hide
only_imgs = ["0a1cade03.jpg", "bca4ae758.jpg", "988cf521f.jpg", "b6a257b28.jpg",
             "b2ad335bf.jpg", "72aaba8ad.jpg", "f383950e8.jpg"]
train = train[train["ImageId"].isin(only_imgs)].copy()
train_all = train_all[train_all["ImageId"].isin(only_imgs)].copy()
train_multi = train_multi[train_multi["ImageId"].isin(only_imgs)].copy()

First we create a classification model to get an encoder that know how to classify defects pixels.
Then, we build a UNet from the trained encoder and train a segmentation model.

In [None]:
# exports
models_dir = path / "models"

In [None]:
models_dir.ls()

## Classification

In [None]:
models_dir.ls(file_type='pth')

(#6) [Path('../data/test_images'),Path('../data/labels'),Path('../data/codes.txt'),Path('../data/train.csv'),Path('../data/train_images'),Path('../data/sample_submission.csv')]

In [None]:
# exports
class_metrics = [accuracy_multi, PrecisionMulti(), RecallMulti()]

In [None]:
# export
def get_classifier_learner(bs:int, arch=resnet18, metrics=class_metrics, toload:str=None):
    """Get a classification `Learner`"""
    dls = get_classification_dls(bs)
    arch = partial(arch, pretrained=True)
    learner = cnn_learner(dls=dls, arch=arch, metrics=metrics, pretrained=True)
    
    if toload and toload.endswith(".pth"):
        return learner.load(models_dir/toload)
    
    return learner

In [None]:
class_learner = get_classifier_learner(bs)

In [None]:
fastai.__version__

'2.1.8'

In [None]:
class_learner.summary()

TypeError: 'int' object is not iterable

## Segmentation

In [None]:
# exports
seg_metrics = [DiceMulti(), dice_kaggle]

In [None]:
bs = 4 
szs = (128, 800)

In [None]:
# export
def get_segmentation_learner(bs: int, szs, arch=resnet18, metrics=seg_metrics, toload: str = None):
    dls = get_segmentation_dls(bs, szs)
    segmentation_learner = unet_learner(
        dls=dls, arch=arch, metrics=metrics, pretrained=True)
    if toload and toload.endswith('.pt'):
        encoder_path = models_dir / "ResNet18-2_class.pt"
        segmentation_learner.model[0].load_state_dict(
            torch.load(encoder_path), strict=True)
    return segmentation_learner

In [None]:
seg_learn = get_segmentation_learner(bs, szs)

In [None]:
# hide
from nbdev.export import notebook2script
notebook2script()

Converted 00_core.ipynb.
Converted 01_data.ipynb.
Converted 02_preprocessing.ipynb.
Converted 03_model.fastai.ipynb.
Converted 04_model.metrics.ipynb.
Converted index.ipynb.
