In [6]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))
import h5py
import torch
import numpy as np
from pytorch3dunet.unet3d.losses import *
from pytorch3dunet.unet3d.metrics import MeanIoU
from pathlib import Path
import glob
import seaborn as sns

In [7]:
class DiceBinLoss(DiceLoss):
    def __init__(self):
        super(DiceBinLoss, self).__init__(normalization='none')

class MeanIoUBin(MeanIoU):
    def __init__(self):
        super(MeanIoUBin, self).__init__(is_binarized=True)

class BCEDiceBinLoss(nn.Module):
    """Linear combination of BCE and Dice losses"""

    def __init__(self, alpha, beta):
        super(BCEDiceBinLoss, self).__init__()
        self.alpha = alpha
        self.bce = nn.BCELoss()
        self.beta = beta
        self.dice = DiceBinLoss()

    def forward(self, input, target):
        return self.alpha * self.bce(input, target) + self.beta * self.dice(input, target)


initLosses = {
    "BCE": nn.BCELoss,
    "Dice": DiceBinLoss,
    "MeanIoU": MeanIoUBin
}

In [8]:
basepred = Path('/home/lorenzo/3dunet-cavity/runs/run_210601_local/predictions')
baseorig = Path('/home/lorenzo/deep_apbs/destData/pdbbind_v2013_core_set_0') # /2yfe/2yfe_grids.h5'

def genDataSets():

    for predfname in glob.glob(str(basepred / '*_grids_predictions.h5')):
        name = Path(predfname).name.split('_')[0]
        labelfname = baseorig / name / f'{name}_grids.h5'
        
        labelT = torch.tensor(h5py.File(labelfname)['label'], dtype=torch.float32)
        labelT = labelT[None,None]
        predT = torch.tensor(h5py.File(predfname)['predictions'])
        predT = predT[None]
        
        yield (predT, labelT)

In [9]:
class RunningAverage:
    def __init__(self, loss):
        self.count = 0
        self.sum = 0
        self.loss = loss

    def update(self, pred, label):
        self.count += 1
        self.sum += self.loss(pred, label).item()

    def value(self):
        return self.sum / self.count
    
class AverageLosses:
    def __init__(self, losses):
        self.losses = {name: RunningAverage(loss()) for name,loss in losses.items()}

    def update(self, pred, label):
        for name in self.losses.keys():
            self.losses[name].update(pred,label)
        
    def value(self):
        return {name: loss.value() for name,loss in self.losses.items()}


OracleLoss = AverageLosses(initLosses)
UnetLoss = AverageLosses(initLosses)
RandLoss = AverageLosses(initLosses)
RandUnitLoss = AverageLosses(initLosses)
ZeroLoss = AverageLosses(initLosses)
UnitLoss = AverageLosses(initLosses)

for predT, labelT in genDataSets():
    constPred = torch.zeros_like(labelT)
    ZeroLoss.update(constPred, labelT)

    constPred[:] = 1
    UnitLoss.update(constPred, labelT)

    UnetLoss.update(predT, labelT)

    OracleLoss.update(labelT, labelT)

    randPred = torch.rand(size=constPred.size(), dtype=constPred.dtype, device=constPred.device, requires_grad=False)
    RandLoss.update(randPred, labelT)
    
    randPred[randPred < 0.5] = 0
    randPred[randPred > 0.5] = 1
    RandUnitLoss.update(randPred, labelT)

In [None]:
torch.sum(prediction & target).float() / torch.clamp(torch.sum(prediction | target).float(), min=1e-8)

In [16]:
predT, labelT = next(genDataSets())

In [17]:
predT.max()

tensor(0.1157)

In [10]:
RandUnitLoss.value()

{'BCE': 49.99549454909105,
 'Dice': 0.9976705129329975,
 'MeanIoU': 0.0011661801209601646}

In [11]:
RandLoss.value()

{'BCE': 0.9998995753434988, 'Dice': 0.9965128898620605, 'MeanIoU': 0.0}

In [12]:
UnitLoss.value()

{'BCE': 99.88337824894832,
 'Dice': 0.9976704808381888,
 'MeanIoU': 0.001166191807267471}

In [13]:
ZeroLoss.value()

{'BCE': 0.11661918060137676, 'Dice': 1.0, 'MeanIoU': 0.0}

In [14]:
UnetLoss.value()

{'BCE': 0.0060122772657240815, 'Dice': 0.9599529137978187, 'MeanIoU': 0.0}

In [15]:
OracleLoss.value()

{'BCE': 0.0, 'Dice': 0.0, 'MeanIoU': 1.0}

## Scratch

In [180]:
name = '1z95'

dstbase = '/home/lorenzo/3dunet-cavity/npy'

predfname = f'/home/lorenzo/3dunet-cavity/runs/run_210601/predictions/{name}_grids_predictions.h5'
origfname = f'/home/lorenzo/deep_apbs/destData/pdbbind_v2013_core_set_0/{name}/{name}_grids.h5'


vlabel = np.array(h5py.File(origfname)['label'])
with open(f'{dstbase}/{name}_label.npy','wb') as f:
    np.save(f, vlabel)

# h5py.File(predbase)['predictions']


array([[[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]],

       [[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]],

       [[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]],

       ...,

       [[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]],

       [[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [0, 0, 0, ..., 