In [1]:
# load pre-trained model from hub

from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation

processor = SegformerImageProcessor.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
model = SegformerForSemanticSegmentation.from_pretrained(r"doggywastaken/segformer-b0-finetuned-bmri-prep")



In [2]:
# load dataset from hub

from datasets import load_dataset

hf_dataset_identifier = f"doggywastaken/manual_breast_segs"
ds = load_dataset(hf_dataset_identifier)

ds = ds.shuffle(seed=1337)
ds = ds["train"].train_test_split(test_size=0.2)
train_ds = ds["train"]
test_ds = ds["test"]

import json
from huggingface_hub import hf_hub_download

repo_id = f"datasets/{hf_dataset_identifier}"
filename = "id2label.json"
id2label = json.load(open(hf_hub_download(repo_id=hf_dataset_identifier, filename=filename, repo_type="dataset"), "r"))
id2label = {int(k): v for k, v in id2label.items()}
label2id = {v: k for k, v in id2label.items()}

num_labels = len(id2label)

In [3]:
# import support functions from pyfunctionns in 0_Support_Functions for jpy-based inference using ViT
import sys
sys.path.insert(0, '../00_support_functions/') # add support function directory to path
from hvl_tools import *
from pyfunctions import *

# inference
pred_both = inference_on_nii_both(r'../testdata/sub-001/ses-01/sub-001_ses-01_ref.nii', processor, model)


100%|██████████| 112/112 [00:16<00:00,  6.86it/s]
100%|██████████| 112/112 [00:16<00:00,  6.64it/s]


In [4]:
# split predicted into right and left
# load reference dataset for underlay
ref_l = load_nii(r'../testdata/sub-001/ses-01/sub-001_ses-01_ref.nii', 'left')
ref_r = load_nii(r'../testdata/sub-001/ses-01/sub-001_ses-01_ref.nii', 'right')
ref = np.concatenate([ref_l, np.flip(ref_r, 2)], axis=2) 

fullwidth = pred_both.shape[2]
midpoint = int(np.floor(fullwidth/2))
pred_l = pred_both[:, :, 0:midpoint]
pred_r = np.flip(pred_both[:, :, midpoint:fullwidth],2)
mask_l = bbox3_masked(pred_l, ref_l, 15)
mask_r = bbox3_masked(pred_r, ref_r, 15)

imyshow(mask_l)
imyshow(mask_r)



(93, 129, 148)


interactive(children=(IntSlider(value=1, description='z', max=92, min=1), Output()), _dom_classes=('widget-int…

(83, 113, 146)


interactive(children=(IntSlider(value=1, description='z', max=82, min=1), Output()), _dom_classes=('widget-int…