In [None]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [None]:
from fastai.conv_learner import *
from fastai.dataset import *
from fastai.models.resnet import vgg_resnet50

import json
import pandas as pd
import glob

In [None]:
torch.cuda.set_device(0)

In [None]:
torch.backends.cudnn.benchmark=True

# Data

In [None]:
PATH = Path('data/cekfisik/fase2')
list(PATH.iterdir())

In [None]:
# def load(path, ):
#     asl = sorted(os.listdir(path))
#     asl_new = []
#     for a in asl:
#         aa = a.split('_asl.jpg')[0]
#         asl_new.append(aa)

# asl = sorted(os.listdir('data/cekfisik/fase2/train/asl/asl'))
# txt = sorted(os.listdir('data/cekfisik/fase2/train/txt/txt'))

# # print(len(asl), len(txt))

# asl_new = []
# for a in asl:
#     aa = a.split('_asl.jpg')[0]
#     asl_new.append(aa)

# txt_new = []
# for t in txt:
#     tt = t.split('_txt.jpg')[0]
#     txt_new.append(tt)


# new_data = {
#     'name': asl_new,
#     'asl': asl,
#     'txt': txt,
# }
# # df_train = pd.DataFrame(new_data)
# # df_train.to_csv('data/cekfisik/fase2/train.csv', index=False)
# df_train

In [None]:
BASE_PATH  = 'data/cekfisik/fase2'
TRAIN_PATH = Path(os.path.join(BASE_PATH+'/train'))
VALID_PATH = Path(os.path.join(BASE_PATH+'/valid'))

In [None]:
train_csv = pd.read_csv(PATH/f'train.csv')
valid_csv = pd.read_csv(PATH/f'valid.csv')

# train_csv.to_csv(BASE_PATH+'/train.csv', index=False)
# valid_csv.to_csv(BASE_PATH+'/train.csv', index=False)
train_csv.head()

In [None]:
valid_csv.head()

In [None]:
def show_img(im, figsize=None, ax=None, alpha=None):
    if not ax: fig,ax = plt.subplots(figsize=figsize)
    ax.imshow(im, alpha=alpha)
    ax.set_axis_off()
    return ax

In [None]:
class MatchedFilesDataset(FilesDataset):
    def __init__(self, fnames, y, transform, path):
        self.y=y
        assert(len(fnames)==len(y))
        super().__init__(fnames, transform, path)
    def get_y(self, i): return open_image(os.path.join(self.path, self.y[i]))
    def get_c(self): return 0

In [None]:
train_asl = Path('train/asl/asl')
train_txt = Path('train/txt/txt')

valid_asl = Path('valid/asl/asl')
valid_txt = Path('valid/txt/txt')


train_x = np.array([train_asl/o for o in train_csv['asl']])
train_y = np.array([train_txt/o for o in train_csv['txt']])

valid_x = np.array([valid_asl/o for o in valid_csv['asl']])
valid_y = np.array([valid_txt/o for o in valid_csv['txt']])

In [None]:
# open_image(PATH/train_x[0])

In [None]:
sz = 512
bs = 4
nw = 16

In [None]:
aug_tfms = [RandomRotate(4, tfm_y=TfmType.CLASS),
            RandomFlip(tfm_y=TfmType.CLASS),
            RandomLighting(0.05, 0.05)]
# aug_tfms = []

In [None]:
tfms = tfms_from_model(resnet34, sz, crop_type=CropType.NO, tfm_y=TfmType.CLASS, aug_tfms=aug_tfms)
datasets = ImageData.get_ds(MatchedFilesDataset, (train_x,train_y), (valid_x,valid_y), tfms, path=PATH)
md = ImageData(PATH, datasets, bs, num_workers=8, classes=None)

In [None]:
denorm = md.trn_ds.denorm
x,y = next(iter(md.aug_dl))
x = denorm(x)

In [None]:
x.shape,y.shape

In [None]:
fig, axes = plt.subplots(1, 4, figsize=(12, 10))
for i,ax in enumerate(axes.flat):
    ax=show_img(x[i], ax=ax)
    show_img(y[i], ax=ax, alpha=0.5)
plt.tight_layout(pad=0.1)

In [None]:
f = resnet34
cut,lr_cut = model_meta[f]

In [None]:
def get_base():
    layers = cut_model(f(True), cut)
    return nn.Sequential(*layers)

In [None]:
def dice(pred, targs):
    pred = (pred>0).float()
    return 2. * (pred*targs).sum() / (pred+targs).sum()

In [None]:
class SaveFeatures():
    features=None
    def __init__(self, m): self.hook = m.register_forward_hook(self.hook_fn)
    def hook_fn(self, module, input, output): self.features = output
    def remove(self): self.hook.remove()

In [None]:
class UnetBlock(nn.Module):
    def __init__(self, up_in, x_in, n_out):
        super().__init__()
        up_out = x_out = n_out//2
        self.x_conv  = nn.Conv2d(x_in,  x_out,  1)
        self.tr_conv = nn.ConvTranspose2d(up_in, up_out, 2, stride=2)
        self.bn = nn.BatchNorm2d(n_out)
        
    def forward(self, up_p, x_p):
        up_p = self.tr_conv(up_p)
        x_p = self.x_conv(x_p)
        cat_p = torch.cat([up_p,x_p], dim=1)
        return self.bn(F.relu(cat_p))

In [None]:
class Unet34(nn.Module):
    def __init__(self, rn):
        super().__init__()
        self.rn = rn
        self.sfs = [SaveFeatures(rn[i]) for i in [2,4,5,6]]
        self.up1 = UnetBlock(512,256,256)
        self.up2 = UnetBlock(256,128,256)
        self.up3 = UnetBlock(256,64,256)
        self.up4 = UnetBlock(256,64,256)
        self.up5 = UnetBlock(256,3,16)
        self.up6 = nn.ConvTranspose2d(16, 1, 1)
        
    def forward(self,x):
        inp = x
        x = F.relu(self.rn(x))
        x = self.up1(x, self.sfs[3].features)
        x = self.up2(x, self.sfs[2].features)
        x = self.up3(x, self.sfs[1].features)
        x = self.up4(x, self.sfs[0].features)
        x = self.up5(x, inp)
        x = self.up6(x)
        return x[:,0]
    
    def close(self):
        for sf in self.sfs: sf.remove()

In [None]:
class UnetModel():
    def __init__(self,model,name='unet'):
        self.model,self.name = model,name

    def get_layer_groups(self, precompute):
        lgs = list(split_by_idxs(children(self.model.rn), [lr_cut]))
        return lgs + [children(self.model)[1:]]

In [None]:
m_base = get_base()

In [None]:
m = to_gpu(Unet34(m_base))
models = UnetModel(m)

In [None]:
learn = ConvLearner(md, models)
learn.opt_fn=optim.Adam
learn.crit=nn.BCEWithLogitsLoss()
learn.metrics=[accuracy_thresh(0.5),dice]

In [None]:
[o.features.size() for o in m.sfs]

In [None]:
learn.freeze_to(1)

In [None]:
learn.lr_find()
learn.sched.plot()

In [None]:
lr=4e-1
wd=1e-7

lrs = np.array([lr/200,lr/20,lr])/2

In [None]:
learn.fit(lr,1,wds=wd,cycle_len=8,use_clr=(5,8))

In [None]:
base = os.getcwd()

In [None]:
learn.save(base+'/512urn-tmp')

In [None]:
learn.load(base+'/512urn-tmp')

In [None]:
learn.unfreeze()
learn.bn_freeze(True)

In [None]:
learn.fit(lrs/2, 1, wds=wd, cycle_len=10,use_clr=(20,10))

In [None]:
learn.save(base+'/128urn-tmp')

In [None]:
learn.load(base+'/128urn-tmp')

In [None]:
x,y = next(iter(md.val_dl))
py = to_np(learn.model(V(x)))

In [None]:
show_img(py[0]>0);

In [None]:
show_img(y[0]);