In [1]:
from fastai.vision.all import *

In [2]:
import pandas as pd

In [3]:
from tqdm.notebook import tqdm

In [4]:
def label_func(x):
    return str(x['file']).replace("images", "../../data/data/masks_2class")

In [5]:
def acc_seg(input, target):
    target = target.squeeze(1)
    return (input.argmax(dim=1)==target).float().mean()

def multi_dice(input:Tensor, targs:Tensor, class_id=0, inverse=False):
    n = targs.shape[0]
    input = input.argmax(dim=1).view(n,-1)
    # replace all with class_id with 1 all else with 0 to have binary case
    output = (input == class_id).float()
    # same for targs
    targs = (targs.view(n,-1) == class_id).float()
    if inverse:
        output = 1 - output
        targs = 1 - targs
    intersect = (output * targs).sum(dim=1).float()
    union = (output+targs).sum(dim=1).float()
    res = 2. * intersect / union
    res[torch.isnan(res)] = 1
    return res.mean()

def diceComb(input:Tensor, targs:Tensor):
    return multi_dice(input, targs, class_id=0, inverse=True)
def diceLV(input:Tensor, targs:Tensor):
    return multi_dice(input, targs, class_id=1)
def diceMY(input:Tensor, targs:Tensor):
    return multi_dice(input, targs, class_id=2)

In [6]:
fullImgList = pd.read_csv("../../analysis/kaggle/image_list.tsv.xz", sep="\t", header=None, names=["pid","file"])

In [7]:
trainedModel = load_learner("fastai2/model-unfrozen-30.pkl", cpu=False)

In [9]:
pixelTable = pd.DataFrame({'file': [], 'lv_pixels': [], 'my_pixels': []})
for i in tqdm(range(int(fullImgList.shape[0]/10000)+1)):
    imgInBatch = fullImgList[(10000*i):(10000*(i+1))]
    testDL = trainedModel.dls.test_dl(imgInBatch)
    predictions,_=trainedModel.get_preds(dl=testDL)
    predictions = predictions.argmax(dim=1)
    lv_pixels = (predictions==1).sum(dim=(1,2))
    my_pixels = (predictions==2).sum(dim=(1,2))
    pixelTable = pd.concat([pixelTable, pd.DataFrame({'file': testDL.items['file'], 'lv_pixels': lv_pixels, 'my_pixels': my_pixels})])

pixelTable.to_csv("fastai2/predictions.tsv",sep="\t",index=False)

  0%|          | 0/44 [00:00<?, ?it/s]