In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
#export
from nb_005 import *

# Carvana

## Setup

(See final section of notebook for one-time data processing steps.)

In [None]:
PATH = Path('data/carvana')
PATH_PNG = PATH/'train_masks_png'
PATH_X_FULL = PATH/'train'
PATH_X_128 = PATH/'train-128'
PATH_Y_FULL = PATH_PNG
PATH_Y_128 = PATH/'train_masks-128'

# start with the 128x128 images
PATH_X = PATH_X_128
PATH_Y = PATH_Y_128

In [None]:
img_f = next(PATH_X.iterdir())
x = open_image(img_f)
x.show()
x.shape

In [None]:
#export
class ImageMask(Image):
    def lighting(self, func, *args, **kwargs): return self
    
    def refresh(self):
        self.sample_kwargs['mode'] = 'nearest'
        return super().refresh()

def open_mask(fn):
    return ImageMask(pil2tensor(PIL.Image.open(fn)).float())

In [None]:
def get_y_fn(x_fn): return PATH_Y/f'{x_fn.name[:-4]}_mask.png'

img_y_f = get_y_fn(img_f)
y = open_mask(img_y_f)
y.show()

In [None]:
# Same as `show_image`, but renamed with _ prefix
def _show_image(img, ax=None, figsize=(3,3), hide_axis=True, cmap='binary', alpha=None):
    if ax is None: fig,ax = plt.subplots(figsize=figsize)
    ax.imshow(image2np(img), cmap=cmap, alpha=alpha)
    if hide_axis: ax.axis('off')
    return ax

def show_image(x, y=None, ax=None, figsize=(3,3), alpha=0.4, hide_axis=True, cmap='viridis'):
    ax1 = _show_image(x, ax=ax, hide_axis=hide_axis, cmap=cmap)
    if y is not None: _show_image(y, ax=ax1, alpha=alpha, hide_axis=hide_axis, cmap=cmap)
    if hide_axis: ax1.axis('off')
        
def _show(self, ax=None, y=None, **kwargs):
    if y is not None: y=y.data
    return show_image(self.data, ax=ax, y=y, **kwargs)

Image.show = _show

In [None]:
x.show(y=y)

## xy transforms

- data types: regr, class, seg, bbox, polygon, generative (s/res, color), custom

In [None]:
#export
class DatasetTfm(Dataset):
    def __init__(self, ds:Dataset,tfms:Collection[Callable]=None,tfm_y:bool=False, **kwargs):
        self.ds,self.tfms,self.tfm_y,self.x_kwargs = ds,tfms,tfm_y,kwargs
        self.y_kwargs = {**self.x_kwargs, 'do_resolve':False} # don't reset random vars
        
    def __len__(self): return len(self.ds)
    
    def __getitem__(self,idx):
        x,y = self.ds[idx]
        
        x = apply_tfms(self.tfms, x, **self.x_kwargs)
        if self.tfm_y: y = apply_tfms(self.tfms, y, **self.y_kwargs)
        return x, y
    
    @property
    def c(self): return self.ds.c
    
import nb_002b,nb_005
nb_002b.DatasetTfm = DatasetTfm  
nb_005.DatasetTfm  = DatasetTfm  

## Dataset

In [None]:
#export
@dataclass
class MatchedFilesDataset(Dataset):
    x_fns:List[Path]; y_fns:List[Path]
    def __post_init__(self): assert len(self.x_fns)==len(self.y_fns)
    def __repr__(self): return f'{type(self).__name__} of len {len(self.x_fns)}'
    def __len__(self): return len(self.x_fns)
    def __getitem__(self, i): 
        return open_image(self.x_fns[i]), open_mask(self.y_fns[i])
    
def split_arrs_by_idx(idxs, *a):
    """
    Split each array passed as *a, to a pair of arrays like this (elements selected by idxs,  the remaining elements)
    This can be used to split multiple arrays containing training data to validation and training set.
    :param idxs [int]: list of indexes selected
    :param a list: list of np.array, each array should have same amount of elements in the first dimension
    :return: list of tuples, each containing a split of corresponding array from *a.
            First element of each tuple is an array composed from elements selected by idxs,
            second element is an array of remaining elements.
    """
    mask = np.zeros(len(a[0]),dtype=bool)
    mask[np.array(idxs)] = True
    return [(o[mask],o[~mask]) for o in a]

In [None]:
x_fns = [o for o in PATH_X.iterdir() if o.is_file()]
y_fns = [get_y_fn(o) for o in x_fns]
val_idxs = list(range(1008))
((val_x,trn_x),(val_y,trn_y)) = split_arrs_by_idx(val_idxs, np.array(x_fns), np.array(y_fns))
train_ds = MatchedFilesDataset(trn_x, trn_y)
valid_ds = MatchedFilesDataset(val_x, val_y)
train_ds, valid_ds

In [None]:
x,y = next(iter(train_ds))
x.shape, y.shape, type(x), type(y)

In [None]:
size=128
tfms = get_transforms(do_flip=True, max_rotate=20, max_zoom=2., max_lighting=0.7, max_warp=0.3,p_affine=0.75)
train_tds, valid_tds, augm_tds = transform_datasets(train_ds, valid_ds, tfms, tfm_y=True, size=size)

In [None]:
_,axes = plt.subplots(1,4, figsize=(12,6))
for i, ax in enumerate(axes.flat):
    imgx,imgy = train_tds[i]
    imgx.show(ax, y=imgy)

In [None]:
#export
def normalize_batch(b, mean, std, do_y=False):
    x,y = b
    x = normalize(x,mean,std)
    if do_y: y = normalize(y,mean,std)
    return x,y

def normalize_funcs(mean, std, do_y=False, device=None):
    if device is None: device=default_device
    return (partial(normalize_batch, mean=mean.to(device),std=std.to(device), do_y=do_y),
            partial(denormalize,     mean=mean,           std=std))

In [None]:
#imagenet mean/std
default_mean, default_std = tensor([0.485, 0.456, 0.406]), tensor([0.229, 0.224, 0.225])
default_norm,default_denorm = normalize_funcs(default_mean,default_std)

In [None]:
bs = 64
data = DataBunch.create(train_tds, valid_tds, bs=bs, dl_tfms=default_norm)

In [None]:
def show_xy_images(x,y,rows,figsize=(9,9)):
    fig, axs = plt.subplots(rows,rows,figsize=figsize)
    for i, ax in enumerate(axs.flatten()): show_image(x[i], y=y[i], ax=ax)
    plt.tight_layout()

In [None]:
x,y = next(iter(data.train_dl))
x,y = x.cpu(),y.cpu()
x = default_denorm(x)
show_xy_images(x,y,6, figsize=(9,9))
x.shape, y.shape

In [None]:
from torchvision.models import resnet34
arch = resnet34

class Debugger(nn.Module): 
    def forward(self,x): 
        set_trace()
        return x

class StdUpsample(nn.Module):
    def __init__(self, nin, nout):
        super().__init__()
        self.conv = nn.ConvTranspose2d(nin, nout, 2, stride=2)
        self.bn = nn.BatchNorm2d(nout)
        
    def forward(self, x): 
        return self.bn(F.relu(self.conv(x)))

flatten_channel = Lambda(lambda x: x[:,])
    
body = nn.Sequential(*list(arch(True).children())[:-2])
head = nn.Sequential(
    nn.ReLU(),
    StdUpsample(512,256),
    StdUpsample(256,256),
    StdUpsample(256,256),
    StdUpsample(256,256),
    nn.ConvTranspose2d(256, 1, 2, stride=2),
    flatten_channel
)

def dice(pred, targs):
    pred = (pred>0).float()
    return 2. * (pred*targs).sum() / (pred+targs).sum()


model = nn.Sequential(body, head)
learn = Learner(data, model, metrics=dice)
learn.split([model[1]])
learn.freeze()
apply_init(learn.model[1], nn.init.kaiming_normal_)
learn.loss_fn = nn.BCEWithLogitsLoss()

In [None]:
lr_find(learn)
learn.recorder.plot()

In [None]:
lr = 1e-2
learn.fit_one_cycle(5, slice(lr))

In [None]:
learn.recorder.plot_losses()

In [None]:
learn.recorder.plot_metrics()

In [None]:
learn.unfreeze()

In [None]:
lrs = learn.lr_range(slice(lr/25,lr)); lrs
learn.fit_one_cycle(6, lrs/5, pct_start=0.01, pct_end=0.4, div_factor=50)

In [None]:
learn.recorder.plot_losses()

In [None]:
learn.recorder.plot_metrics()

In [None]:
x,y = next(iter(data.valid_dl))
py = learn.model(x)
py = py.detach()

In [None]:
(py[0]>0).shape

In [None]:
for i, ax in enumerate(plt.subplots(4,4,figsize=(16,16))[1].flat):
    Image(x[i]).show(ax=ax,y=ImageMask(py[i]>0))

In [None]:
learn.save('carvana_simple_128')

In [None]:
def get_data(size, bs):
    tfms = get_transforms(do_flip=True, max_rotate=20, max_zoom=2., max_lighting=0.7, max_warp=0.3,p_affine=0.75)
    train_tds, valid_tds, augm_tds = transform_datasets(train_ds, valid_ds, tfms, tfm_y=True, size=size)
    data = DataBunch.create(train_tds, valid_tds, bs=bs, dl_tfms=default_norm)
    return data

In [None]:
size=512
bs = 8

data = get_data(size, bs)

In [None]:
lr = 1e-3
learn = Learner(data, model, metrics=dice)
learn.loss_fn = nn.BCEWithLogitsLoss()
learn.load('carvana_simple_128')

In [None]:
learn.fit_one_cycle(6, lr)

In [None]:
learn.save('carvana_simple_512')