In [None]:
#@title Define if we are on Colab and mount drive { display-mode: "form" }
try:
  from google.colab import drive
  drive.mount('/content/gdrive')
  IN_COLAB = True
except:
  IN_COLAB = False

In [None]:
#@title (COLAB ONLY) Clone GitHub repo { display-mode: "form" }

if IN_COLAB:
  !git clone https://github.com/lluissalord/radiology_ai.git

  %cd radiology_ai

In [None]:
#@title Setup environment and Colab general variables { display-mode: "form" }
# %%capture
%run colab_setup.ipynb

In [None]:
#@title Move images from Drive to temporary folder here to be able to train models { display-mode: "form" }
# %%capture
%run move_raw_preprocess.ipynb

In [None]:
import os
import pandas as pd
import random
from tqdm import tqdm

from fastai.basics import *
from fastai.callback import *
from fastai.vision.all import *
from fastai.vision.widgets import *
from fastai.medical.imaging import *

In [None]:
from utils import *
from preprocessing import *

from APL_losses import *

In [None]:
# from fastprogress.fastprogress import NBMasterBar, NBProgressBar

# master_bar, progress_bar = NBMasterBar, NBProgressBar

In [None]:
TRAIN_RESIZE = 256
RANDOM_RESIZE_CROP = 256
RANDOM_MIN_SCALE = 0.5

BATCH_SIZE = 32

N_TRAIN = None
N_SAMPLES_BIN = 50 # None

HIST_CLIPPING = True
HIST_CLIPPING_CUT_MIN = 5.
HIST_CLIPPING_CUT_MAX = 99.

KNEE_LOCALIZER = True
CLAHE_SCALED = True
HIST_SCALED = False
HIST_SCALED_SELF = False

USE_SAVED_MODEL = False
SAVE_MODEL = False

MODEL = resnet18
MODEL_VERSION = 0
MODEL_SAVE_NAME = f'{MODEL.__name__}_v{MODEL_VERSION}.pkl'
MODEL_SAVE_PATH = os.path.join(models_folder, MODEL_SAVE_NAME)

In [None]:
all_check_DICOM_dict = {
    'ap': {
        'Modality': ['CR', 'DR', 'DX'],
        'SeriesDescription': ['RODILLA AP', 'TIBIA AP DIRECTO', 'Rodilla AP', 'rodilla AP', 'W098bDER Rodilla a.p.', 'T098aDER Rodilla a.p.', 'rodilla 1P AP', 'xeonllo DEREITO AP', 'xeonllo ESQUERDO AP'],
        'BodyPartExamined': ['LOWER LIMB', 'KNEE'],
        'function': lambda row: row.Rows/row.Columns >= 0.83,
    },
    # 'lat': {
    #     'Modality': ['CR', 'DR', 'DX'],
    #     'SeriesDescription': ['RODILLA LAT', 'TIBIA LAT DIRECTO', 'RODILLA LAT EN CARGA', 'T Rodilla lat', 'rodilla LAT', 'rodilla  LAT', 'W098bDER Rodilla lat.', 'T098aDER Rodilla lat', 'rodilla 1P LAT', 'xeonllo DEREITO LAT', 'xeonllo ESQUERDO LAT', 'TOBILLO EN CARGA LAT', 'PIE LAT EN CARGA', 'rodilla LAT dcha', 'rodilla LAT izda'],
    #     'BodyPartExamined': ['LOWER LIMB', 'KNEE']
    # },
    # 'two': {
    #     'Modality': ['CR', 'DR', 'DX'],
    #     'SeriesDescription': ['RODILLAS AP', 'rodilla AP y LAT', 'ambas rodillas AP', 'ambas rodillas LAT', 'rodilla (telemando) AP y LAT', 'rodilla AP y LAT', 'Rodillas LAT', 'Rodilla AP y LAT', 'ambolos dous xeonllos AP', 'ambolos dous xeonllos LAT', 'rodilla seriada', 'Rodillas AP', 'Rodillas LAT'],
    #     'BodyPartExamined': ['LOWER LIMB', 'KNEE']
    # },
    # 'other': {
    #     'Modality': ['CR', 'DR', 'DX'],
    #     'BodyPartExamined': ['THORAX', 'UPPER LIMB', 'KNEE STANDING',
    #    'RIBS', 'HAND', 'HIP', 'PIE EN CARGA', 'FOOT', 'ANKLE',
    #    'ELBOW', 'PELVIS', 'LSPINE', 'CSPINE']
    # }
}

targets = list(all_check_DICOM_dict.keys()) + ['other']

In [None]:
metadata_raw_path = os.path.join(PATH_PREFIX, 'metadata_raw.csv')
metadata_df = pd.read_csv(metadata_raw_path)
metadata_df.fname = metadata_df.fname.apply(
    lambda x: os.path.normpath(
        os.path.join(
            raw_preprocess_folder,
            os.path.split(x)[-1] + '.png'
        )
    )
    .replace(os.sep, '/')
)

metadata_labels_path = os.path.join(PATH_PREFIX, 'metadata_labels.csv')
metadata_labels = pd.read_csv(metadata_labels_path)
reviewed_labels = metadata_labels[metadata_labels['Prob'].isnull()].rename({'Path': 'fname'}, axis=1)
reviewed_labels = reviewed_labels.set_index('fname')

# Define which column to use as the prediction
if 'Final_pred' in reviewed_labels.columns:
    pred_col = 'Final_pred'
else:
    pred_col = 'Pred'

# Initialize lists containing the filenames for each class
all_fnames = {}

for label, check_DICOM_dict in all_check_DICOM_dict.items():
    # Check DICOM which according to the metadata should be that label
    match_df = df_check_DICOM(metadata_df, check_DICOM_dict)
    
    # Remove cases that have been reviewed and selected as DIFFERENT from the current label
    match_df = match_df.merge(reviewed_labels[reviewed_labels[pred_col] != label], how='left', left_on='fname', right_index=True)
    match_df = match_df[match_df[pred_col].isnull()]

    # Add cases that have been reviewed and selected as EQUAL from the current label
    match_df = pd.concat(
        [
            reviewed_labels[reviewed_labels[pred_col] == label].reset_index(),
            match_df
        ]
    ).drop_duplicates('fname').reset_index(drop=True)
    all_fnames[label] = L(list(match_df.fname))

# Set as raw filenames all the ones in the metadata DataFrame
raw_fnames = L(list(metadata_df.fname))

# Filter on the filenames to not include undesired files
raw_fnames = L(filter_fnames(raw_fnames, metadata_raw_path))

# Label the rest of images as other
other_fnames = copy(raw_fnames)
for label, fnames in all_fnames.items():
    other_fnames = L(set(other_fnames.map(lambda path: str(path).replace(os.sep, '/'))) - set(fnames))

# Filter on the filenames to not include undesired files
other_fnames = filter_fnames(other_fnames, metadata_raw_path)
all_fnames['other'] = L(other_fnames)

# Select the corresponding part for training
if N_TRAIN is None:
    fnames = raw_fnames
else:
    fnames = random.choices(raw_fnames, k=N_TRAIN)

In [None]:
print(targets)
[len(all_fnames[label]) for label in all_fnames]

In [None]:
# Trying to not use oversampling due to issues on relabeling data and because with 2 labels is already somehow balanced

# # Oversampling of all classes to meet the biggest one or reach max_n_times its own size
# max_samples = max([len(fnames) for _, fnames in all_fnames.items()])
# max_n_times = 4
# for label, fnames in all_fnames.items():
#     k = min(max_samples-len(fnames), max_n_times * len(fnames))
#     all_fnames[label] = all_fnames[label] + random.choices(all_fnames[label], k=k)

# print(targets)
# [len(all_fnames[label]) for label in all_fnames]

In [None]:
# Create DataFrame with the filenames and the corresponding label
labels_concat = []
for label, fnames in all_fnames.items():
    label_df = pd.DataFrame(list(fnames), columns=['fname'])
    label_df['Target'] = label
    labels_concat.append(label_df)

labels_df = pd.concat(labels_concat).set_index('fname', drop=False)

In [None]:
item_tfms = []

if HIST_CLIPPING:
    item_tfms.append(XRayPreprocess(PIL_cls=PILImageBW, cut_min=HIST_CLIPPING_CUT_MIN, cut_max=HIST_CLIPPING_CUT_MAX, np_input=len(item_tfms) > 0, np_output=True))

if KNEE_LOCALIZER:
    item_tfms.append(KneeLocalizer(KNEE_SVM_MODEL_PATH, PIL_cls=PILImageBW, resize=TRAIN_RESIZE, np_input=len(item_tfms) > 0, np_output=True))

batch_tfms = [
    Flip(),
    *aug_transforms(
        pad_mode=PadMode.Zeros,
    ),
    RandomResizedCropGPU(RANDOM_RESIZE_CROP, min_scale=RANDOM_MIN_SCALE),
    Normalize()
]

In [None]:
# Histogram scaling DICOM on the fly

if CLAHE_SCALED:
    item_tfms.append(CLAHE_Transform(PIL_cls=PILImageBW, grayscale=True, np_input=len(item_tfms) > 0, np_output=False))
elif HIST_SCALED:
    if HIST_SCALED_SELF:
        bins = None
    else:
        # bins = init_bins(fnames=L(list(final_df['Original'].values)), n_samples=100)
        all_valid_raw_preprocess = pd.concat([pd.Series(unlabel_all_df.index), label_df['Raw_preprocess']])
        bins = init_bins(L([fname for labels,fnames in all_fnames.items() for fname in fnames]), n_samples=N_SAMPLES_BIN, isDCM=False)
    # item_tfms.append(HistScaled(bins))
    item_tfms.append(HistScaled_all(bins))

In [None]:
dls = DataBlock(
    blocks=(ImageBlock(PILImageBW), CategoryBlock),
    get_x=ColReader('fname'),
    get_y=ColReader('Target'),
    splitter=RandomSplitter(0.2),
    item_tfms=item_tfms,
    batch_tfms=batch_tfms,
).dataloaders(labels_df, bs=BATCH_SIZE, num_workers=0, shuffle_train=True, drop_last=True)
dls.show_batch(max_n=25, cmap=plt.cm.bone)

In [None]:
loss_func = None

loss_func = NCEandRCE(1, 1, len(targets))

In [None]:
# Define the callbacks that will be used during training
callback_fns = [
        MixUp(),
        # partial(OverSamplingCallback),
        # ShowGraphCallback(),
        EarlyStoppingCallback(monitor='val_loss', min_delta=0.05, patience=2),
    ]

# Adapt metrics depending on the number of labels
if len(targets) == 2:
    average = 'binary'
    roc_auc = RocAucBinary()
else:
    average = 'macro'
    roc_auc = RocAuc()

f1_score = F1Score(average=average)
precision = Precision(average=average)
recall = Recall(average=average)
learn = cnn_learner(
    dls,
    resnet18,
    metrics=[
        error_rate,
        roc_auc,
        f1_score,
        precision,
        recall
    ],
    loss_func=loss_func,
    callback_fns=callback_fns,
    config={'n_in': 1}
)

# Regularization by using float precision of 16 bits
# This helps to not overfit because is more difficult to "memorize" images, but enough to learn
learn = learn.to_fp16()

In [None]:
if USE_SAVED_MODEL:
    model_load = create_model(MODEL, len(targets))
    opt_load = copy(learn.opt)

    if not os.path.exists(models_folder):
        os.makedirs(models_folder)

    load_model(file=MODEL_SAVE_PATH, model=model_load, opt=opt_load, device=0)
    learn.model = model_load
    learn.opt = opt_load

In [None]:
learn.lr_find()

In [None]:
%%time
learn.fine_tune(5, 0.025, freeze_epochs=2)

In [None]:
learn.show_results(max_n=25)

In [None]:
if SAVE_MODEL:
    save_model(file=os.path.join('models', MODEL_SAVE_NAME), model=learn.model, opt=learn.opt)

In [None]:
# Use too much RAM and the session is not capable of handle it
# interp = Interpretation.from_learner(learn)
# losses, idx = interp.top_losses()
# interp.plot_top_losses(25, figsize=(15,10))

In [None]:
# Generating DataLoader and select the paths that will be used for inference from raw_fnames
raw_dls = DataBlock(
    blocks=(ImageBlock(PILImageBW)),
    get_x=ColReader('fname'),
    item_tfms=item_tfms,
).dataloaders(pd.DataFrame(list(raw_fnames),columns=['fname']), bs=BATCH_SIZE, num_workers=0, shuffle_train=True, drop_last=True)

# paths = raw_dls.train.items + raw_dls.valid.items
paths = list(raw_dls.train.items.iloc[:,0].values) + list(raw_dls.valid.items.iloc[:,0].values)
labels = [labels_df.loc[path, 'Target'] if type(labels_df.loc[path, 'Target']) is not pd.Series else labels_df.loc[path, 'Target'][0] for path in tqdm(paths, desc='Matching labels')]

# Add DataSet from paths to the Test set of the learner
dl = learn.dls.test_dl(paths)

# Calculate predictions and probabilities
preds, _ = learn.tta(dl=dl)
# preds, _ = learn.get_preds(dl=dl)
max_probs, targs = preds.max(1)

In [None]:
class_threshold = {
    'Correct_label': 0.95,
    'Wrong_label': 0.95,
}

metadata_labels_path = os.path.join(PATH_PREFIX, 'metadata_labels.csv')
metadata_labels = pd.read_csv(metadata_labels_path)
reviewed_labels = metadata_labels[metadata_labels['Prob'].isnull()].rename({'Path': 'fname'}, axis=1)
reviewed_labels = reviewed_labels.set_index('fname')

data = {
    'Path': [],
    'Label': [],
    'Raw_pred': [],
    'Pred': [],
    'Prob': [],
}
to_be_reviewed = []
for label, prob, targ, path_str in tqdm(zip(labels, max_probs, targs, paths), total=len(labels)):
    path = Path(path_str)
    raw_pred = targets[targ]
    
    # Check if already reviewed
    try:
        review = reviewed_labels.loc[path_str]

        # Set current data if reviews
        pred = review['Final_pred']
        prob = np.NaN
    except KeyError:
        # Set prob and pred according to the thresholds
        prob = float(prob)

        # Case of confidence on itself to predict same labels as target
        if label == targets[targ]:
            if prob >= class_threshold['Correct_label']:
                pred = raw_pred
            else:
                pred = 'Unsure_' + targets[targ] + '_' + label
                to_be_reviewed.append((path, targets[targ], label, prob))

        # Confidence on wrong labelling
        else:
            if prob >= class_threshold['Wrong_label']:
                pred = raw_pred
            else:
                pred = 'Unsure_' + targets[targ] + '_' + label
                to_be_reviewed.append((path, targets[targ], label, prob))

    data['Path'].append(os.path.normpath(path).replace(os.sep, '/'))
    data['Label'].append(label)
    data['Raw_pred'].append(raw_pred)
    data['Pred'].append(pred)
    data['Prob'].append(prob)

df = pd.DataFrame(data)
df.to_csv(metadata_labels_path, index=False)

In [None]:
def _open_thumb(fn, h, w): return Image.open(fn).to_thumb(h, w).convert('RGBA')

class ImagesCleanerDefaultPred(ImagesCleaner):
    "A widget that displays all images in `fns` along with a `Dropdown` with default value the prediction"

    def set_fns(self, fns, preds, labels, probs):
        self.fns = L(fns)[:self.max_n]
        # ims = parallel(_open_thumb, self.fns, h=self.height, w=self.width, progress=False,
        #                n_workers=min(len(self.fns)//10,defaults.cpus))
        ims = [_open_thumb(fn, h=self.height, w=self.width) for fn in self.fns]
        self.widget.children = [
            VBox([
                Label(f'{pred}/{label}/{prob:.4f}'),
                widget(im, height=f'{self.height}px'),
                Dropdown(options=self.opts, layout={'width': 'max-content'}, value=pred)
            ]) for im, pred, label, prob in zip(ims,preds,labels,probs)
        ]

    def values(self): return L(self.widget.children).itemgot(-1).attrgot('value')

In [None]:
# Check the unsure with lowest probability
df_to_review = df[(~df['Pred'].isin(targets)) & (df['Prob'].notnull())].sort_values(['Raw_pred', 'Prob']).iloc[:100]

# # Check the OTHER cases which the model is totally sure and are also confirmed by metadata
# df_to_review = df_rev[(df_rev['Label'] == 'other') & (df_rev['Pred'] == 'other') & (df['Prob'].notnull())].sort_values('Prob', ascending=False).iloc[:100]

# # Check the AP cases which the model is totally sure and are also confirmed by metadata
# df_to_review = df_rev[(df_rev['Label'] == 'ap') & (df_rev['Pred'] == 'ap') & (df['Prob'].notnull())].sort_values('Prob', ascending=False).iloc[:100]

w = ImagesCleanerDefaultPred(targets,  max_n=len(df_to_review.index))
w.set_fns(
    list(df_to_review['Path']),
    list(df_to_review['Raw_pred']),
    list(df_to_review['Label']),
    list(df_to_review['Prob'])
)
w

In [None]:
w.change()

In [None]:
df['Final_pred'] = df['Pred']
for i, pred in w.change():
    idx = df_to_review.iloc[i].name
    df.loc[idx, 'Final_pred'] = pred
    df.loc[idx, 'Prob'] = np.nan

    # Update label image if required
    path =  Path(df.loc[idx, 'Path'])
    if path.parent.name != pred:
        labels_df.loc[path, 'Target'] = pred
        labels_df.loc[path, 'fname'] = path

df.to_csv(metadata_labels_path, index=False)