In [15]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))
import h5py
from matplotlib import pyplot as plt
import torch
import numpy as np
from pytorch3dunet.unet3d.losses import *
from pytorch3dunet.unet3d.metrics import MeanIoU
from pathlib import Path
import glob
import os

import matplotlib
matplotlib.use('TkAgg')

%matplotlib inline

In [24]:
class DiceProbLoss(DiceLoss):
    def __init__(self):
        super(DiceProbLoss, self).__init__(normalization='none')

class MeanIoUBin(MeanIoU):
    def __init__(self):
        super(MeanIoUBin, self).__init__(is_binarized=True)

class BCEDiceBinLoss(nn.Module):
    """Linear combination of BCE and Dice losses"""

    def __init__(self, alpha, beta):
        super(BCEDiceBinLoss, self).__init__()
        self.alpha = alpha
        self.bce = nn.BCELoss()
        self.beta = beta
        self.dice = DiceBinLoss()

    def forward(self, input, target):
        return self.alpha * self.bce(input, target) + self.beta * self.dice(input, target)


initLosses = {
    "BCE": nn.BCELoss,
    "Dice": DiceProbLoss,
    "MeanIoU": MeanIoU
}

In [25]:
basepred = Path('/home/lorenzo/3dunet-cavity/runs/run_210623_gpu/predictions')
baseorig = Path('/home/lorenzo/deep_apbs/destData/refined-set_filter')

def genDataSets():

    for predfname in glob.glob(str(basepred / '*_grids_predictions.h5')):
        name = Path(predfname).name.split('_')[0]
        labelfname = baseorig / name / f'{name}_grids.h5'
        
        if not os.path.exists(labelfname):
            print(f'{labelfname} does not exist.')
            continue
        
        labelT = torch.tensor(h5py.File(labelfname)['label'], dtype=torch.float32)
        labelT = labelT[None,None]
        predT = torch.tensor(h5py.File(predfname)['predictions'])
        predT = predT[None]
        
        yield name, (predT, labelT)

In [26]:
class RunningAverage:
    def __init__(self, loss):
        self.count = 0
        self.sum = 0
        self.loss = loss
        self.losses = {}

    def update(self, pred, label, name):
        self.count += 1
        loss = self.loss(pred, label).item()
        self.sum += loss
        self.losses[name] = loss

    def value(self):
        return self.sum / self.count
    
class AverageLosses:
    def __init__(self, losses, kwargs=None):
        if kwargs is None:
            kwargs = {name:{} for name in losses.keys()}
        self.losses = {name: RunningAverage(loss(**kwargs[name])) for name,loss in losses.items()}

    def update(self, pred, label, prot):
        for name in self.losses.keys():
            self.losses[name].update(pred,label, prot)
        
    def value(self):
        return {name: loss.value() for name,loss in self.losses.items()}

In [27]:
OracleLoss = AverageLosses(initLosses)
UnetLoss = AverageLosses(initLosses)
RandLoss = AverageLosses(initLosses)
RandUnitLoss = AverageLosses(initLosses)
ZeroLoss = AverageLosses(initLosses)
UnitLoss = AverageLosses(initLosses)

for name, (predT, labelT) in genDataSets():
    constPred = torch.zeros_like(labelT)
    ZeroLoss.update(constPred, labelT, name)

    constPred[:] = 1
    UnitLoss.update(constPred, labelT, name)

    UnetLoss.update(predT, labelT, name)

    OracleLoss.update(labelT, labelT, name)

    randPred = torch.rand(size=constPred.size(), dtype=constPred.dtype, device=constPred.device, requires_grad=False)
    RandLoss.update(randPred, labelT, name)
    
    randPred[randPred < 0.5] = 0
    randPred[randPred > 0.5] = 1
    RandUnitLoss.update(randPred, labelT, name)

/home/lorenzo/deep_apbs/destData/refined-set_filter/1oba/1oba_grids.h5 does not exist.
/home/lorenzo/deep_apbs/destData/refined-set_filter/2epn/2epn_grids.h5 does not exist.
/home/lorenzo/deep_apbs/destData/refined-set_filter/1tkb/1tkb_grids.h5 does not exist.
/home/lorenzo/deep_apbs/destData/refined-set_filter/1rnt/1rnt_grids.h5 does not exist.
/home/lorenzo/deep_apbs/destData/refined-set_filter/2h21/2h21_grids.h5 does not exist.
/home/lorenzo/deep_apbs/destData/refined-set_filter/2am4/2am4_grids.h5 does not exist.
/home/lorenzo/deep_apbs/destData/refined-set_filter/1x8d/1x8d_grids.h5 does not exist.
/home/lorenzo/deep_apbs/destData/refined-set_filter/2euk/2euk_grids.h5 does not exist.
/home/lorenzo/deep_apbs/destData/refined-set_filter/1lnm/1lnm_grids.h5 does not exist.
/home/lorenzo/deep_apbs/destData/refined-set_filter/1o1s/1o1s_grids.h5 does not exist.
/home/lorenzo/deep_apbs/destData/refined-set_filter/1r9l/1r9l_grids.h5 does not exist.
/home/lorenzo/deep_apbs/destData/refined-se

In [28]:
RandUnitLoss.value(), RandLoss.value(), UnitLoss.value(), ZeroLoss.value(),OracleLoss.value()

({'BCE': 49.99950144841121,
  'Dice': 0.997966271180373,
  'MeanIoU': 0.0010180229527087738},
 {'BCE': 0.9999702343573937,
  'Dice': 0.9969496108018435,
  'MeanIoU': 0.0010180229527087738},
 {'BCE': 99.89769275371845,
  'Dice': 0.997956259892537,
  'MeanIoU': 0.001023037529263932},
 {'BCE': 0.1023037525323721, 'Dice': 1.0, 'MeanIoU': 0.0},
 {'BCE': 0.0, 'Dice': 0.0, 'MeanIoU': 1.0})

In [32]:
UnetLoss.losses['MeanIoU'].losses

{'1sw2': 0.4645071029663086,
 '1fiv': 0.6156509518623352,
 '1a4r': 0.17319780588150024,
 '4gbd': 0.6333988904953003,
 '5g1z': 0.12134632468223572,
 '5gja': 0.3489736020565033,
 '2fqt': 0.29123446345329285,
 '4r75': 0.567950963973999,
 '4non': 0.15870153903961182,
 '3fjg': 0.5126749873161316,
 '3s0e': 0.493194043636322,
 '1dy4': 0.5372849106788635,
 '1dqn': 0.2584525942802429,
 '1m48': 0.027618248015642166,
 '5yz2': 0.20882698893547058,
 '4fxp': 0.48952972888946533,
 '5mxf': 0.0,
 '1ctt': 0.0,
 '1ceb': 0.2004687339067459,
 '5km9': 0.18358531594276428,
 '2dw7': 0.008737157098948956,
 '1t7d': 0.22022944688796997,
 '5tcj': 0.03971793130040169,
 '5vij': 0.4260115623474121,
 '4jwk': 0.23512376844882965,
 '5tuz': 0.09520634263753891}

In [None]:
thress = np.linspace(0.1,1.0,10)

kwargs = {f"MeanIoU_{thres}": {'thres':thres} for thres in thress}
initLosses = {f"MeanIoU_{thres}": MeanIoU for thres in thress}

UnetLosses = AverageLosses(initLosses, kwargs)

for name, (predT, labelT) in genDataSets():
    UnetLosses.update(predT, labelT, name)

In [None]:
UnetLosses.losses['MeanIoU'].losses

## Scratch

In [17]:
name = '1z95'

dstbase = '/home/lorenzo/3dunet-cavity/npy'

predfname = f'/home/lorenzo/3dunet-cavity/runs/run_210601/predictions/{name}_grids_predictions.h5'
origfname = f'/home/lorenzo/deep_apbs/destData/pdbbind_v2013_core_set_0/{name}/{name}_grids.h5'


vlabel = np.array(h5py.File(origfname)['label'])
with open(f'{dstbase}/{name}_label.npy','wb') as f:
    np.save(f, vlabel)

# h5py.File(predbase)['predictions']
