In [None]:
import json
import os
import nrrd
from tqdm import tqdm
import numpy as np
import sys
from pathlib import Path
sys.path.append('..')
from data_utils.helpers_classification import get_pos_coords, get_neg_coords, get_nearly_pos_coords, get_vol_hard_mask, get_hard_neg_coords, get_vol_paths, normalize_vols

In [None]:
with open('../dataset/classification/dataset.json', 'r') as f:
    meta = json.load(f)

In [None]:
test_meta = { k:v for k,v in meta.items() if k != 'vol_meta' }
test_meta['N'] = 20000

In [None]:
valid_meta = { k:v for k,v in meta['vol_meta'].items() if v['split'] == 'valid' }

In [None]:
for vol_id, vol_meta in valid_meta.items():
    vol_meta['patches'] = set( tuple(coords[:-1]) for coords in vol_meta['patches'])

In [None]:
vol_paths = get_vol_paths('../dataset/raw/ASOCA2020Data/')

In [None]:
vol_paths = [(x[0],
  Path(f'../dataset/classification/{meta["vol_meta"][str(x[0])]["split"]}/vols/{x[0]}.npy'),
  *x[2:]
 ) for x in vol_paths]

In [None]:
n_patches = 20000
patch_size = 68

In [None]:
from typing import Callable
from functools import partial

In [None]:
def sample_new(sample_fn:Callable, n_samples:int, already_sampled: set)->np.ndarray:
    res = set()
    while len(res) < n_samples:
        samples = sample_fn()
        label = samples[0][-1]
        samples = set([tuple(coord[:-1]) for coord in samples.tolist()])
        samples = samples - already_sampled
        res = res.union(samples)
    return np.array([ [*x, label] for x in res ])[:n_samples].astype(int)

In [None]:
assert n_patches % 4 == 0
res = {}
n_per_vol = n_patches // len(valid_meta)

for vol_id, vol_path, targ_path, heart_mask_path in tqdm(vol_paths):
    if str(vol_id) not in valid_meta: continue
    vol           = np.load(vol_path)
    targs, _      = nrrd.read(targ_path, index_order='C')
    heart_mask, _ = nrrd.read(heart_mask_path, index_order='C')          

    targs = targs.astype(np.uint8)
    heart_mask = heart_mask.astype(np.uint8)
    
    already_sampled = valid_meta[str(vol_id)]['patches']

    n_pos = n_per_vol//2
    pos_coords = sample_new(
        partial(get_pos_coords, targs, patch_size, n_samples=n_pos),
        n_samples=n_pos,
        already_sampled=already_sampled
    )
    
    n_neg_rand = n_per_vol//10
    neg_coords = sample_new(
        partial(get_neg_coords, targs, heart_mask, patch_size, n_neg_rand),
        n_samples=n_neg_rand,
        already_sampled=already_sampled
    )
    
    n_near_pos = 2*n_per_vol//10
    neg_near_pos_coords = sample_new(
        partial(get_nearly_pos_coords, targs, patch_size, n_near_pos, offset=8),
        n_samples=n_near_pos,
        already_sampled=already_sampled
    )
    
    vol_hard_mask = get_vol_hard_mask(vol, targs, heart_mask)
    hard_neg_coords = get_hard_neg_coords(vol_hard_mask, targs, already_sampled, patch_size, 2*n_per_vol//10)

    coords = np.vstack((pos_coords, neg_near_pos_coords, neg_coords, hard_neg_coords)).astype(int)
    coords = np.random.permutation(coords)
    res[vol_id] = coords

In [None]:
for vol_id in res:
    coords = set([tuple(x[:-1]) for x in res[vol_id]])
    already_sampled = valid_meta[str(vol_id)]['patches']
    assert len(coords.intersection(already_sampled)) == 0, vol_id

In [None]:
res

In [None]:
test_meta['vol_meta'] = {k: {
        'split': 'valid',
        'n_patches': len(v),
        'patches': v.tolist(),
    } for k,v in res.items()}

In [None]:
with open('../dataset/classification/dataset_test.json', 'w') as f:
    json.dump(test_meta, f)

In [None]:
from data_utils.datamodule import AsocaClassificationDataset

In [None]:
with open('../dataset/classification/dataset_test.json', 'r') as f:
    test_meta = json.load(f)

In [None]:
ds = AsocaClassificationDataset(
    ds_path='../dataset/classification',
    meta_fname='dataset_test.json',
    split='valid')

In [None]:
ds.file_ids

In [None]:
with open('../dataset/classification/dataset.json', 'r') as f:
    meta = json.load(f)
with open('../dataset/classification/dataset_test.json', 'r') as f:
    meta_test = json.load(f)

In [None]:
from collections import Counter

In [None]:
for vol_id in meta['vol_meta']:
    print(vol_id, 
          meta['vol_meta'][vol_id]['n_patches'],
          len(set([tuple(x[:-1]) for x in meta['vol_meta'][vol_id]['patches']])),
          Counter([x[-1] for x in meta['vol_meta'][vol_id]['patches']])
    )

In [None]:
for vol_id in meta_test['vol_meta']:
    coords_valid = set([tuple(x[:-1]) for x in meta['vol_meta'][vol_id]['patches']])
    coords_test = set([tuple(x[:-1]) for x in meta_test['vol_meta'][vol_id]['patches']])
    print(vol_id, 
          meta_test['vol_meta'][vol_id]['n_patches'],
          len(coords_test),
          len(coords_test.intersection(coords_valid)),
          Counter([x[-1] for x in meta_test['vol_meta'][vol_id]['patches']])
    )

In [None]:
import wandb
import torch
import shutil
import sys
sys.path.append('..')
from train import get_class
from data_utils.datamodule import AsocaClassificationDataModule
import pytorch_lightning as plt


data_dir='../dataset/classification'
model_dir = '/var/scratch/ebekkers/damyan/models'
runs = wandb.Api().runs(path='ASOCA_final', filters={
    'config.seed': {'$in': [0,11,42]},
    'config.model/model': {"$regex": "models.classification.*"}, 
    })

# trainer = plt.Trainer(gpus=4, accelerator='ddp', replace_sampler_ddp=False)
# trainer = plt.Trainer(gpus=1)
dm = AsocaClassificationDataModule(data_dir=data_dir)


In [None]:
run = [run for run in runs if run.name == 'gallant-sun-765'][0]
model_params = { k.split('/')[-1]:v for k,v in run.config.items() if 'model' in k }
class_name = model_params['model']
del model_params['model']
if 'initialize' in model_params: model_params['initialize'] = False

with run.files()[0].download(model_dir, replace=True)as model_f:
    model: plt.LightningModule = get_class(class_name)(**model_params)
    with open(model_f.name, 'rb') as f:
        ckpt = torch.load(f)
    state_dict = {k:v for k,v in ckpt['state_dict'].items() if 'in_indices_' not in k}
    model.load_state_dict(state_dict, strict=False)
    all_preds = torch.empty(20000)
    all_targs = torch.empty(20000)
    bs = 32
    for i, x, targs in enumerate(dm.test_dataloader(batch_size=bs)):
        x = x.cuda()
        preds = model(x).cpu().detach()
        all_preds[i*bs:i*bs+bs] = preds
        all_targs[i*bs:i*bs+bs] = targs
        

In [None]:
shutil.rmtree(f'{model_dir}/{run.project}/{run.id}')

In [None]:
import numpy as np

In [None]:
mp = np.load('../manual_preds.npy')

In [None]:
preds, targs = mp[0], mp[1]

In [None]:
preds = preds.round()

In [None]:

acc = (preds == targs).mean()

In [None]:
acc

In [None]:
tp = np.sum((preds == targs) & (preds == 1.))
tn = np.sum((preds == targs) & (preds == 0.))
fp = np.sum((preds != targs) & (preds == 1.))
fn = np.sum((preds != targs) & (preds == 0.))

In [None]:
assert sum([tp, tn, fp, fn]) == 20000

In [None]:
tp, tn, fp, fn

In [None]:
precision = tp / (tp+fp)
recall = tp / (tp + fn)

In [None]:
precision, recall

In [None]:
f1 = 2 * precision * recall / (precision + recall)

In [None]:
f1