In [19]:
from inference.inferenceSdk import SegModel
from network_training.model_restore import load_model_and_checkpoint_files_llm
from run.load_pretrained_weights import *
import json
import os
import numpy as np
from dataset.utils import nnUNet_resize
import torch
import random

os.environ["CUDA_VISIBLE_DEVICES"] = '0'
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')

config = {
    'llm_folder':'/folder/store/llm_checkpoint',
    'seg_folder':'/folder/store/seg_checkpoint',
    'llm_chk':'llm_checkpoint_name',
    'seg_chk':'seg_checkpoint_name',
    'output_dir':'/folder/save/output/mask/and/report',
    'eval_mode':'region_segtool'
}

In [None]:
def get_trainer(config):
    trainer, params = load_model_and_checkpoint_files_llm(config['llm_folder'], mixed_precision=True,
                                                checkpoint_name=config['llm_chk'])
    trainer.load_checkpoint_ram(params[0], False)
    load_pretrained_weights(trainer.network, join(config['seg_folder'],config['seg_chk']+'.model'))

    return trainer

trainer = get_trainer(config)
segmodel = SegModel(config)

In [21]:
def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True

print("emptying cuda cache")
torch.cuda.empty_cache()

setup_seed(42)
trainer.network.eval()
trainer.llm_model.eval()

emptying cuda cache


LanguageModel(
  (gpt_with_lm_head): GPT2LMHeadModel(
    (transformer): GPT2Model(
      (wte): Embedding(50257, 1024)
      (wpe): Embedding(1024, 1024)
      (drop): Dropout(p=0.1, inplace=False)
      (h): ModuleList(
        (0-23): 24 x GPT2Block(
          (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (attn): GPT2PseudoAttention(
            (c_attn): Conv1DWithTrainedWeights()
            (c_proj): Conv1DWithTrainedWeights()
            (attn_dropout): Dropout(p=0.1, inplace=False)
            (resid_dropout): Dropout(p=0.1, inplace=False)
            (uk): Linear(in_features=1024, out_features=1024, bias=True)
            (uv): Linear(in_features=1024, out_features=1024, bias=True)
          )
          (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (mlp): GPT2MLP(
            (c_fc): Conv1D()
            (c_proj): Conv1D()
            (act): NewGELUActivation()
            (dropout): Dropout(p=0.1, inplace=False)
          

In [22]:
test_file = [
    {
      "image": "path/to/image.nii.gz",
      "modal": "T2FLAIR"
    }
]

hammer_anas = json.load(open('utils_file/hammer_anas.json','r'))

In [None]:
list_of_lists = [[j['image']] for j in test_file]
list_of_ab_segs = [j['label'] if 'label' in j else None for j in test_file]
list_of_ana_segs = [j['label2'] if 'label2' in j else None for j in test_file]
list_of_reports = None if 'report' not in test_file[0] else [j['report'] for j in test_file]
modals = [j['modal'] for j in test_file]

list_of_ab_segs, list_of_ana_segs = segmodel.seg(list_of_lists, list_of_ab_segs, list_of_ana_segs, modals)

In [24]:
pred_report = []
for i, the_image_path in enumerate(list_of_lists):
    the_ab_seg_path = list_of_ab_segs[i] if list_of_ab_segs is not None else None
    the_ana_seg_path = list_of_ana_segs[i] if list_of_ana_segs is not None else None

    if the_ab_seg_path is not None:
        d, s_ab, dct = trainer.preprocess_patient(the_image_path, the_ab_seg_path, target_shape=None)
    else:
        s_ab = None
    
    if the_ana_seg_path is not None:
        d, s_ana, dct = trainer.preprocess_patient(the_image_path, the_ana_seg_path, target_shape=None)
    else:
        s_ana = None
    
    modal = modals[i]

    d = np.expand_dims(nnUNet_resize(d[0],trainer.patch_size,axis=0),axis=0)
    s_ab = nnUNet_resize(s_ab[0], trainer.patch_size,is_seg=True,axis=0) if s_ab is not None else np.zeros(trainer.patch_size)
    s_ab = np.expand_dims(s_ab, axis=0)
    s_ana = nnUNet_resize(s_ana[0], trainer.patch_size,is_seg=True,axis=0) if s_ana is not None else np.zeros(trainer.patch_size)
    s_ana = np.expand_dims(s_ana, axis=0)
    s = np.concatenate((s_ana,s_ab),axis=0)

    region_features, region_direction_names  = trainer.predict_preprocessed_data_return_region_report(
        d, s, None, do_mirroring=False, mirror_axes=trainer.data_aug_params['mirror_axes'], use_sliding_window=True,
        step_size=0.5, use_gaussian=True, all_in_gpu=False,
        mixed_precision=True, modal=modal, eval_mode=config['eval_mode'])
    
    region_features = torch.tensor(np.array([item.cpu().detach().numpy() for item in region_features]), dtype=torch.float32).to(trainer.llm_model.device)
    
    output = trainer.llm_model.generate(
                    region_features,
                    max_length=300,
                    num_beams=1,
                    num_beam_groups=1,
                    do_sample=False,
                    num_return_sequences = 1,
                    early_stopping=True
            )
    del region_features

    generated_sents_for_selected_regions = trainer.tokenizer.batch_decode(
            output, skip_special_tokens=True, clean_up_tokenization_spaces=True
    )

    ## the index of the sentence mask ##
    if config['eval_mode'] == "region_segtool":
        pred_global_report = generated_sents_for_selected_regions[-1]

        pred_region_concat_report = []
        for cur_idx, se in enumerate(generated_sents_for_selected_regions[:-1]):
            # no anatomy is mentioned in the sentence
            ana_flag = False
            for cur_a in hammer_anas:
                if cur_a in se.lower():
                    ana_flag=True
                    break
            # anatomy is mentioned in the sentence
            if not ana_flag:
                sort_ana = sorted(region_direction_names[cur_idx][1],key =lambda x:-x[0])
                most_pixel = sort_ana[0][0]
                sort_ana = list(filter(lambda x:x[0]>=most_pixel, sort_ana))
                ana_str = ' and '.join([item[1] for item in sort_ana])
                se = se.strip()
                se = se[:-1] + ' in '+ana_str+'.' if se[-1] == '.' or se[-1] == ',' else se + ' in '+ana_str+'.'
            else:
                if region_direction_names[cur_idx][0] == "left":
                    se = se.replace('right','left')
                elif region_direction_names[cur_idx][0] == "right":
                    se = se.replace('left','right')
            pred_region_concat_report.append(se)
            
        pred_region_concat_report = " ".join(pred_region_concat_report)

        left_sentence = ""

        pred_split = pred_global_report.split('.')
        pred_split_2 = []
        for se in pred_split:
            pred_split_2.extend(se.split(','))
        pred_split = list(map(lambda x:x+'.',pred_split_2))
        
        if 'ventricle' not in pred_region_concat_report.lower() and 'ventricle' not in left_sentence.lower():
            left_sentence = left_sentence+" ".join([g for g in pred_split if 'ventricle' in g.lower()])            
        if 'midline' not in pred_region_concat_report.lower() and 'midline' not in left_sentence.lower():
            left_sentence = left_sentence+" "+" ".join([g for g in pred_split if 'midline' in g.lower()])
        if 'sulci' not in pred_region_concat_report.lower() and 'midline' not in left_sentence.lower():
            left_sentence = left_sentence+" "+" ".join([g for g in pred_split if 'sulci' in g.lower()])
        
        if 'midline' not in pred_region_concat_report.lower() and 'midline' not in left_sentence.lower():
            left_sentence += " No midline shift."
            
        pred_region_concat_report +=" "+left_sentence
        
        pred_report.append({'image':the_image_path,'pred_report':pred_region_concat_report,'ab_mask':the_ab_seg_path,'ana_mask':the_ana_seg_path})

    elif config['eval_mode'] == "given_mask":
        pred_region_concat_report = " ".join(generated_sents_for_selected_regions)
        pred_report.append({'image':the_image_path,'pred_report':pred_region_concat_report,'ab_mask':the_ab_seg_path,'ana_mask':the_ana_seg_path})

before crop: (1, 155, 240, 240) after crop: (1, 146, 171, 136) spacing: [1. 1. 1.] 

no resampling necessary
no resampling necessary
before: {'spacing': array([1., 1., 1.]), 'spacing_transposed': array([1., 1., 1.]), 'data.shape (data is transposed)': (1, 146, 171, 136)} 
after:  {'spacing': array([1., 1., 1.]), 'data.shape (data is resampled)': (1, 146, 171, 136)} 

before crop: (1, 155, 240, 240) after crop: (1, 146, 171, 136) spacing: [1. 1. 1.] 

no resampling necessary
no resampling necessary
before: {'spacing': array([1., 1., 1.]), 'spacing_transposed': array([1., 1., 1.]), 'data.shape (data is transposed)': (1, 146, 171, 136)} 
after:  {'spacing': array([1., 1., 1.]), 'data.shape (data is resampled)': (1, 146, 171, 136)} 

debug: mirroring False mirror_axes (0, 1, 2)
do mirror: False


In [None]:

print(pred_report)
with open(os.path.join(config['output_dir'],'pred_report.json'), 'w') as f:
    json.dump(pred_report, f, indent=4)