# Load libraries and directories

In [18]:
from IPython.display import display, HTML

In [19]:
# from IPython import get_ipython
from tqdm.notebook import tqdm
import pickle
import os
import pprint
pp = pprint.PrettyPrinter(indent=1)

# Custom modules for debugging
from SliceViewer import ImageSliceViewer3D, ImageSliceViewer3D_1view,ImageSliceViewer3D_2views
from investigate import *

#pd.set_option("display.max_rows", 10)
      
import json
from run_sma_experiment import find_l3_images,output_images
import pprint
from L3_finder import *

# Custom functions
def save_object(obj, filename):
    with open(filename, 'wb') as output:  # Overwrites any existing file.
        pickle.dump(obj, output, pickle.HIGHEST_PROTOCOL)

def load_object(filename):        
    with open(filename, 'rb') as input:
        return pickle.load(input)

In [20]:
get_ipython().run_line_magic('tb', '')

No traceback available to show.


### In CCHMC's workflow, there were two dicom dumps, with different folder structures and naming convention:
<br>
Folder structure: <br>
Dump-1: Project_folder/Patient_folder/Series_Folder/dicom_files <br>
Dump-2: Project_folder/Patient_folder/Study_folder/Series_Folder/dicom_files<br> 
<br>
Naming Convention for patient folder: <br>
Dump-1: PATID-GMRN-PATID-STUDYNAME <br>
Dump-2: PT-PATID-PATID 

In [21]:
#Select which dump you are processing here [dump1: 1, dump2: 2]
dump = 1

In [22]:
cwd = os.getcwd()
data = '/tf/data'
output = '/tf/pickles/v5_8pts'

## Section 1 - Load list of normal patients filtered from Epic data and select those patients from the DICOM dump of all patients

In [23]:
# Load normal patient list
infile  = 'patlist_with_validBMI_corrected_v5.csv'
df_P = pd.read_csv(infile, index_col=False)
df_P = df_P.loc[:, ~df_P.columns.str.contains('^Unnamed')]
df_P = df_P[['GIVEN_MRN','PAT_ID','ACC']]
print('Columns of df_P: ', list(df_P))
print('Length of df_P: ', len(df_P))
display(df_P.head(10))
#print('# of Unique patients: ', len(df_P.subject_id.unique()))

Columns of df_P:  ['GIVEN_MRN', 'PAT_ID', 'ACC']
Length of df_P:  2238


Unnamed: 0,GIVEN_MRN,PAT_ID,ACC
0,807126,Z857672,7443683
1,11176208,Z1204112,7442667
2,834056,Z870530,7219002
3,1412716,Z1009393,7437949
4,1051399,Z441477,7449601
5,11277072,Z1305152,7476367
6,11021437,Z1049116,7476123
7,855379,Z881264,7207613
8,1004812,Z413629,7206650
9,742506,Z828131,7206442


In [24]:
pats = next(os.walk(data))[1]
print('Total patient folders in data dir: ',len(pats))

Total patient folders in data dir:  8


In [25]:
if dump == 1:
    patids = [pat.split('-')[0] for pat in pats]
elif dump == 2:
    patids = [pat.split('-')[-1] for pat in pats]

valid_ids = [valid_id for valid_id,valid_dir in zip(patids,pats) if valid_id in df_P.PAT_ID.values]

valid_ids = set(valid_ids)

In [26]:
print('valid ids: ',len(valid_ids))

valid ids:  8


## Section 2 - Load each study into subject object
<br>
Subject object defined in L3finder.ingest

In [27]:
# Import modules and config file
configfile = os.path.join(cwd,'config/debug_ES/series_filter_ds1.json')
with open(configfile, "r") as f:
        config = json.load(f)

config = config["series_filter"]        
print('Current config dict: ')
pp.pprint(config)

Current config dict: 
{'dicom_dir': '/tf/data',
 'model_path': 'None',
 'output_directory': '/tf/output',
 'overwrite': True,
 'save_plots': True,
 'series_to_skip_pickle_file': '/tf/output/broken_sagittal_and_axial_series.pkl',
 'show_plots': False}


In [28]:
if dump==2:
    config["new_tim_dicom_dir_structure"] = False
elif dump==1:
    config["new_tim_dicom_dir_structure"] = True

In [29]:
# Debug
print("Finding subjects")

subjects = list(
    find_subjects(
        config["dicom_dir"],
        new_tim_dir_structure=config["new_tim_dicom_dir_structure"]
    )
)

print('Subjects found: ', len(subjects))

Finding subjects
Subjects found:  8


## Section-3 - check if there are subjects with multiple folders (studies)

In [30]:
subjects = [subject for subject in subjects if subject.id_ in valid_ids]
print('Subjects found: ', len(subjects))
print('Valid Subjects: ', len(valid_ids))

Subjects found:  8
Valid Subjects:  8


In [31]:
# Find Duplicate Subjects
unique_subjects = []
duplicate_subjects = []
for subject in subjects:
    if subject.id_ not in unique_subjects:
        unique_subjects.append(subject.id_)
    else:
        duplicate_subjects.append(subject.id_)

print('Duplicates: ',len(duplicate_subjects)           )
print('Uniques: ',len(unique_subjects))

Duplicates:  0
Uniques:  8


In [34]:
## Save subjects without duplicates
save_object(subjects, os.path.join(output,'subjects_noduplicates.pkl'))

## Section 4 - Load each series into series object and keep only axials and sagittals

In [35]:
subjects = load_object(os.path.join(output,'subjects_noduplicates.pkl'))

In [36]:
len(subjects)

8

In [37]:
%%time
# Find series images
print("Finding series")
series = list(flatten(tqdm((s.find_series() for s in subjects),total=len(subjects))))
print("Total number of series found: ", len(series))

Finding series


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=8.0), HTML(value='')))


Total number of series found:  45
CPU times: user 68.3 ms, sys: 16.6 ms, total: 84.9 ms
Wall time: 72.2 ms


#### Separate axial and sagittal series

In [38]:
%%time
# Debug
from L3_finder import *
from l3finder.ingest import *
from multiprocessing import get_context
from multiprocessing import set_start_method
#set_start_method("spawn")

if __name__=='__main__':
    # Find series images
    print("Finding series")
    series = list(flatten(s.find_series() for s in subjects))

    # Separate series
    print("Separating series")
    #sagittal_series, axial_series, excluded_series = separate_series(series)
    
    excluded_series = []

    sag_filter = functools.partial(
        same_orientation,
        orientation='sagittal',
        excluded_series=excluded_series
    )
    
    axial_filter = functools.partial(
        same_orientation,
        orientation='axial',
        excluded_series=excluded_series
    )

    def pool_filter(pool, func, candidates):
        return [
            c for c, keep
            in zip(candidates, tqdm(pool.imap(func, candidates),total=len(candidates)))
            if keep]
    
    print('Filtering series using ', multiprocessing.cpu_count(), ' cores:')
    with get_context("spawn").Pool() as p:
        sagittal_series = pool_filter(p, sag_filter, series)
        print("Processed Sagittals")
        axial_series = pool_filter(p, axial_filter, series)
        print("Processed Axials")
        p.close()
        p.join()

    
    
    print("Series seperated")

#remove_start_method("spawn")

Finding series
Separating series
Filtering series using  48  cores:


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=45.0), HTML(value='')))


Processed Sagittals


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=45.0), HTML(value='')))


Processed Axials
Series seperated
CPU times: user 220 ms, sys: 720 ms, total: 940 ms
Wall time: 2.21 s


In [39]:
print("Length of valid pats: ", len(subjects))
print("Length of sagittal series", len(sagittal_series))
print("Length of axial series", len(axial_series))
#print("Length of excluded series", len(excluded_series))
#print("Length of all series in dataset", len(series))

Length of valid pats:  8
Length of sagittal series 1
Length of axial series 16


In [40]:
# Save required objects
save_object(axial_series, os.path.join(output,'axial_series.pkl'))
save_object(sagittal_series, os.path.join(output,'sagittal_series.pkl'))

## Section 5 - Investigate subjects and series using pandas

In [41]:
axial_series = load_object(os.path.join(output,'axial_series.pkl'))
sagittal_series = load_object(os.path.join(output,'sagittal_series.pkl'))
subjects = load_object(os.path.join(output,'subjects_noduplicates.pkl'))

In [42]:
df_a = get_summary_dfs(axial_series,sagittal_series,subjects)
save_object(df_a, os.path.join(output,'df_a.pkl'))

In [43]:
display(df_a.head(10))

Unnamed: 0,ID,Axials,Sagittals
0,Z923887,2,0
1,Z1107859,2,0
2,Z1312350,2,0
3,Z934614,3,0
4,Z1146560,2,0
5,Z719159,2,0
6,Z1141194,2,1
7,Z1199636,1,0


In [44]:
df_a_axials = get_summary_by_serieslength(axial_series)
df_a_sags = get_summary_by_serieslength(sagittal_series)
save_object(df_a_axials, os.path.join(output,'df_a_axials.pkl'))
save_object(df_a_sags, os.path.join(output,'df_a_sags.pkl'))

In [45]:
print("Length of subjects with atleast 1 axial or sagittal series: ", len(df_a))
print("Length of subjects with atleast 1 axial series: ", len(df_a_axials['ID'].unique()))
print("Length of subjects with atleast 1 sagittal series: ", len(df_a_sags['ID'].unique()))

Length of subjects with atleast 1 axial or sagittal series:  8
Length of subjects with atleast 1 axial series:  8
Length of subjects with atleast 1 sagittal series:  1


In [46]:
# Patients without Axial
pats = [pat for pat in df_a['ID'].values if pat not in df_a_axials['ID'].values]
print(len(pats))
print(pats)

0
[]


In [47]:
# Patients without Sagittal
pats = [pat for pat in df_a['ID'].values if pat not in df_a_sags['ID'].values]
print(len(pats))

7


In [None]:
imseries = get_subject_series('Z837620','Z837620-SE-6-Vol_Body_Vol._0.5',subjects)
print(imseries.orientation,' ' , imseries.slice_thickness)
imdata = imseries.pixel_data

In [None]:
%matplotlib inline
print(imdata.shape)
ImageSliceViewer3D(imdata)

In [None]:
print_summary_by_serieslength(df_a_axials)

In [None]:
print_summary_by_serieslength(df_a_sags)

## Section 6 - Create dataframe of optimal axial sagittal pairs

The function filter_finalpairs in investigate.py is used

In [48]:
axial_series = load_object(os.path.join(output,'axial_series.pkl'))
sagittal_series = load_object(os.path.join(output,'sagittal_series.pkl'))
subjects = load_object(os.path.join(output,'subjects_noduplicates.pkl'))

df_a_axials = load_object(os.path.join(output,'df_a_axials.pkl'))
df_a_sags = load_object(os.path.join(output,'df_a_sags.pkl'))
df_a = load_object(os.path.join(output,'df_a.pkl'))

In [49]:
%%time
from L3_finder import *
from l3finder.ingest import *
from multiprocessing import get_context
from multiprocessing import set_start_method
#set_start_method("spawn")
df_filt = None

if __name__=='__main__':
    # Find series images
    print("Finding IDs")
    
    IDs = [s.id_ for s in subjects]
    pair_filter = functools.partial(
        filter_finalpairs,
        df_ax=df_a_axials,
        df_sag=df_a_sags,
        subjects=subjects
    )
    
    def pool_filter(pool, func, candidates):
        return [a for a in tqdm(pool.imap(func, candidates),total=len(candidates))]
        
    print('Filtering series using ', multiprocessing.cpu_count(), ' cores:')
        
    with get_context("spawn").Pool(4) as p:
        result_list = pool_filter(p, pair_filter, IDs)
        p.close()
        p.join()
    
    print('parallel processing over')
     # Start from here
    df_filt  = pd.DataFrame(columns=['ID','Axial','Sagittal','Overlap','MissingScore','PairValidity', 
                                'AxSlices','SagSlices','AxThick','SagThick'])
    for i,op in enumerate(result_list):
        df_filt.loc[i] = op
    

print("Processed")

Finding IDs
Filtering series using  48  cores:


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=8.0), HTML(value='')))


parallel processing over
Processed
CPU times: user 211 ms, sys: 87.7 ms, total: 299 ms
Wall time: 10.2 s


In [50]:
save_object(df_filt, os.path.join(output,'df_filteredpairs.pkl'))

## Section 7: Investigate the dataframe for missing and low quality pairs

In [51]:
# Load all params
df_filt = load_object(os.path.join(output,'df_filteredpairs.pkl'))
axial_series = load_object(os.path.join(output,'axial_series.pkl'))
sagittal_series = load_object(os.path.join(output,'sagittal_series.pkl'))
subjects = load_object(os.path.join(output,'subjects_noduplicates.pkl'))

df_a_axials = load_object(os.path.join(output,'df_a_axials.pkl'))
df_a_sags = load_object(os.path.join(output,'df_a_sags.pkl'))
df_a = load_object(os.path.join(output,'df_a.pkl'))

In [52]:
# Make sure filtered df and subjects are equal length
print('Length of filtered df: ',len(df_filt))
print('Length of subjects: ',len(subjects))

Length of filtered df:  8
Length of subjects:  8


In [53]:
# View and Remove subjects without axial series
df_noaxials =  df_filt[df_filt['Axial'].isnull()]
print('Number of subjects without Axials: ', len(df_noaxials))
display(df_noaxials)

Number of subjects without Axials:  0


Unnamed: 0,ID,Axial,Sagittal,Overlap,MissingScore,PairValidity,AxSlices,SagSlices,AxThick,SagThick


In [54]:
# Remove subjects without axials from subjects list:
subjects = [s for s in subjects if s.id_ not in df_noaxials.ID.values]
print('Length of subjects with axials: ',len(subjects))


Length of subjects with axials:  8


In [55]:
# Print cases that don't have sagittals
df_nosags = df_filt[df_filt['Sagittal'].isnull()]
print("Number of cases without sagittals: ", len(df_nosags))

Number of cases without sagittals:  7


In [56]:
# Investigate cases with less than 0.7 overlap and  < 0.9 Missing Score [Tracks Slices missing from stack]
df_pooroverlap = df_filt[(df_filt['Overlap'] < 0.7) | (df_filt['MissingScore'] < 0.9)]
print('Cases with overlap < 0.7: ', len(df_pooroverlap))

Cases with overlap < 0.7:  0


In [57]:
display(df_pooroverlap.sort_values(by=['MissingScore'],ascending=[True]))

Unnamed: 0,ID,Axial,Sagittal,Overlap,MissingScore,PairValidity,AxSlices,SagSlices,AxThick,SagThick


In [None]:
# Handy Functions to investigate the poor pairs
def get_ax_sag(df,ind):
    global subjects
    subid = df.loc[ind,'ID']
    axid = df.loc[ind,'Axial']
    sagid = df.loc[ind,'Sagittal']
    ax = get_subject_series(subid,axid,subjects)
    sag = get_subject_series(subid,sagid,subjects)
    return ax,sag

In [None]:
calculate_missing_slices_sagittals(get_ax_sag(df_pooroverlap,85)[1]),verbose=True)

In [None]:
calculate_series_overlap(*get_ax_sag(df_pooroverlap,85),verbose=True)

### Based on investigation, eliminate series and subjects not eligible and create final df

In [58]:
# Keep Sagittals only when overlap > 0.7
df_final = df_filt.copy()
for ind,row in df_final.iterrows():
    if (not row['Overlap']) or (row['Overlap'] < 0.7):
            df_final.loc[ind,'Sagittal'] = None

In [59]:
# Print cases that don't have sagittals
print("Number of cases without sagittals in filter df: ", len(df_nosags))
df_nosags2 = df_final[df_final['Sagittal'].isnull()]
print("Number of cases without sagittals in final df: ", len(df_nosags2))

Number of cases without sagittals in filter df:  7
Number of cases without sagittals in final df:  7


In [60]:
# Get final 

final_df_file = 'df_final_dump_'+str(dump)+'_8pats.pkl'
final_subs_file = 'subjects_final_dump_'+str(dump)+'_8pats.pkl'

save_object(df_final, os.path.join(output,final_df_file))
save_object(subjects, os.path.join(output,final_subs_file))