In [1]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [2]:
from fastai.conv_learner import *
# from fastai.dataset import *
from fastai.models.resnet import vgg_resnet50

import json
from glob import glob

In [3]:
torch.backends.cudnn.benchmark=True

## Data

In [4]:
PATH = Path('../data/all')

In [5]:
def show_img(im, figsize=None, ax=None, alpha=None):
    if not ax: fig,ax = plt.subplots(figsize=figsize)
    ax.imshow(im, alpha=alpha)
    ax.set_axis_off()
    return ax

In [6]:
VEHICLES=10
ROADS=7
ROAD_LINES=6

In [7]:
TRAIN_DN = 'CameraRGB'
MASKS_DN = 'CameraSeg'
workers=7
random_crop=True
pseudo_label=False
# val_folder = 'sample_test_sync'
val_folder = 'val'
S_PREFIX = '49_fp16'
pretrain=False

In [8]:
from torchvision.datasets.folder import pil_loader
import torchvision.transforms as transforms
import torchvision.transforms.functional as TTF

### Create dataloader

In [9]:
class MatchedFilesDataset(Dataset):
    def __init__(self, fnames, y, tfms, path):
        self.path,self.fnames = path,fnames
        self.open_fn = pil_loader
        self.y=y
        self.open_y_fn = pil_loader
        assert(len(fnames)==len(y))
        
        self.n = self.get_n()
        self.c = self.get_c()
        self.tfms = tfms
        
    def get_x(self, i): return self.open_fn(os.path.join(self.path, self.fnames[i]))
    def get_y(self, i): return self.open_y_fn(os.path.join(self.path, self.y[i]))
    def get_n(self): return len(self.fnames)
    def get_c(self): return 2
    
    def get(self, tfms, x, y):
        for fn in tfms:
            #pdb.set_trace()
            x, y = fn(x, y)
        return (x, y)
    
    def __getitem__(self, idx):
        x,y = self.get_x(idx),self.get_y(idx)
        return self.get(self.tfms, x, y)
    
    def __len__(self): return self.n

    def resize_imgs(self, targ, new_path):
        dest = resize_imgs(self.fnames, targ, self.path, new_path)
        return self.__class__(self.fnames, self.y, self.transform, dest)

In [10]:

# Seems to speed up training by ~2%
class DataPrefetcher():
    def __init__(self, loader, stop_after=None):
        self.loader = loader
        self.dataset = loader.dataset
        self.stream = torch.cuda.Stream()
        self.stop_after = stop_after
        self.next_input = None
        self.next_target = None

    def __len__(self):
        return len(self.loader)
    
    def preload(self):
        try:
            self.next_input, self.next_target = next(self.loaditer)
        except StopIteration:
            self.next_input = None
            self.next_target = None
            return
        with torch.cuda.stream(self.stream):
            self.next_input = self.next_input.cuda(async=True)
            self.next_target = self.next_target.cuda(async=True)

    def __iter__(self):
        count = 0
        self.loaditer = iter(self.loader)
        self.preload()
        while self.next_input is not None:
            torch.cuda.current_stream().wait_stream(self.stream)
            input = self.next_input
            target = self.next_target
            self.preload()
            count += 1
            yield input, target
            if type(self.stop_after) is int and (count > self.stop_after):
                break

In [11]:
def crop_bg_pil(x,y):
    w, h = x.size
    top = int(h/3.75)
    bot = int(h*.9 + h/150)
    pad_right=32-w%32
    if pad_right == 32: pad_right = 0
    return TTF.crop(x, top, 0, bot-top, w+pad_right), TTF.crop(y, top, 0, bot-top, w+pad_right)

In [12]:
class RHF(object):
    def __init__(self, p=0.5): self.p = p
    def __call__(self, x, y):
        if random.random() < self.p:
            return TTF.hflip(x), TTF.hflip(y)
        return x,y

In [13]:
class RR(object):
    def __init__(self, degrees=2): self.degrees = degrees
    def __call__(self, x, y):
        angle = random.uniform(-self.degrees, self.degrees)
        return TTF.rotate(x, angle), TTF.rotate(y, angle)

In [14]:
def tfm_x_wrapper(tfm):
    return lambda x,y: (tfm(x), y)

In [15]:
class RC():
    def __init__(self, targ_sz):
        self.targ_sz = targ_sz

    def __call__(self, x, y):
        rand_w = random.uniform(0, 1)
        rand_h = random.uniform(0, 1)
        w,h = x.size
        t_w,t_h = self.targ_sz
        start_x = np.floor(rand_w*(w-t_w)).astype(int)
        start_y = np.floor(rand_h*(h-t_h)).astype(int)
        return TTF.crop(x, start_y, start_x, t_h, t_w), TTF.crop(y, start_y, start_x, t_h, t_w)

In [16]:
def convert_y_ce(y_img):
    y_new = np.zeros(y_img.shape, dtype=int)
    y_new[y_img==VEHICLES] = 1
    cutoff_y = int(y_new.shape[0]*.875)
    y_new[cutoff_y:,:] = 0

    y_new[y_img==ROADS] = 2
    y_new[y_img==ROAD_LINES] = 2
    return torch.from_numpy(y_new).long()

In [17]:
def convert_y(y_img):
    yr = (y_img==ROADS) | (y_img==ROAD_LINES)
    yc = (y_img==VEHICLES)
    cutoff_y = int(yc.shape[0]*.875)
    yc[cutoff_y:,:] = 0
    rn = ~(yr | yc)
    return torch.from_numpy(np.stack((rn,yc,yr)).astype(int))


def xy_tensor(x,y):
    y_img = np.array(y, np.int32, copy=False)
    return TTF.to_tensor(x), convert_y_ce(y_img[:,:,0])

In [18]:
class RRC(transforms.RandomResizedCrop):
    def __call__(self, x, y):
        i, j, h, w = self.get_params(x, self.scale, self.ratio)
        x = TTF.resized_crop(x, i, j, h, w, self.size, self.interpolation)
        y = TTF.resized_crop(y, i, j, h, w, self.size, self.interpolation)
        return x, y

In [19]:
def torch_loader(f_ext, data_path, bs, size, workers=7, random_crop=False, pseudo_label=False, val_folder=None, val_bs=None):
    # Data loading code
    x_names = np.sort(np.array(glob(str(data_path/f'CameraRGB{f_ext}'/'*.png'))))[:500]
    y_names = np.sort(np.array(glob(str(data_path/f'CameraSeg{f_ext}'/'*.png'))))[:500]

    x_n = x_names.shape[0]
    val_idxs = list(range(x_n-300, x_n))
    
    if pseudo_label:
        x_names_test = np.sort(np.array(glob(f'../data/pseudo/CameraRGB{f_ext}/*.png')))
        y_names_test = np.sort(np.array(glob(f'../data/pseudo/CameraSeg{f_ext}/*.png')))
        x_names = np.concatenate((x_names, x_names_test))
        x_names = np.concatenate((y_names, y_names_test))
        print(f'Pseudo-Labels: {len(x_names_test)}')
    if val_folder:
        x_names_val = np.sort(np.array(glob(f'../data/{val_folder}/CameraRGB{f_ext}/*.png')))
        y_names_val = np.sort(np.array(glob(f'../data/{val_folder}/CameraSeg{f_ext}/*.png')))
        val_x,val_y = x_names_val, y_names_val
        trn_x,trn_y = x_names, y_names
        print(f'Val Labels:', len(val_x))
    else:
        ((val_x,trn_x),(val_y,trn_y)) = split_by_idx(val_idxs, x_names, y_names)
    print(f'Val x:{len(val_x)}, y:{len(val_y)}')
    print(f'Trn x:{len(trn_x)}, y:{len(trn_y)}')
    print(f'All x:{len(x_names)}')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    
    train_tfms = [
        crop_bg_pil,
        tfm_x_wrapper(transforms.ColorJitter(.2,.2,.2)),
#         tfm_x_wrapper(Lighting(0.1, __imagenet_pca['eigval'], __imagenet_pca['eigvec'])),
        RR(),
        RHF(),
#         RC((size,size)),
        xy_tensor,
        tfm_x_wrapper(normalize),
    ]
    if random_crop:
        train_tfms.insert(3,RRC(size, scale=(0.4, 1.0)))
    train_dataset = MatchedFilesDataset(trn_x, trn_y, train_tfms, path='')
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=bs, shuffle=True,
        num_workers=workers, pin_memory=True)

    val_tfms = [
        crop_bg_pil,
        xy_tensor,
        tfm_x_wrapper(normalize)
    ]
    val_dataset = MatchedFilesDataset(val_x, val_y, val_tfms, path='')
    if val_bs is None: val_bs = bs
    val_loader = torch.utils.data.DataLoader(
        val_dataset, batch_size=val_bs, shuffle=False,
        num_workers=workers, pin_memory=True)

    train_loader = DataPrefetcher(train_loader)
    val_loader = DataPrefetcher(val_loader)
    
    data = ModelData(data_path, train_loader, val_loader)
    return data


In [20]:
def denorm(x):
    x_np = x.cpu().numpy()
    x_np = np.rollaxis(x_np, 0, 3)
    mean=np.array([0.485, 0.456, 0.406])
    std=np.array([0.229, 0.224, 0.225])
    x_np = x_np*std+mean
    return x_np

## U-net (ish)

In [21]:
from torchvision.models import vgg11_bn

In [22]:
def vgg11(pre): return children(vgg11_bn(pre))[0]

In [23]:
model_meta = {
    resnet18:[8,6], resnet34:[8,6], resnet50:[8,6], resnet101:[8,6], resnet152:[8,6],
    vgg11:[0,13], vgg16:[0,22], vgg19:[0,22],
    resnext50:[8,6], resnext101:[8,6], resnext101_64:[8,6],
    wrn:[8,6], inceptionresnet_2:[-2,9], inception_4:[-1,9],
    dn121:[0,7], dn161:[0,7], dn169:[0,7], dn201:[0,7],
}

In [24]:
def get_base(f, pretrain):
    cut,lr_cut = model_meta[f]
    layers = cut_model(f(pretrain), cut)
    return nn.Sequential(*layers), lr_cut

In [25]:
class SaveFeatures():
    features=None
    def __init__(self, m): self.hook = m.register_forward_hook(self.hook_fn)
    def hook_fn(self, module, input, output): self.features = output
    def remove(self): self.hook.remove()

In [26]:
class UnetBlock(nn.Module):
    def __init__(self, up_in, x_in, n_out):
        super().__init__()
        up_out = x_out = n_out//2
        self.x_conv  = nn.Conv2d(x_in,  x_out,  1)
        self.tr_conv = nn.ConvTranspose2d(up_in, up_out, 2, stride=2)
        self.bn = nn.BatchNorm2d(n_out)
        
    def forward(self, up_p, x_p):
        up_p = self.tr_conv(up_p)
        x_p = self.x_conv(x_p)
        cat_p = torch.cat([up_p,x_p], dim=1)
        return self.bn(F.relu(cat_p, inplace=True))

In [27]:
class UnetModel():
    def __init__(self,model,name='unet'):
        self.model,self.name = model,name

    def get_layer_groups(self, precompute=False):
        if isinstance(self.model, FP16):
            model = self.model.module
        else:
            model = self.model
        lgs = list(split_by_idxs(children(model.rn), [model.lr_cut]))
        return lgs + [children(model)[1:]]

In [28]:
def carce_f_p_r(pred, targs):
    _,idx = torch.max(pred, 1)
    return fbeta_score(idx==1, targs[:,:,:]==1, beta=2)

In [29]:
def rdce_f(pred, targs):
    _,idx = torch.max(pred, 1)
    f,p,r = fbeta_score(idx==2, targs[:,:,:]==2, beta=0.5)
    return f

In [30]:
def carsig_f_p_r(pred, targs):
    p2 = F.sigmoid(pred)
    return fbeta_score(p2[:,0,:,:], targs[:,0,:,:], beta=2, threshold=0.5)

In [31]:
def rdsig_f(pred, targs):
    p2 = F.sigmoid(pred)
    f,p,r = fbeta_score(p2[:,1,:,:], targs[:,1,:,:], beta=0.5, threshold=0.5)
    return f

In [32]:
def car_f_p_r(pred, targs):
    _,idx = torch.max(pred, 1)
    return fbeta_score(idx==1, targs[:,1,:,:], beta=2)

In [33]:
def rd_f(pred, targs):
    _,idx = torch.max(pred, 1)
    f,p,r = fbeta_score(idx==2, targs[:,2,:,:], beta=0.5)
    return f

In [34]:
def new_acc_sig(pred, targs):
    p2 = F.sigmoid(pred)
    return ((p2>0.5).long() == targs).float().mean()

In [35]:
def new_acc_ce(preds, targs):
    mx,idx = torch.max(preds, 1)
    return (idx == targs).float().mean()

In [36]:
def new_acc(pred, targs):
    _,idx = torch.max(pred, 1)
    _,t_idx = torch.max(targs,1)
    return (idx == t_idx).float().mean()

In [37]:
def dice_coeff_weight(pred, target, weight):
    smooth = 1.
    num,c,h,w = pred.shape
    m1 = pred.view(num, c, -1)  # Flatten
    m2 = target.view(num, c, -1)  # Flatten
    intersection = (m1 * m2)
    w = V(weight.view(1,-1,1))
    i_w = (w*intersection).sum()
    m1_w = (w*m1).sum()
    m2_w = (w*m2).sum()
    return (2. * i_w + smooth) / (m1_w + m2_w + smooth)

def dice_coeff(pred, target):
    smooth = 1.
    num,c,h,w = pred.shape
    m1 = pred.view(num, c, -1)  # Flatten
    m2 = target.view(num, c, -1)  # Flatten
    intersection = (m1 * m2).sum()
    return (2. * intersection + smooth) / (m1.sum() + m2.sum() + smooth)


class SoftDiceLoss(nn.Module):
    def __init__(self, weight=None, size_average=True, softmax=True):
        super(SoftDiceLoss, self).__init__()
        self.weight = weight
        self.softmax = softmax

    def forward(self, logits, targets):
        probs = F.softmax(logits) if self.softmax else F.sigmoid(logits)
        num = targets.size(0)  # Number of batches
        targets = torch.cat(((targets==0).unsqueeze(1), (targets==1).unsqueeze(1), (targets==2).unsqueeze(1)), dim=1).float()
        if isinstance(logits.data, torch.cuda.HalfTensor):
            targets = targets.half()
        else:
            targets = targets.float()
        if self.weight is not None:
            score = dice_coeff_weight(probs, targets, self.weight)
        else:
            score = dice_coeff(probs, targets)
        score = 1 - score.sum() / num
        return score

In [38]:
def fbeta_score(y_pred, y_true, beta, threshold=None, eps=1e-9):
    beta2 = beta**2

    if threshold:
        y_pred = torch.ge(y_pred.float(), threshold).float()
    else:
        y_pred = y_pred.float()
    y_true = y_true.float()

    true_positive = (y_pred * y_true).sum()
    precision = true_positive/(y_pred.sum()+(eps))
    recall = true_positive/(y_true.sum()+eps)
    
    fb = (precision*recall)/(precision*beta2 + recall + eps)*(1+beta2)
    
    return fb, precision, recall

In [39]:
def lyft_score(pred, target, weight):
    num,c,h,w = pred.shape
    pred = pred.view(num, c, -1)  # Flatten
    target = target.view(num, c, -1)  # Flatten
    intersection = (pred * target)
    int_sum = intersection.sum(dim=-1)
    pred_sum = pred.sum(dim=-1)
    targ_sum = target.sum(dim=-1)
    
    eps = 1e-9
    precision = int_sum / (pred_sum + eps)
    recall = int_sum / (targ_sum + eps)
    beta = V(weight ** 2)
    
    fnum = (1.+beta) * precision * recall
    fden = beta * precision + recall + eps
    
    fscore = fnum / fden
    
#     fb = (precision*recall)/precision*beta + recall + eps
    avg_w = torch.cuda.FloatTensor([0,.5,.5])
    if isinstance(pred.data, torch.cuda.HalfTensor):
        avg_w = avg_w.half()
    else:
        avg_w = avg_w.float()
    favg = V(avg_w) * fscore
#     pdb.set_trace()
    return favg.sum(dim=-1)

class FLoss(nn.Module):
    def __init__(self, weight=torch.cuda.FloatTensor([1,2,0.5]), softmax=True):
        super().__init__()
        self.weight = weight
        self.softmax = softmax

    def forward(self, logits, targets):
        probs = F.softmax(logits) if self.softmax else F.sigmoid(logits)
        num = targets.size(0)  # Number of batches
        targets = torch.cat(((targets==0).unsqueeze(1), (targets==1).unsqueeze(1), (targets==2).unsqueeze(1)), dim=1).float()
        if isinstance(logits.data, torch.cuda.HalfTensor):
            targets = targets.half()
        else:
            targets = targets.float()
            
        score = lyft_score(probs, targets, self.weight)
        score = 1 - score.sum() / num
        return score

In [40]:
class Unet34Mod(nn.Module):
    def __init__(self, out=3, f=resnet34, pretrain=True):
        super().__init__()
        m_base, lr_cut = get_base(f, pretrain)
        self.rn = m_base
        self.lr_cut = lr_cut
        self.sfs = [SaveFeatures(self.rn[i]) for i in [2,4,5,6]]
        self.up1 = UnetBlock(512,256,256)
        self.up2 = UnetBlock(256,128,256)
        self.up3 = UnetBlock(256,64,128)
        self.up4 = UnetBlock(128,64,64)
        self.up5 = UnetBlock(64,32,32)
        self.up6 = nn.ConvTranspose2d(32, out, 1)
        self.x_skip = nn.Sequential(
            nn.Conv2d(out,32,1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True)
        )
        
    def forward(self,x):
        x_skip = self.x_skip(x)
        x = self.rn(x)
        x = self.up1(x, self.sfs[3].features)
        x = self.up2(x, self.sfs[2].features)
        x = self.up3(x, self.sfs[1].features)
        x = self.up4(x, self.sfs[0].features)
        x = self.up5(x, x_skip)
        x = self.up6(x)
#         pdb.set_trace()
        return torch.squeeze(x)
    
    def close(self):
        for sf in self.sfs: sf.remove()

In [43]:
class FP16Loss(nn.Module):
    def __init__(self, crit):
        super().__init__()
        self.crit = crit
        if isinstance(self.crit, torch.nn.modules.loss._Loss):
            crit.reduce = False
    def forward(self, *args, **kwargs):
        if not isinstance(self.crit, torch.nn.modules.loss._Loss):
            return self.crit(*args, **kwargs)
        loss = self.crit(*args, **kwargs)
        l2 = loss.float().sum()
        sz = float(np.prod(np.array(loss.shape)))
        m_loss = (l2/sz).half()
        return m_loss

In [48]:
def get_learner(md, m_fn=Unet34Mod, weights=[1,200,2], half=False, softmax=True, dice=False, floss=False):
    out_sz = 3 if softmax else 2
    m = to_gpu(m_fn(out_sz))
    models = UnetModel(m)
    learn = ConvLearner(md, models)
    learn.opt_fn=optim.Adam
    class_weights = torch.cuda.FloatTensor(weights)
    if half:
        class_weights = class_weights.half()
        learn.half()
        
    if dice: learn.crit=SoftDiceLoss(weight=class_weights, softmax=softmax)
    elif floss: learn.crit = FLoss(weight=class_weights, softmax=softmax)
    else: learn.crit=nn.CrossEntropyLoss(weight=class_weights)
    
    learn.crit = FP16Loss(learn.crit)
    if softmax: learn.metrics = [new_acc_ce, rdce_f, carce_f_p_r]
    else: learn.metrics = [new_acc_sig, rdsig_f, carsig_f_p_r]
    # learn.metrics=[new_acc, rd_f, car_f_p_r]
    
    return learn

### Models

In [56]:
ext = '-300'
sz=128
bs=64
random_crop=True
md = torch_loader(ext, PATH, bs, sz, workers, random_crop, pseudo_label, val_folder)

Val Labels: 30
Val x:30, y:30
Trn x:500, y:500
All x:500


In [57]:
learn = get_learner(md, m_fn=Unet34Mod, weights=[1,4,1], softmax=True, half=True)

In [58]:
lr=1e-3
wd=1e-7

lrs = np.array([lr/200,lr/20,lr])

In [59]:
learn.unfreeze()
learn.set_bn_freeze(learn.model.module.rn, True)

In [60]:
learn.fit(lrs, 1, wds=wd, cycle_len=10,use_clr=(20,2))

Model params: 144
Group params: 48
Group params: 60
Group params: 36


HBox(children=(IntProgress(value=0, description='Epoch', max=10), HTML(value='')))

epoch      trn_loss   val_loss   new_acc_ce rdce_f     carce_f_p_r 
    0      1.18058    1.173828   0.300717   0.351396   0.005949   0.010313   0.00538   
    1      1.146974   1.174805   0.29942    0.349933   0.013987   0.01642    0.013487  
    2      1.090882   1.155273   0.319229   0.369258   0.04286    0.019213   0.061909  
    3      1.012394   1.026367   0.73976    0.624092   0.042746   0.024187   0.052893  
    4      0.926098   0.696777   0.924643   0.890211   0.30219    0.194525   0.350719  
    5      0.841845   0.452637   0.9519     0.949771   0.379328   0.2823     0.414986  
    6      0.764642   0.377441   0.954611   0.951573   0.380536   0.285304   0.415182  
    7      0.699655   0.368652   0.954294   0.949573   0.389349   0.286811   0.427564  
    8      0.645722   0.379883   0.950798   0.939076   0.40101    0.284935   0.446481  
    9      0.601227   0.40332    0.947993   0.934336   0.395824   0.271608   0.446923  



[array([0.40332]),
 0.9479934573173523,
 0.9343362418746435,
 0.3958241765224263,
 0.2716075187745372,
 0.4469229824345794]

In [None]:
learn.fit(.00001, 1, cycle_len=1)

In [44]:
%pdb on

Automatic pdb calling has been turned ON


In [69]:
out_sz = 3
m = to_gpu(Unet34Mod(out_sz, pretrain=True))
models = UnetModel(m)
# learn = Learner(md, models)

learn = ConvLearner(md, models)
# learn.opt_fn=optim.Adam
learn.crit=FP16Loss(nn.CrossEntropyLoss())
# learn.crit = nn.CrossEntropyLoss(reduce=True)
learn.half()
learn.metrics = [new_acc_ce, rdce_f, carce_f_p_r]
    

In [70]:
learn.fit(.00001, 1, cycle_len=1)

Model params: 36
Group params: 0
Group params: 0
Group params: 36


HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))

  0%|          | 0/16 [00:00<?, ?it/s]Variable containing:
 1.1943
[torch.cuda.HalfTensor of size 1 (GPU 0)]

  6%|▋         | 1/16 [00:00<00:08,  1.81it/s, loss=1.19]Variable containing:
 1.1934
[torch.cuda.HalfTensor of size 1 (GPU 0)]

  6%|▋         | 1/16 [00:00<00:08,  1.70it/s, loss=1.19]Variable containing:
 1.1992
[torch.cuda.HalfTensor of size 1 (GPU 0)]

  6%|▋         | 1/16 [00:00<00:09,  1.58it/s, loss=1.2] Variable containing:
 1.1943
[torch.cuda.HalfTensor of size 1 (GPU 0)]

 25%|██▌       | 4/16 [00:00<00:02,  5.98it/s, loss=1.2]Variable containing:
 1.1924
[torch.cuda.HalfTensor of size 1 (GPU 0)]

 25%|██▌       | 4/16 [00:00<00:02,  5.74it/s, loss=1.19]Variable containing:
 1.1914
[torch.cuda.HalfTensor of size 1 (GPU 0)]

 25%|██▌       | 4/16 [00:00<00:02,  5.55it/s, loss=1.19]Variable containing:
 1.1924
[torch.cuda.HalfTensor of size 1 (GPU 0)]

 25%|██▌       | 4/16 [00:00<00:02,  5.40it/s, loss=1.19]Variable containing:
 1.1953
[torch.cuda.HalfTensor of size 

[array([1.15332]),
 0.01840432733297348,
 0.016891920825063265,
 0.07953820448846301,
 0.017088578830159493,
 0.9207561955136722]

### Step by step tests

In [49]:
ext = '-150'
sz=64
bs=64
random_crop=True
md = torch_loader(ext, PATH, bs, sz, workers, random_crop, pseudo_label, val_folder)

Val Labels: 30
Val x:30, y:30
Trn x:500, y:500
All x:500


In [50]:
out_sz = 3
m = to_gpu(Unet34Mod(out_sz, pretrain=False))
# models = UnetModel(m)
# learn = Learner(md, models)
# # learn.opt_fn=optim.Adam
# class_weights = torch.cuda.FloatTensor([1,1,1])
# learn.crit=nn.CrossEntropyLoss(weight=class_weights).half()
# learn.half()
# learn.metrics = [new_acc_ce, rdce_f, carce_f_p_r]
    

In [51]:
x,y = next(iter(md.trn_dl))

m16 = FP16(m)

crit = nn.CrossEntropyLoss(reduce=False).half()
# crit = FP16Loss(nn.CrossEntropyLoss().half())

out = m16(V(x))

loss = crit(out, V(y))
# lm = loss.mean()

In [52]:
l2 = loss.float().sum()

In [53]:
sz = float(np.prod(np.array(loss.shape)))

In [54]:
wtf = (l2/sz).half()

In [55]:
wtf.backward()

In [None]:
wtf = loss.float().sum()/loss.size(0)/loss.size(1)/loss.size(2)

In [None]:
lol = wtf.half()

In [None]:
loss.backward(lm.data)

In [None]:
sz = torch.cuda.LongTensor(int(np.prod(np.array(loss.shape))))

In [None]:
type(loss.sum())

In [None]:
torch.autograd.Variable

In [None]:
loss.sum()/V(sz.half())

In [None]:
loss.sum()/torch.autograd.Variable(sz.float(), required_grad=False)

In [None]:
loss.sum().backward()

In [None]:
loss.grad_fn

In [None]:
np.prod(np.array(out.shape))

In [None]:
torch.cuda.HalfTensor([49152])

In [None]:
torch.cuda.HalfTensor([6e4])

In [None]:
torch.cuda.FloatTensor([1e38])

In [None]:
lm.backward()

In [None]:
loss.sum().backward()

In [None]:
lm.mean()

In [None]:
loss.mean().backward()

In [None]:
ls = F.log_softmax(out)

In [None]:
F.nll_loss(ls, V(y), reduce=False).mean()

In [None]:
x,y = next(iter(md.trn_dl))

crit = nn.CrossEntropyLoss()

m.float()
out = m(V(x))

print(out.mean())

print(crit(out, V(y)))

In [None]:
learn.fit(.00001, 1, cycle_len=1)

In [None]:
m

In [None]:
list(m.children())

In [None]:
m16 = FP16(m)

In [None]:
def batchnorm_to_fp32_t(module):
    '''
    BatchNorm layers to have parameters in single precision.
    Find all layers and convert them back to float. This can't
    be done with built in .apply as that function will apply
    fn to all modules, parameters, and buffers. Thus we wouldn't
    be able to guard the float conversion based on the module type.
    '''
    if isinstance(module, nn.modules.batchnorm._BatchNorm):
        module.float()
        print('is batchnorm:', module)
#     else:
#         print('Not batchnorm:', module)
    for child in module.children():
        batchnorm_to_fp32_t(child)
    return module

In [None]:
t2 = batchnorm_to_fp32_t(m.half())

In [None]:
mc = list(m.children())

In [None]:
isinstance(list(mc[0].children())[1], nn.modules.batchnorm._BatchNorm)

In [None]:
m16

In [None]:
learn.fit(.00001, 1, cycle_len=1)

In [None]:
learn.fit(.00001, 1, cycle_len=1)