In [1]:
from pathlib import Path
import torch
import pandas as pd
import stat
import numbers
import pydicom
import numpy as np
from tqdm import tqdm
from collections import Counter
import re
import os
from PIL import Image
from matplotlib import pyplot as plt
import cv2

In [2]:
Path().resolve()

PosixPath('/home/buehlern/Documents/Masterarbeit/notebooks')

# Create Balanced Dataset for Fracture Detection Finetuning

In [3]:
df_loc = Path('../data/df_min.pkl')
df = pd.read_pickle(df_loc)

In [4]:
df.columns

Index(['patientid', 'path', 'pixelarr_shape', 'inverted', 'bodypart',
       'fracture', 'foreignmaterial'],
      dtype='object')

In [5]:
df['fracture'].value_counts()

fracture
NO        441
YES       212
Unsure     20
Name: count, dtype: int64

In [6]:
bool_map = {'YES': True, 'NO': False, 'Unsure': float('NaN')}
df['fracture_bool'] = df['fracture'].map(bool_map)
df['foreignmaterial_bool'] = df['foreignmaterial'].map(bool_map)

In [7]:
df_frac = df.dropna(subset=['fracture_bool'])

In [8]:
df_frac['fracture'].value_counts()

fracture
NO     441
YES    212
Name: count, dtype: int64

In [9]:
df_frac['bodypart'].value_counts()

bodypart
knee        66
elbow       64
foot        62
hand        57
wrist       57
cspine      57
shoulder    57
tspine      53
clavicle    51
rips        48
scapula     45
skull       36
Name: count, dtype: int64

In [10]:
df_frac.groupby(['bodypart', 'fracture']).size()

bodypart  fracture
clavicle  NO          20
          YES         31
cspine    NO          54
          YES          3
elbow     NO          30
          YES         34
foot      NO          45
          YES         17
hand      NO          43
          YES         14
knee      NO          61
          YES          5
rips      NO          26
          YES         22
scapula   NO          30
          YES         15
shoulder  NO          41
          YES         16
skull     NO          36
tspine    NO          28
          YES         25
wrist     NO          27
          YES         30
dtype: int64

In [11]:
df_frac.groupby(['bodypart', 'fracture']).size().unstack().fillna(0).min(axis=1).sort_values(ascending=False)

bodypart
elbow       30.0
wrist       27.0
tspine      25.0
rips        22.0
clavicle    20.0
foot        17.0
shoulder    16.0
scapula     15.0
hand        14.0
knee         5.0
cspine       3.0
skull        0.0
dtype: float64

In [12]:
df_balanced = pd.DataFrame()
bp_list = ['elbow', 'wrist']
for bp in bp_list:
    df_bp = df_frac[df_frac['bodypart'] == bp]
    num_samples = df_bp['fracture'].value_counts().min()
    df_bp = df_bp.groupby('fracture_bool').apply(lambda x: x.sample(n=num_samples, random_state=1))
    df_balanced = pd.concat([df_balanced, df_bp])
df_balanced = df_balanced.reset_index(drop=True)

  df_bp = df_bp.groupby('fracture_bool').apply(lambda x: x.sample(n=num_samples, random_state=1))
  df_bp = df_bp.groupby('fracture_bool').apply(lambda x: x.sample(n=num_samples, random_state=1))


In [13]:
df_balanced[['bodypart', 'fracture_bool']]

Unnamed: 0,bodypart,fracture_bool
0,elbow,False
1,elbow,False
2,elbow,False
3,elbow,False
4,elbow,False
...,...,...
109,wrist,True
110,wrist,True
111,wrist,True
112,wrist,True


In [14]:
df_balanced.groupby(['bodypart', 'fracture']).size()

bodypart  fracture
elbow     NO          30
          YES         30
wrist     NO          27
          YES         27
dtype: int64

In [15]:
df_ft_balanced_loc = Path('../data/df_min_ft_test_114.pkl')
pd.to_pickle(df_balanced, df_ft_balanced_loc)

# Inspect Data

In [32]:
from PIL import Image

In [52]:
df_ft_balanced_loc = Path('../data/df_min_ft_test_114.pkl')
df_balanced = pd.read_pickle(df_ft_balanced_loc)

In [53]:
def show_image(image, title=''):
    # image is [H, W, 1]
    assert image.shape[2] == 1
    plt.imshow(image, cmap=plt.cm.bone)
    plt.title(title, fontsize=8)
    plt.axis('off')
    return

In [None]:
bp_list = ['elbow', 'wrist']
for bp in bp_list:
    plt.rcParams['figure.figsize'] = [40, 10]
    for i, fracture in enumerate([False, True]):
        scan = df_balanced[(df_balanced['bodypart'] == bp) & (df_balanced['fracture_bool'] == fracture)].sample(1)
        scan_id = scan.index[0]
        scan_frac = scan["fracture"].iloc[0]
        scan_path = scan["path"].iloc[0]
        for j, scale in enumerate([1.0, 0.5, 0.25]):
            pixel_values_raw = pydicom.read_file(scan_path).pixel_array
            pixel_values_resize = cv2.resize(pixel_values_raw, (0, 0), fx=scale, fy=scale)
            pixel_values = pixel_values_resize[:, :, np.newaxis]
            
            print(f"{i*3+j+1}: Scan {scan_id}: path={scan_path}, bp={bp}, fracture={scan_frac}, scale={scale}, shape={pixel_values.shape}")
    
            plt.subplot(1, 8, i*3+j+1)
            show_image(pixel_values, title=f"{bp}, Fracture: {scan_frac}, Scale: {scale}")
            
            im = Image.fromarray(pixel_values_resize)
            im.save(f'/home/buehlern/Documents/Masterarbeit/notebooks/Data Exploration Graphics/Finetuning/{scan_id}_{bp}_{scan_frac}_{scale}_{pixel_values_raw.shape}.png')
            plt.axis('off')
    plt.show()

# Create Dataset

In [19]:
import sys
sys.path.insert(1, '/home/buehlern/Documents/Masterarbeit/models')
from src.data.mri_datamodule import MRIDataModule

  from .autonotebook import tqdm as notebook_tqdm


In [20]:
# Load the DataModule
mri_datamodule = MRIDataModule(
            batch_size=1,
            num_workers=1,
            persistent_workers=True,
            pin_memory=True,
            df_name='df_min_ft_test_114',
            batch_binning='smart',
            label='fracture')

Using label fracture as stratification_target
Initializing MRIDatasetBase...
Loading dataframe from /home/buehlern/Documents/Masterarbeit/data/df_min_ft_test_114.pkl...
MRIDatasetBase(len=114) initialized
Initializing MRIDataset(mode=train)...


WARN: NO TRAINVAL TEST SPLIT FOUND AT /home/buehlern/Documents/Masterarbeit/data/splits/split_test_df_min_ft_test_114_straton_fracture.csv, type YES[enter] to generate one:  YES


WARN: GENERATING NEW TRAINVAL TEST SPLIT
MRIDataset(mode=train, len=91) initialized
Initializing MRIDataset(mode=val)...
MRIDataset(mode=val, len=7) initialized
Initializing MRIDataset(mode=test)...
WARN: Including test data
MRIDataset(mode=test, len=16) initialized


In [21]:
data_sources = [mri_datamodule.data_train, mri_datamodule.data_val, mri_datamodule.data_test]
for data_source in data_sources:
    it = iter(data_source)
    frac = 0
    total = len(data_source)
    for i in range(total):
        item = next(it)
        image = item[0]
        label = item[1]
        frac += label
    print(f"Fractures: {frac}/{total}")

Fractures: 47/91
Fractures: 2/7
Fractures: 8/16
