In [27]:
import cv2
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import torch

from collections import defaultdict
from einops import rearrange
from importlib import import_module
from skimage.exposure import match_histograms
from skp.toolbox.classes import Ensemble
from skp.toolbox.functions import load_kfold_ensemble_as_list
from transformers import AutoModel
from tqdm import tqdm

In [28]:
device = "cuda"

In [29]:
crop_model = AutoModel.from_pretrained("ianpan/bone-age-crop", trust_remote_code=True)
crop_model = crop_model.eval().to(device)

In [47]:
cfg_name = "boneage.cfg_female_channel_reg_cls_match_hist_cropped_uncropped"
cfg = import_module(f"skp.configs.{cfg_name}").cfg
weights_paths = [
    cfg.save_dir + cfg_name + f"/c6043bd4/fold{fold}/checkpoints/last.ckpt" for fold in [0]
]
model = load_kfold_ensemble_as_list(cfg, weights_paths=weights_paths)
model = Ensemble(model, output_name="logits1", activation_fn="softmax")
model = model.eval().to(device)
ref_img = cv2.imread(cfg.ref_image_match_hist, cfg.cv2_load_flag)

In [59]:
model = AutoModel.from_pretrained("ianpan/bone-age", trust_remote_code=True)
model = model.eval().to(device)

A new version of the following files was downloaded from https://huggingface.co/ianpan/bone-age:
- configuration.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.
A new version of the following files was downloaded from https://huggingface.co/ianpan/bone-age:
- modeling.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


In [48]:
df = pd.read_csv("/mnt/stor/datasets/bone-age/test.csv")
df.head()

Unnamed: 0,pid,female,bone_age
0,4360,0,168.934249
1,4361,0,169.652678
2,4362,0,73.256112
3,4363,0,152.862669
4,4364,0,135.456954


In [63]:
# Original image - no crop or histogram matching
preds = defaultdict(list)
for row_idx, row in tqdm(df.iterrows(), total=len(df)):
    preds["pid"].append(int(row.pid))
    img0 = cv2.imread(os.path.join("/mnt/stor/datasets/bone-age/test", f"{int(row.pid)}.png"), 0)
    img = rearrange(img0, "h w -> h w 1")
    img = cfg.val_transforms(image=img)["image"]
    img = rearrange(img, "h w c -> 1 c h w")
    img = torch.from_numpy(img).float()
    # female_ch = torch.zeros_like(img)
    # if row.female:
    #     female_ch[...] = 255
    # img = torch.cat([img, female_ch], dim=1)
    with torch.inference_mode():
        bone_age = model(img.to(device), torch.tensor([row.female]).to(device))[0].cpu()
    preds["bone_age_pred"].append(bone_age)

100%|██████████| 200/200 [00:06<00:00, 30.59it/s]


In [64]:
pred_df1 = pd.DataFrame(preds)
pred_df1 = df.merge(pred_df1, on="pid")
pred_df1["mae"] = (pred_df1.bone_age_pred - pred_df1.bone_age).abs()
pred_df1["mae"].mean() # 4.42, 4.67

4.420377197265625

In [49]:
# Original image - no crop or histogram matching
preds = defaultdict(list)
for row_idx, row in tqdm(df.iterrows(), total=len(df)):
    preds["pid"].append(int(row.pid))
    img0 = cv2.imread(os.path.join("/mnt/stor/datasets/bone-age/test", f"{int(row.pid)}.png"), 0)
    img = rearrange(img0, "h w -> h w 1")
    img = cfg.val_transforms(image=img)["image"]
    img = rearrange(img, "h w c -> 1 c h w")
    img = torch.from_numpy(img).float()
    female_ch = torch.zeros_like(img)
    if row.female:
        female_ch[...] = 255
    img = torch.cat([img, female_ch], dim=1)
    with torch.inference_mode():
        bone_age = model({"x": img.to(device)})[0].cpu()
        bone_age = (bone_age * torch.arange(240)).sum().numpy()
    preds["bone_age_pred"].append(bone_age)

  6%|▌         | 12/200 [00:00<00:03, 57.50it/s]

100%|██████████| 200/200 [00:03<00:00, 56.96it/s]


In [50]:
pred_df1 = pd.DataFrame(preds)
pred_df1 = df.merge(pred_df1, on="pid")
pred_df1["mae"] = (pred_df1.bone_age_pred - pred_df1.bone_age).abs()
pred_df1["mae"].mean() # 4.42, 4.67

4.667528762547231

In [51]:
# Crop, no histogram matching
preds = defaultdict(list)
for row_idx, row in tqdm(df.iterrows(), total=len(df)):
    preds["pid"].append(int(row.pid))
    img0 = cv2.imread(os.path.join("/mnt/stor/datasets/bone-age/test", f"{int(row.pid)}.png"), 0)
    img = rearrange(img0, "h w -> h w 1")
    img = crop_model.preprocess(img)
    img = rearrange(img, "h w c -> 1 c h w")
    img = torch.from_numpy(img).float()
    with torch.inference_mode():
        box = crop_model(img.to(device), torch.tensor([img0.shape[:2]]).to(device)).cpu().numpy()
    x, y, w, h = box[0]
    cropped_img0 = img0[y: y + h, x: x + w]
    cropped_img = rearrange(cropped_img0, "h w -> h w 1")
    cropped_img = cfg.val_transforms(image=cropped_img)["image"]
    cropped_img = rearrange(cropped_img, "h w c -> 1 c h w")
    cropped_img = torch.from_numpy(cropped_img).float()
    female_ch = torch.zeros_like(cropped_img)
    if row.female:
        female_ch[...] = 255
    cropped_img = torch.cat([cropped_img, female_ch], dim=1)
    with torch.inference_mode():
        bone_age = model({"x": cropped_img.to(device)})[0].cpu()
        bone_age = (bone_age * torch.arange(240)).sum().numpy()
    preds["bone_age_pred"].append(bone_age)

  3%|▎         | 6/200 [00:00<00:03, 52.64it/s]

100%|██████████| 200/200 [00:03<00:00, 51.26it/s]


In [52]:
pred_df2 = pd.DataFrame(preds)
pred_df2 = df.merge(pred_df2, on="pid")
pred_df2["mae"] = (pred_df2.bone_age_pred - pred_df2.bone_age).abs()
pred_df2["mae"].mean() # 4.47, 4.84

4.8375959294328235

In [53]:
# No crop, yes histogram matching
preds = defaultdict(list)
for row_idx, row in tqdm(df.iterrows(), total=len(df)):
    preds["pid"].append(int(row.pid))
    img0 = cv2.imread(os.path.join("/mnt/stor/datasets/bone-age/test", f"{int(row.pid)}.png"), 0)
    img0 = match_histograms(img0, ref_img)
    img = rearrange(img0, "h w -> h w 1")
    img = cfg.val_transforms(image=img)["image"]
    img = rearrange(img, "h w c -> 1 c h w")
    img = torch.from_numpy(img).float()
    female_ch = torch.zeros_like(img)
    if row.female:
        female_ch[...] = 255
    img = torch.cat([img, female_ch], dim=1)
    with torch.inference_mode():
        bone_age = model({"x": img.to(device)})[0].cpu()
        bone_age = (bone_age * torch.arange(240)).sum().numpy()
    preds["bone_age_pred"].append(bone_age)

  7%|▋         | 14/200 [00:00<00:04, 40.96it/s]

100%|██████████| 200/200 [00:05<00:00, 38.00it/s]


In [54]:
pred_df3 = pd.DataFrame(preds)
pred_df3 = df.merge(pred_df3, on="pid")
pred_df3["mae"] = (pred_df3.bone_age_pred - pred_df3.bone_age).abs()
pred_df3["mae"].mean() # 4.34, 4.59

4.590002495312414

In [55]:
# Crop and histogram matching
preds = defaultdict(list)
for row_idx, row in tqdm(df.iterrows(), total=len(df)):
    preds["pid"].append(int(row.pid))
    img0 = cv2.imread(os.path.join("/mnt/stor/datasets/bone-age/test", f"{int(row.pid)}.png"), 0)
    img = rearrange(img0, "h w -> h w 1")
    img = crop_model.preprocess(img)
    img = rearrange(img, "h w c -> 1 c h w")
    img = torch.from_numpy(img).float()
    with torch.inference_mode():
        box = crop_model(img.to(device), torch.tensor([img0.shape[:2]]).to(device)).cpu().numpy()
    x, y, w, h = box[0]
    cropped_img0 = img0[y: y + h, x: x + w]
    cropped_img0 = match_histograms(cropped_img0, ref_img)
    cropped_img = rearrange(cropped_img0, "h w -> h w 1")
    cropped_img = cfg.val_transforms(image=cropped_img)["image"]
    cropped_img = rearrange(cropped_img, "h w c -> 1 c h w")
    cropped_img = torch.from_numpy(cropped_img).float()
    female_ch = torch.zeros_like(cropped_img)
    if row.female:
        female_ch[...] = 255
    cropped_img = torch.cat([cropped_img, female_ch], dim=1)
    with torch.inference_mode():
        bone_age = model({"x": cropped_img.to(device)})[0].cpu()
        bone_age = (bone_age * torch.arange(240)).sum().numpy()
    preds["bone_age_pred"].append(bone_age)

  2%|▏         | 4/200 [00:00<00:05, 36.59it/s]

100%|██████████| 200/200 [00:04<00:00, 40.64it/s]


In [56]:
pred_df4 = pd.DataFrame(preds)
pred_df4 = df.merge(pred_df4, on="pid")
pred_df4["mae"] = (pred_df4.bone_age_pred - pred_df4.bone_age).abs()
pred_df4["mae"].mean() # 4.16, 4.45

4.450715427270441

In [None]:
y_true = pred_df1.bone_age.values
y_pred = (pred_df1.bone_age_pred.values + pred_df2.bone_age_pred.values + pred_df3.bone_age_pred.values + pred_df4.bone_age_pred.values) / 4
np.mean(np.abs(y_true - y_pred)) # 4.22, 4.42

4.423161336917886