# Dataset Exploration

The aim of this notebook is to explore all the data, labels and metadata available in the dataset to extract a valuable analysis to perform a better data collection and preprocessing.

## Imports

In [None]:
import os
import multiprocessing
from collections import Counter
import ast

import numpy as np
import pandas as pd
from tqdm import tqdm
from matplotlib import pyplot as plt

from lib.plot_utils import create_histogram
from lib.image_processing import load_numpy_data, get_stats
from lib.data_processing import get_labels_from_str

## Config

In [None]:
# Path to the folder with the subjects folders extracted
subjects_path = "../../../datasets/BIMCV-COVID19-cIter_1_2/covid19_posi/"

# Path to the TSV with all the images file pahts by subject and session
partitions_tsv_path = os.path.join(subjects_path, "derivatives/partitions.tsv")

# Path to the TSV with the main labels (and medical report) for each session
labels_tsv_path = os.path.join(subjects_path, "derivatives/labels/labels_covid_posi.tsv")

# Path to the TSV with the results of the COVID tests for each subject
tests_tsv_path = os.path.join(subjects_path, "derivatives/EHR/sil_reg_covid_posi.tsv")

# Maximum number of cores to use during some exploration steps
max_cores = multiprocessing.cpu_count()

## Load TSV of the dataset images

This TSV contains for each unique pair of subject ID and session ID the corresponding path to the image of that session. The path is relative to the folder containing all the subjects folders. The name of this folder should be "covid19_posi".

In [None]:
# Load dataframe with all the images by session and subject
cols = ["subject", "session", "filepath"]  # The original columns must be fixed
part_df = pd.read_csv(partitions_tsv_path, sep="\t", header=0, names=cols)
part_df.head()

In [None]:
# Create auxiliary groupby views of the dataframe
part_groupby_sub = part_df.groupby(["subject"])
part_groupby_sess = part_df.groupby(["session"])

## Dataset counts

In [None]:
n_subjects = part_df['subject'].nunique()
n_sessions = part_df['session'].nunique()
n_images = len(part_df['filepath'])
print(f"Number of subjects: {n_subjects}")
print(f"Number of sessions: {n_sessions}")
print(f"Number of images: {n_images}")

### Check if the sessions ids are unique

In [None]:
n_sub_sess_pairs = part_df.groupby(["subject", "session"]).ngroups
print(f"Total sessions: {n_sessions}")
print(f"Total unique pairs of subject & session: {n_sub_sess_pairs}")
print(f"Are sessions IDs unique? {n_sessions == n_sub_sess_pairs}")

### Sessions by subject

In [None]:
# Get counts of sessions by subject
sess_count_by_sub = part_groupby_sub["session"].agg("nunique").values
print("Count of sessions by subject:")
print(f" - mean: {sess_count_by_sub.mean():.2f}")
print(f" - median: {np.median(sess_count_by_sub)}")
print(f" - max: {sess_count_by_sub.max()}")
print(f" - min: {sess_count_by_sub.min()}")
create_histogram(data=sess_count_by_sub,
                 title="Count of sessions by subject",
                 ylabel="Subjects",
                 xlabel="Sessions")

### Images by subject

In [None]:
# Get counts of images by subject
images_count_by_sub = part_groupby_sub["filepath"].agg("count").values
print("Count of images by subject:")
print(f" - mean: {images_count_by_sub.mean():.2f}")
print(f" - median: {np.median(images_count_by_sub)}")
print(f" - max: {images_count_by_sub.max()}")
print(f" - min: {images_count_by_sub.min()}")
create_histogram(data=images_count_by_sub,
                 title="Count of images by subject",
                 ylabel="Subjects",
                 xlabel="Images")

### Images by session

In [None]:
# Get counts of images by session
images_count_by_sess = part_groupby_sess["filepath"].agg("count").values
print("Count of images by session:")
print(f" - mean: {images_count_by_sess.mean():.2f}")
print(f" - median: {np.median(images_count_by_sess)}")
print(f" - max: {images_count_by_sess.max()}")
print(f" - min: {images_count_by_sess.min()}")
create_histogram(data=images_count_by_sess,
                 title="Count of images by session",
                 ylabel="Sessions",
                 xlabel="Images")

## Images analysis

### Count of images views

Here we count how many images are anterior-posterior or posterior anterior. This views are the ones we are interested in.

In [None]:
img_files = part_df["filepath"].values  # Get all the paths to the images

# Filter by view
ap_pa_img_files = []  # Anterior-Posterior (AP) and Posterior-Anterior (PA)
not_ap_img_files = []
for f_name in img_files:
    if "vp-ap" in f_name or "vp-pa" in f_name:
        ap_pa_img_files.append(f_name)
    else:
        not_ap_img_files.append(f_name)
        
print(f"AP and PA images count: {len(ap_pa_img_files)}")
print(f"Other views count: {len(not_ap_img_files)}")

### Check images file extensions

In [None]:
for view, images in [("AP/PA", ap_pa_img_files), ("other", not_ap_img_files)]:
    count_png = 0
    count_nii = 0
    for filepath in images:
        if filepath.endswith("png"):
            count_png += 1
        elif filepath.endswith("nii.gz"):
            count_nii += 1
        else:
            print(f"Warning: Unexpected image extension in {filepath}")

    print(f"Images file extensions in {view} views:")
    print(f" - png: {count_png}")
    print(f" - nii.gz: {count_nii}")

### Explore pixels data

In the next cell we load all the images corresponding to the views AP and PA to extract some information about them.

### WARNING: The execution of this cell can take up to one hour to complete.

In [None]:
# Convert the relative paths to the full path required to access the images
img_full_paths = [os.path.join(subjects_path, img_file) for img_file in ap_pa_img_files]

# Compute the stats with parallel processing
with multiprocessing.Pool(max_cores) as p:
    stats = list(tqdm(p.imap(get_stats, img_full_paths), total=len(img_full_paths)))
    
means, stds, maxs, mins, shapes = zip(*stats)  # Split the stats by types

Show a histogram with the means extracted from each image over all its pixels

In [None]:
create_histogram(data=means,
                 title="Pixels means for each image",
                 ylabel="Images",
                 xlabel="Pixels mean",
                 bins=300,
                 fig_size=(8, 5))

In [None]:
create_histogram(data=stds,
                 title="Pixels stds for each image",
                 ylabel="Images",
                 xlabel="Pixels std",
                 bins=300,
                 fig_size=(8, 5))

Show a histogram with the maximum pixels values extracted from each image

In [None]:
create_histogram(data=maxs,
                 title="Pixels maximums for each image",
                 ylabel="Images",
                 xlabel="Maximum pixel",
                 bins=300,
                 fig_size=(8, 5))

Show a histogram with the minimum pixels values extracted from each image

In [None]:
create_histogram(data=mins,
                 title="Pixels minimums for each image",
                 ylabel="Images",
                 xlabel="Minimum Pixel",
                 bins=300,
                 fig_size=(8, 5))

Count different images shapes

In [None]:
print("Images shapes count:")
sorted_counts = np.array(sorted(dict(Counter(shapes)).items(), key=lambda x: x[1], reverse=True))
for shape, count in sorted_counts:
    print(f"{shape}: {count}")

Scatter plot to view the general shape distributions. To create this plot we take into account the top "N" most frequent shapes.

In [None]:
# Select the top 40 most frequent shapes
N = 40
selected_counts = sorted_counts[:N,:]

heights, widths = zip(*selected_counts[:,0]) # Split height and width values

colors = np.array(selected_counts[:,1])  # Get color form the count number
area = 300  # Size of the circles
plt.figure(figsize=(9,8), dpi=100)
plt.scatter(widths, heights, s=area, c=colors, cmap="Accent", alpha=0.6)
plt.grid()
cbar = plt.colorbar()
cbar.ax.set_ylabel("Images Count")
plt.title("Images shapes distribution")
plt.ylabel("Image Height")
plt.xlabel("Image Width")
plt.show()

## Labels Analysis

### Load TSV with labels by session

In [None]:
labels_df = pd.read_csv(labels_tsv_path, sep="\t")
labels_df.head()

### Look for nan values

In [None]:
print("NaNs count by column:")
print(labels_df.isna().sum())

### Check subjects and sessions IDs

In [None]:
n_subjects = labels_df['PatientID'].nunique()
n_sessions = labels_df['ReportID'].nunique()
print(f"Number of subjects: {n_subjects}")
print(f"Number of sessions: {n_sessions}")

# Compare the sets of unique sessions IDs of each DataFrame
imgs_df_sess = set(part_df['session'])
labels_df_sess = set(labels_df["ReportID"])
sess_diff1 = len(imgs_df_sess.difference(labels_df_sess))
sess_diff2 = len(labels_df_sess.difference(imgs_df_sess))
print(f"Number of images without labels: {sess_diff1}")
print(f"Number of labels without image: {sess_diff2}")

### Count labels

Note: The labels in the column "Labels" are strings representing a python list
      with the actual strings of the labels

In [None]:
# Replace the nans by the label NONE
labels_df["Labels"] = labels_df["Labels"].fillna('["NONE"]')

labels_counter = Counter()
for labels_str in labels_df["Labels"]:
    labels_list = get_labels_from_str(labels_str, verbose=True)  # Get a list object from the str
    labels_counter.update(labels_list)  # Update the counter for each label in the list


Show counts results

In [None]:
print(f"Total number of unique labels: {len(labels_counter)}")

print("\nLabels count (sorted):")
for label, counter in sorted(labels_counter.items(), key=lambda x: x[1], reverse=True):
    print(f'"{label}": {counter}')

### Count sets of labels

In [None]:
# List the sets of labels to look for
labels_sets = [["COVID 19", "pneumonia"], ["COVID 19", "infiltrates"], ["pneumonia", "infiltrates"],
               ["COVID 19"], ["pneumonia"], ["infiltrates"], ["normal"]]
labels_sets_counter = Counter()  # To store the aparitions of each labels set

for labels_str in labels_df["Labels"]:
    labels_list = get_labels_from_str(labels_str, verbose=True)  # Get a list object from the str

    # Check if any labels set is a subset of the labels_list
    for l_set in labels_sets:
        if set(l_set).issubset(labels_list):
            labels_sets_counter.update([" + ".join(l_set)])  # Update the counter of the labels set

In [None]:
print("\nLabels sets count (sorted):")
for label, counter in sorted(labels_sets_counter.items(), key=lambda x: x[1], reverse=True):
    print(f'"{label}": {counter}')

## Analyze COVID tests labels

### Load TSV with COVID tests results

In [None]:
tests_df = pd.read_csv(tests_tsv_path, sep="\t")
tests_df.head()

Look for nan values

In [None]:
print("NaNs count by column:")
print(tests_df.isna().sum())

### Count tests types

In [None]:
print("Test types count:")
test_types_counts = tests_df["test"].value_counts()
print(test_types_counts)

### Count tests results

In [None]:
print("Test results count:")
test_res_counts = tests_df["result"].value_counts()
print(test_res_counts)

## Use COVID tests to improve data labeling

For each subject we compare the labels of each session with the results available from the COVID tests labels.
If we detect that some session hasn't the 'COVID 19' label but the subject has a POSITIVE COVID test with approximately the same date, we add the 'COVID 19' label to the session.

### Prepare data

In [None]:
# Get only the positive tests
posi_tests = tests_df[tests_df["result"] == "POSITIVO"].copy()

# Convert dates from strings to datetime objects
posi_tests["date"] = pd.to_datetime(posi_tests["date"], format="%d.%m.%Y")

# Group tests by subject ID
#  - Note: A subject can have several tests
posi_tests_by_sub = posi_tests.groupby(["participant"])

# Group the sessions labels by subject ID
labels_by_sub = labels_df.groupby(["PatientID"])

### Configure parameters

In [None]:
'''
Set the range of days before and after a COVID test to
take the test result as valid to label a session
'''
valid_prev_days = 0  # Number of days before the COVID test
valid_post_days = 0  # Number of days after the COVID test

verbose = 1  # 0: only errors, >0: Full logs

### Process tests data

In [None]:
fixed_labels = 0  # To count the sessions labels modified adding 'COVID 19'

# Avoid to change labels of sessions with one of these labels
#labels_to_avoid = []
labels_to_avoid = ["exclude", "normal"]
# Only try to change the labels os sessions with at least one of these labels
#mandatory_labels = []
mandatory_labels = ["COVID 19 uncertain", "infiltrates", "pneumonia"]

# Iterate over the groups of tests of each subject
#  - IMPORTANT: "posi_tests_by_sub" only contains positive tests
for sub_id, sub_tests in posi_tests_by_sub:
    # Load subject sessions data
    sub_sessions_tsv = os.path.join(subjects_path, sub_id, f"{sub_id}_sessions.tsv")
    # Check if the data exists
    if not os.path.isfile(sub_sessions_tsv):
        print("---------------------------------------------------")
        if not os.path.isdir(os.path.dirname(sub_sessions_tsv)):
            print(f'| Error: Missing subject directory of "{sub_id}"')
        else:
            print(f'| Error: Missing file "{sub_sessions_tsv}"')
        print("---------------------------------------------------")
        continue  # skip the subject     
    sub_sessions_df = pd.read_csv(sub_sessions_tsv, sep="\t")
    
    # Convert sessions dates from strings to datetime objects
    sub_sessions_df["study_date"] = pd.to_datetime(sub_sessions_df["study_date"], format="%Y%m%d")
    
    # Get the list of labels for each session of the subject
    sub_sessions_labels = labels_by_sub.get_group(sub_id)[['ReportID','Labels']]
    
    # Compare the labels of each session with the COVID tests
    for idx, sess_row in sub_sessions_df.iterrows():
        sess_id = sess_row["session_id"]
        sess_date = sess_row["study_date"]
        
        # Get the list of labels of the current session
        sess_labels_row = sub_sessions_labels[sub_sessions_labels["ReportID"] == sess_id]
        labels_str = sess_labels_row["Labels"].values[0]
        labels_list = get_labels_from_str(labels_str, verbose=True)  # Get a list object from the str
        
        # Skip sessions with the COVID label (Nothing to change here)
        if 'COVID 19' in labels_list:
            continue
        # Skip sessions with at least one of the labels to avoid
        if any(l in labels_to_avoid for l in labels_list):
            continue
        # Skip sessions without at least one mandatory label
        if len(mandatory_labels) and not any(l in mandatory_labels for l in labels_list):
            continue
            
        # Look if any of the tests can affect the session labels (by time difference)
        for test_date in sub_tests["date"]:
            # Compute time difference in days
            days_diff = (sess_date - test_date).days

            # Check if is a valid difference to fix the label
            if -valid_prev_days <= days_diff <= valid_post_days:
                if verbose:
                    print(f"FIX: Sess ID: {sess_id} Days diff: {days_diff} - session labels: {labels_list}")
                fixed_labels += 1               
                break  # Don't look for more tests

In [None]:
print(f"Sessions labels modified adding 'COVID 19': {fixed_labels}")