###### Load libraries and directories

In [None]:
from IPython.display import display, HTML

In [None]:
# from IPython import get_ipython
from tqdm.notebook import tqdm
import pickle
import os
import pprint
pp = pprint.PrettyPrinter(indent=1)

# Custom modules for debugging
from SliceViewer import ImageSliceViewer3D, ImageSliceViewer3D_1view,ImageSliceViewer3D_2views
from investigate import *

#pd.set_option("display.max_rows", 10)
      
import json
from run_sma_experiment import find_l3_images,output_images
import pprint
from L3_finder import *

# Custom functions
import pickle
def save_object(obj, filename):
    with open(filename, 'wb') as output:  # Overwrites any existing file.
        pickle.dump(obj, output, pickle.HIGHEST_PROTOCOL)

def load_object(filename):        
    with open(filename, 'rb') as input:
        return pickle.load(input)

In [None]:
get_ipython().run_line_magic('tb', '')

In [None]:
cwd = os.getcwd()
print(cwd)
data = '/tf/data'
pickles = '/tf/pickles/v5_8pts'
pickles_old = '/tf/pickles/'
models = '/tf/models'

In [None]:
# Import modules and config file
configfile = os.path.join(cwd,'config/debug_ES/v5_run_prediction_CV_poorl3.json')
with open(configfile, "r") as f:
        config = json.load(f)
pp.pprint(config)

## Section 1 - Process final images

## Section 2 - Load each study into subject object
<br>
Subject object defined in L3finder.ingest

In [None]:
config['l3_finder']['new_tim_dicom_dir_structure']

In [None]:
# Debug
print("Finding subjects")

subjects = list(
    find_subjects(
        config['l3_finder']["dicom_dir"],
        new_tim_dir_structure=config['l3_finder']['new_tim_dicom_dir_structure']
    )
)

print('Subjects found: ', len(subjects))

## Section-3 - check if there are subjects with multiple folders (studies)

In [None]:
%%time
# Find series images
print("Finding series")
series = list(flatten(tqdm((s.find_series() for s in subjects),total=len(subjects))))

In [None]:
print("Total number of series found: ", len(series))

In [None]:
%%time
sagittal_series, axial_series, excluded_series = separate_series(series)

In [None]:
print("Length of valid pats: ", len(subjects))
print("Length of sagittal series", len(sagittal_series))
print("Length of axial series", len(axial_series))
print("Length of excluded series", len(excluded_series))
print("Length of all series in dataset", len(series))

### Make sure each subject has at the max only 1 axial and 1 sagittal series

In [None]:
ax_ids = [ax.subject.id_ for ax in axial_series]
sag_ids = [sag.subject.id_ for sag in sagittal_series]

def find_duplicates(id_list):
    uniques = []
    duplicates = []
    for ids in id_list:
        if ids in uniques:
            duplicates.append(ids)
        else:
            uniques.append(ids)
            
    return uniques,duplicates


ax_u,ax_d = find_duplicates(ax_ids)
sag_u,sag_d = find_duplicates(sag_ids)

print('Ax duplicates: ', ax_d)
print('Sag duplicates: ', sag_d)

In [None]:
# Find the series objects to investigate
ax_d_series = [ax for ax in axial_series if ax.subject.id_ in ax_d]
sag_d_series = [ax for ax in sagittal_series if ax.subject.id_ in ax_d]

In [None]:
print('axials with duplicate: ',len(ax_d_series))
print('sagittals with duplicate: ',len(sag_d_series))

In [None]:
# debug
# df_dl= load_object(os.path.join(pickles_old,'df_final.pkl'))
# display(df_dl[df_dl['ID']==ax_d[0]])

## Reconstruct Missing Sagittals

In [None]:
# By default code filters 0.5mm slices, but I am letting them pass by setting it to 0
constructed_sagittals = construct_series_for_subjects_without_sagittals(
        subjects, sagittal_series, axial_series,thickness_filter=0) 

In [None]:
print(
        "Series separated\n",
        len(sagittal_series), "sagittal series.",
        len(axial_series), "axial series.",
        len(excluded_series), "excluded series.",
        len(constructed_sagittals), "constructed series.",
    )

In [None]:
sagittal_series.extend(constructed_sagittals)

In [None]:
save_object(axial_series,os.path.join(pickles,'axial_curated.pkl'))
save_object(sagittal_series,os.path.join(pickles,'sagittal_curated.pkl'))

## Create MIPS

In [None]:
print("Creating sagittal MIPS")
mips = create_sagittal_mips_from_series(
        many_series=sagittal_series,
        cache_dir=config['l3_finder'].get("cache_dir", None),
        cache=config['l3_finder'].get("cache_intermediate_results", False),
    )

In [None]:
save_object(mips,os.path.join(pickles,'mips.pkl'))

## Find L3

In [None]:
mips = load_object(os.path.join(pickles,'mips.pkl'))

In [None]:
print("Preprocessing Images")
preprocessed_images = preprocess_images(mips)

# Sagittal mip is redundant, get rid just use preprocessed images
sagittal_mips = [SagittalMIP(i) for i in preprocessed_images]

print("Separate heights for better batching")
mips_by_dimension = group_mips_by_dimension(sagittal_mips)
print("Dimensions in set:", mips_by_dimension.keys())

In [None]:
save_object(mips_by_dimension,os.path.join(pickles,'mips_by_dimension.pkl'))

## Find L3 - step 2

In [None]:
mips_by_dimension = load_object(os.path.join(pickles,'mips_by_dimension.pkl'))
axial_series = load_object(os.path.join(pickles,'axial_curated.pkl'))
sagittal_series = load_object(os.path.join(pickles,'sagittal_curated.pkl'))

In [None]:
# Get all models in model path dir
models_dir = config['l3_finder']['model_path_dir']

# Get all models in models dir
models_list = sorted([f for f in os.listdir(models_dir) if f.endswith('.h5')])
print(models_list)
folds = len(models_list)

for fold in range(folds):
    model_path = os.path.join(models_dir,models_list[fold])
    print(model_path)

In [None]:
runname = 'CV_poorl3'
if __name__ == "__main__":
    for fold in range(folds):
        model_path = os.path.join(models_dir,models_list[fold])
        print("Making predictions for fold ", fold, 'Path: ', model_path)
        prediction_results = []
        prediction_errors = []
        for dimension, sagittal_mips in mips_by_dimension.items():
            dim_group_results,errors = make_predictions_for_sagittal_mips(
                sagittal_mips,
                model_path=model_path,
                shape=dimension
            )
            prediction_results.extend(dim_group_results)
            prediction_errors.extend(errors)

        # Save prediction results
        pred_results_file = 'prediction_results_' + str(fold) + '_' +  runname + '.pkl'
        pred_errors_file = 'prediction_errors_' + str(fold) + '_' +  runname + '.pkl'
        save_object(prediction_results,os.path.join(pickles,pred_results_file))
        save_object(prediction_errors,os.path.join(pickles,pred_errors_file))

### Save l3 prediction results

In [None]:
axial_series = load_object(os.path.join(pickles,'axial_curated.pkl'))

In [None]:
len(axial_series)

In [None]:
# Load prediction_results pickle files
runname = 'CV_poorl3'

output_dir = config["l3_finder"]["output_directory"]
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

prediction_results_list = sorted([f for f in os.listdir(pickles) if (f.startswith('prediction_results_') and f.endswith(runname+'.pkl'))])
folds = len(prediction_results_list)
print(folds)

if __name__ == "__main__":
    for fold in range(folds):
        pred_file = os.path.join(pickles,prediction_results_list[fold])
        prediction_results = load_object(pred_file)
        print('Total predictions: ',len(prediction_results))
        print('Building L3 images for fold: ', fold)
        l3_images = build_l3_images(axial_series, prediction_results)
        print('Total images: ',len(l3_images))
        # Don't run this unless you have new L3 results
        print("Outputting L3 images for fold: ", fold)
        # Clears pixel data from memory aafter outputting
        output_dir = os.path.join(config["l3_finder"]["output_directory"],str(fold))
        
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        
        l3_images = output_images(
        l3_images,
        args=dict(
            output_directory=output_dir,
            should_plot=config["l3_finder"]["show_plots"],
            should_overwrite=config["l3_finder"]["overwrite"],
            should_save_plots=config["l3_finder"]["save_plots"]
        ))

## Find mean L3 prediction

In [None]:
import csv
import sys
from collections import defaultdict
from pathlib import Path

from L3_finder import L3Image
from l3finder import ingest

from compare_best_to_manual_l3_and_seg import MinimalPrediction, MinimalResult

In [None]:
def load_l3_predictions(l3_prediction_dir,nfolds):
    subject_id_col = 0
    pred_in_px_col = 1
    predictions = defaultdict(list)

    for fold_index in range(0, nfolds):
        csv_dir = os.path.join(l3_prediction_dir,str(fold_index))
        csv_path = Path(csv_dir,'l3_prediction_results.csv')
        with open(csv_path) as csvfile:
            reader = csv.reader(csvfile)
            next(reader)

            for row in reader:
                subid = row[subject_id_col].split('-')[0]
                predictions[subid].append(float(row[pred_in_px_col]))

    return predictions

def calc_mean_predictions(all_predictions: defaultdict):
    result = {}
    for subject_id, prediction_list in all_predictions.items():
        result[subject_id] = np.mean(prediction_list)
    return result

def find_subjects_w_preds(predictions, all_subjects):
    subject_ids_w_preds = set(predictions.keys())
    return [s for s in all_subjects if s.id_ in subject_ids_w_preds]

def load_l3_images_from_predictions(mean_predictions, subjects_w_preds,axials,sagittals):
    l3_images = []

    for subject in subjects_w_preds:
        sagittal_series = [s for s in sagittals if s.subject.id_ == subject.id_][0]
        axial_series = [a for a in axials if a.subject.id_ == subject.id_][0]
        l3_images.append(
            L3Image(
                axial_series=axial_series,
                sagittal_series=sagittal_series,
                prediction_result=MinimalResult(
                    MinimalPrediction(
                        predicted_y_in_px=mean_predictions[subject.id_]
                    )
                )
            )
        )
    return l3_images

In [None]:
axial_series = load_object(os.path.join(pickles,'axial_curated.pkl'))
sagittal_series = load_object(os.path.join(pickles,'sagittal_curated.pkl'))

In [None]:
runname = 'CV_poorl3'
prediction_results_list = sorted([f for f in os.listdir(pickles) if (f.startswith('prediction_results_') and f.endswith(runname+'.pkl'))])
folds = len(prediction_results_list)
print(folds)

In [None]:
if __name__ == "__main__":
    all_predictions = load_l3_predictions(config["l3_finder"]["output_directory"],folds)
    mean_predictions = calc_mean_predictions(all_predictions)  

In [None]:
if __name__ == "__main__":
    subjects_w_preds = find_subjects_w_preds(mean_predictions, list(ingest.find_subjects(config['l3_finder']["dicom_dir"])))
    l3_images = load_l3_images_from_predictions(mean_predictions, subjects_w_preds, axial_series, sagittal_series)

In [None]:
save_object(l3_images,os.path.join(pickles,'l3_images_cv.pkl'))
save_object(mean_predictions,os.path.join(pickles,'mean_predictions.pkl'))

### Handle Outlier Cases

In [None]:
l3_images = load_object(os.path.join(pickles,'l3_images_cv.pkl'))
mean_predictions = load_object(os.path.join(pickles,'mean_predictions.pkl'))

In [None]:
infile  = 'poorl3.csv'
df_poorl3 = pd.read_csv(infile, index_col=False)

print('Total number of outliers for manual L3 detection: ', len(df_poorl3))
l3_absent = df_poorl3.loc[df_poorl3['L3slice'].isnull(),'ID'].values.tolist()
print('Cases with L3 not present: ', len(l3_absent))
l3_present = df_poorl3.loc[~df_poorl3['L3slice'].isnull(),'ID'].values.tolist()
print('Cases with manually identified L3s: ', len(l3_present))

l3_outliers = l3_absent + l3_present
print("Outliers: ", len(l3_outliers))

#sagittal_mips_valid = [sagittal_mip for sagittal_mip in sagittal_mips if sagittal_mip.subject_id not in df_poorl3.ID.values]

In [None]:
# Get rid of outliers without proper L3 images
print('Total l3_images: ', len(l3_images))
l3_images = [l3_image for l3_image in l3_images if l3_image.subject_id not in l3_absent]
print('Total l3_images after outlier removal: ', len(l3_images))

In [None]:
l3_images_out = [l3_image for l3_image in l3_images if l3_image.subject_id in l3_present]
print(len(l3_images_out))

In [None]:
l3_images_normals = [l3_image for l3_image in l3_images if l3_image.subject_id not in l3_outliers]
print(len(l3_images_normals))

In [None]:
# Create Manual Predictions DICT
manualL3s = []
for i in range(len(l3_images_out)):
    subject_id = l3_images_out[i].subject_id
    manualL3s.append(df_poorl3.loc[df_poorl3['ID']==subject_id,'L3slice'].values[0])

In [None]:
save_object(l3_images_normals,os.path.join(pickles,'l3_images_normals.pkl'))
save_object(l3_images_out,os.path.join(pickles,'l3_images_outliers.pkl'))
save_object(manualL3s,os.path.join(pickles,'manualL3s.pkl'))

### Segment L3 Axial Images and Calculate Muscle Area

In [None]:
l3_images_normals = load_object(os.path.join(pickles,'l3_images_normals.pkl'))
l3_images_out = load_object(os.path.join(pickles,'l3_images_outliers.pkl'))

In [None]:
# List from epic filter 
# Changed for V5 to read from csv file
df_v5 = pd.read_csv('patlist_with_validBMI_corrected_v5.csv', index_col=False)
normal_patients_corrected = list(df_v5.PAT_ID.values)
print(len(normal_patients_corrected))

In [None]:
# List from Andrew
infile  = 'poorl3.csv'
df_poorl3 = pd.read_csv(infile, index_col=False)

df_l3_present = df_poorl3.loc[~df_poorl3['L3slice'].isnull()]

print('Length of df_l3_present: ', len(df_l3_present))

df_l3_present_normals = df_l3_present.loc[df_l3_present['ID'].isin(normal_patients_corrected)]

print('Length of df_l3_present_normals: ', len(df_l3_present_normals))

In [None]:
# Process only the normals selected by epic filter
l3_images_normals = [l3 for l3 in l3_images_normals if l3.subject_id in normal_patients_corrected]
print('Length of normals processed and in epic filter: ', len(l3_images_normals))

l3_images_out = [l3 for l3 in l3_images_out if l3.subject_id in normal_patients_corrected]
print('Length of outliers processed and in epic filter: ', len(l3_images_out))

manualL3s = [int(df_l3_present_normals.loc[df_l3_present_normals['ID']==l3.subject_id,'L3slice'].values[0]) for l3 in l3_images_out if l3.subject_id]
print('Length of manual L3s: ', len(manualL3s))

In [None]:
# Patients in Epic filter, but not in normals or outliers [i.e those missing Axial CT itself]

all_l3s = l3_images_normals + l3_images_out

l3_pats = [l3.subject_id for l3 in all_l3s]

missing_CT = [p for p in normal_patients_corrected if p not in l3_pats]

print('Patients with L3: ', len(l3_pats))
print('Patients from Epic: ',len(normal_patients_corrected))
print('Patients Missing Axial CT: ', len(missing_CT))

In [None]:
from compare_best_to_manual_l3_and_seg import seg_model_configs
from compare_best_to_manual_l3_and_seg import do_segmentation_cv

In [None]:
config["muscle_segmentor"]['model_path_dir']

In [None]:
if __name__ == "__main__":
    configs = seg_model_configs(config["muscle_segmentor"]['model_path_dir'])
    smas,average_masks,tableless_images = do_segmentation_cv(configs, l3_images_normals)
    print('Length of smas normals: ',len(smas))
    print('Length of average_masks normals: ',len(average_masks))
    print('Length of tableless_images normals: ',len(tableless_images))
    print("Done")

In [None]:
len(manualL3s)

In [None]:
if __name__ == "__main__":
    configs = seg_model_configs(config["muscle_segmentor"]['model_path_dir'])
    smas_out,average_masks_out,tableless_images_out = do_segmentation_cv(configs, l3_images_out, manualL3s)
    print('Length of smas outliers: ',len(smas_out))
    print('Length of average_masks outliers: ',len(average_masks_out))
    print('Length of tableless_images outliers: ',len(tableless_images_out))
    print("Done")

In [None]:
smas = smas + smas_out
average_masks = average_masks + average_masks_out

In [None]:
tableless_images = np.concatenate((tableless_images, tableless_images_out),axis=0)

In [None]:
l3_images = l3_images_normals + l3_images_out

In [None]:
print('Length of smas all: ',len(smas))
print('Length of average_masks all: ',len(average_masks))
print('Length of tableless_images all: ',len(tableless_images))

In [None]:
from imageio import imsave
import csv

def output_sma_results(output_dir, l3_images, tableless_images, average_masks, smas):
    os.makedirs(output_dir, exist_ok=True)

    csv_filename = os.path.join(output_dir, "areas-mm2_by_subject_id.csv")
    with open(csv_filename, "w") as csvfile:
        csv_writer = csv.writer(csvfile)
        csv_writer.writerow(["subject_id", "area_mm2", "sagittal_series", "axial_series"])
        print('Saving Segmentation Results in ', output_dir)
        index = 0    
        for mask, sma, l3_image, tableless_image in zip(average_masks, smas,l3_images, tableless_images):
            index += 1
            base = os.path.join(output_dir, str(index) + "_" + l3_image.subject_id)
            imsave(base + "_CT.tif", tableless_image.astype(np.float32))
            imsave(base + "_muscle.tif", mask * np.iinfo(np.uint8).max)

            row = [
                l3_image.subject_id,
                sma.area_mm2,
                l3_image.sagittal_series.series_name,
                l3_image.axial_series.series_name,
            ]
            csv_writer.writerow(row)
        print('Total exams outputted: ', index)

In [None]:
output_sma_results(config["muscle_segmentor"]['output_directory'], l3_images, tableless_images, average_masks, smas)