## Periodontal Diease Model ##
### Download the model weights and the test data ###
Running this notebook requires AWS keys which can be obtained from the authors.

In [12]:
import os
import pandas as pd
import numpy as np
import tarfile
import glob
from pathlib import Path

# Imports from this project
%load_ext autoreload
%autoreload 2
import periomodel
from periomodel.fileutils import FileOP
from periomodel.imageproc import is_image
print(f'Project version:  {periomodel.__version__}')

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
Project version:  0.0.post1.dev19+gf9f2c51


### Some useful functions for this notebook ###

In [9]:
def download_and_extract(url, download_dir):
    output_file_path = FileOP().download_from_url(url=url, download_dir=data_dir)
    if os.path.isfile(output_file_path):
        extract_dir = os.path.join(data_dir, 
                              os.path.splitext(os.path.basename(output_file_path))[0])
        Path(extract_dir).mkdir(parents=True, exist_ok=True)
        with tarfile.open(output_file_path) as tar:
            tar.extractall(extract_dir)
        print(f'Data downloaded and extracted: {output_file_path}')
    else:
        logger.warning('Download failed.')
        extract_dir = None
    return extract_dir
    
def patient_image_stat(data, file_col='file', group_col='group'):
    df_n_patients = data[['PatientIDE', group_col]].\
                    drop_duplicates().\
                    groupby(by=group_col).nunique().\
                    reset_index(drop=False).\
                    rename(columns={'PatientIDE': 'n_patients'}).\
                    sort_values(by=group_col, ascending=True).\
                    reset_index(drop=True)

    df_n_images = data[[file_col, group_col]].\
                drop_duplicates().\
                groupby(by=group_col).nunique().\
                reset_index(drop=False).\
                rename(columns={file_col: 'n_images'}).\
                sort_values(by=group_col, ascending=True).\
                reset_index(drop=True)

    df_stat = df_n_patients.merge(df_n_images, on=group_col, how='inner')
    df_stat = df_stat.assign(images_per_patient=np.round(df_stat['n_images']/df_stat['n_patients']).astype(int))

    return df_stat

### Download and extract the test data ###

In [10]:
# Test data URL
image_url = 'https://dsets.s3.amazonaws.com/classification_datasets/periodata_test.tar.gz'
# Model checkpoint
model_url = 'https://dsets.s3.amazonaws.com/classification_datasets/periomodel_checkpoint_1400.tar.gz'
# Download directory
# The DATA_ROOT directory can be set in the .env file (see repository)
data_dir = os.environ.get('DATA_ROOT')
assert os.path.isdir(data_dir), f'data_dir {data_dir} does not exist.'

In [24]:
# Extract and verify the test data sets
print(f'Extracting test data from: {image_url}')
image_extract_dir = download_and_extract(url=image_url, download_dir=data_dir)
if image_extract_dir is not None:
    image_dir = os.path.join(image_extract_dir, 'test')
    df_file = glob.glob(os.path.join(image_dir, '*.parquet'))
    if len(df_file) > 0:
        df_file = df_file[0]
        file_list = glob.glob(os.path.join(image_dir, '*.png'))
        file_list_verified = [file for file in file_list if is_image(file)]
        test_df = pd.read_parquet(df_file)
        display(test_df.head())
        n_files = len(test_df["image"].unique())
        print(f'{n_files} samples in data frame')
        print(f'{len(file_list_verified)} images verified in {image_dir}')
        assert n_files == len(file_list) == len(file_list_verified), 'WARNING, could not verify all files'
    else:
        logger.warning(f'Missing .parquet data file in: {image_dir}')

# Download and extract the model weights
print()
print(f'Extracting model data from: {model_url}')
model_dir = download_and_extract(url=model_url, download_dir=data_dir)
if model_dir is not None:
    try:
        checkpoint_file = glob.glob(os.path.join(model_dir, '*.ckpt'))[0]
        log_file = glob.glob(os.path.join(model_dir, '*.log'))[0]
    except IndexError:
        logger.warning('.ckpt or .log file not found.')
    print(f'Model checkpoint file: {checkpoint_file}')
    print(f'Training log file:     {log_file}')

Extracting test data from: https://dsets.s3.amazonaws.com/classification_datasets/periodata_test.tar.gz
Extracting from .gz archive.
Uncompressed output file exists: /home/andreas/data/dcmdata/periodata_test.tar. Skipping.
Data downloaded and extracted: /home/andreas/data/dcmdata/periodata_test.tar


Unnamed: 0,PatientIDE,annotation_id,disease,cl3,dset,image
0,a5c42982,829,unstable,2,test,box_a5c42982_20230414_02_00.png
1,956a3988,1758,stable,1,test,box_956a3988_20221108_02_01.png
2,037914b7,1545,unstable,2,test,box_037914b7_20230331_21_03.png
3,c0b47799,5703,very unstable,2,test,box_c0b47799_20230208_05_05.png
4,2cfc437a,3334,stable,1,test,box_2cfc437a_20230515_11_02.png


614 samples in data frame
614 images verified in /home/andreas/data/dcmdata/periodata_test/test

Extracting model data from: https://dsets.s3.amazonaws.com/classification_datasets/periomodel_checkpoint_1400.tar.gz
Extracting from .gz archive.
Uncompressed output file exists: /home/andreas/data/dcmdata/periomodel_checkpoint_1400.tar. Skipping.
Data downloaded and extracted: /home/andreas/data/dcmdata/periomodel_checkpoint_1400.tar
Model checkpoint file: /home/andreas/data/dcmdata/periomodel_checkpoint_1400/periomodel_checkpoint_1400.ckpt
Training log file:     /home/andreas/data/dcmdata/periomodel_checkpoint_1400/train.log


In [17]:
# Label assignments in data set
label_dict = {}
label_dict_inv = {}
label_col_list = ['cl3']
for label_col in label_col_list:
    cl_list = sorted(list(test_df.get(label_col).unique()))
    ds_list = [tuple(test_df.loc[test_df[label_col]==cl, 'disease'].unique()) for cl in cl_list]
    label_cl_dict = dict(zip(cl_list, ds_list))
    label_dict.update({label_col: label_cl_dict})
    
    ds_list_inv = sorted(list(test_df.get('disease').unique()))
    cl_list_inv = [int(test_df.loc[test_df['disease'] == disease, label_col].\
                       values[0]) for disease in ds_list_inv]
    label_cl_dict_inv = dict(zip(ds_list_inv, cl_list_inv))
    label_dict_inv.update({label_col: label_cl_dict_inv})
display(label_dict)
print()
display(label_dict_inv)

{'cl3': {np.int64(0): ('healthy',),
  np.int64(1): ('stable',),
  np.int64(2): ('unstable', 'very unstable')}}




{'cl3': {'healthy': 0, 'stable': 1, 'unstable': 2, 'very unstable': 2}}

### Summary statistics for the test data set ###

In [19]:
# Image data used for these models
dset_list = sorted(list(test_df.get('dset').unique()))
for dset in dset_list:
    df_dset = test_df.loc[test_df['dset'] == dset]
    dset_n_images = len(df_dset['image'].unique())
    print(f'{dset.upper()}: {dset_n_images} IMAGES')
    disease_stat = patient_image_stat(df_dset, file_col='image', group_col='disease')
    display(disease_stat)
    for label_col in ['cl3']:
        cl_stat = patient_image_stat(df_dset, file_col='image', group_col=label_col)
        print(f'{dset.upper()} IMAGES FOR {label_col.upper()} MODEL')
        display(cl_stat)

TEST: 614 IMAGES


Unnamed: 0,disease,n_patients,n_images,images_per_patient
0,healthy,20,120,6
1,stable,25,283,11
2,unstable,16,133,8
3,very unstable,4,78,20


TEST IMAGES FOR CL3 MODEL


Unnamed: 0,cl3,n_patients,n_images,images_per_patient
0,0,20,120,6
1,1,25,283,11
2,2,16,211,13


### Summary output for this notebook ###

In [23]:
# This cell should run without errors.
if image_extract_dir is not None:
    print(f'Image files verified: {len(file_list_verified)}')
    print(f'Image file location:  {image_dir}')
    print(f'Label file:           {df_file}')
if model_dir is not None:
    print(f'Checkpoint:     {checkpoint_file}')
    print(f'Log file:       {log_file}')

Image files verified: 614
Image file location:  /home/andreas/data/dcmdata/periodata_test/test
Label file:           /home/andreas/data/dcmdata/periodata_test/test/perimodel_labels_test.parquet
Checkpoint:     /home/andreas/data/dcmdata/periomodel_checkpoint_1400/periomodel_checkpoint_1400.ckpt
Log file:       /home/andreas/data/dcmdata/periomodel_checkpoint_1400/train.log
