### Run feature extraction on OAI images

In [2]:
!ls /datasets/

osteoarthritis-initiative


In [3]:
import os
import io
import time
import boto3
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
import nibabel as nib
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm

from thesisproject.models import UNet
from thesisproject.predict import Predict
from thesisproject.data import ImagePairDataset, extract_features

In [4]:
"""
AWS_S3_CREDS = {
    "aws_access_key_id": "AKIA5M53U5MGSUORAVXQ",
    "aws_secret_access_key": "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"
}

session = boto3.Session(**AWS_S3_CREDS)
s3 = session.resource('s3')

bucket = "osteoarthritis-initiative"
oai_bucket = s3.Bucket(bucket)
"""

uploaded_files = os.listdir("/datasets/osteoarthritis-initiative/")# Or read from oai-files.txt
"""
for bucket_object in oai_bucket.objects.all():
    uploaded_files.append(bucket_object.key)
"""

'\nfor bucket_object in oai_bucket.objects.all():\n    uploaded_files.append(bucket_object.key)\n'

In [5]:
class Square_pad:
    def __call__(self, image: torch.Tensor):
        imsize = image.shape
        max_edge = np.argmax(imsize)
        pad_amounts = [imsize[max_edge] - imsize[0], imsize[max_edge] - imsize[1], imsize[max_edge] - imsize[2]]

        padding = [int(np.floor(pad_amounts[0] / 2)),
                   int(np.ceil(pad_amounts[0] / 2)),
                   int(np.floor(pad_amounts[1] / 2)),
                   int(np.ceil(pad_amounts[1] / 2)),
                   int(np.floor(pad_amounts[2] / 2)),
                   int(np.ceil(pad_amounts[2] / 2)),] #left, right, top, bottom, front, back
        padding = tuple(padding[::-1])
        
        padded_im = F.pad(image, padding, "constant", 0)
        return padded_im
    
def test_collate(image):
    return image

def filename_to_subject_info(filename):
    subject_id = int(filename[:7])
    is_right = False
    if filename[8] == "R":
        is_right = True
        knee = filename[8:13]
        visit = int(filename[15:17])
    else:
        knee = filename[8:12]
        visit = int(filename[14:16]) 
    return subject_id, is_right, visit

volume_transform = Square_pad()

In [6]:
label_keys = ["Lateral femoral cart.",
              "Lateral meniscus",
              "Lateral tibial cart.",
              "Medial femoral cartilage",
              "Medial meniscus",
              "Medial tibial cart.",
              "Patellar cart.",
              "Tibia"]
net = UNet(1, 9, 384, class_names=label_keys)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
net.to(device)

print("loaded U-net")

TypeError: __init__() missing 1 required positional argument: 'image_size'

In [7]:
checkpoint = torch.load("model_saves/unet-checkpoint.pt")
net.load_state_dict(checkpoint["model_state_dict"])

predict = Predict(net, batch_size=8, show_progress=False)

NameError: name 'net' is not defined

In [6]:
subjects_df = pd.read_csv("../subjects.csv", index_col="subject_id_and_knee")

In [7]:
computed_files = []
if os.path.exists("feature_extract.csv"):
    df = pd.read_csv("feature_extract.csv")
    for _, row in df.iterrows():
        computed_files.append(row["filename"])
else:
    df = pd.DataFrame()
    
files_to_compute = list(set(uploaded_files) - set(computed_files))

print(f"{len(files_to_compute)}/{len(uploaded_files)} files left for feature extraction.")

1633/1925 files left for feature extraction.


In [None]:
if not os.path.exists("tmp_img"):
    os.mkdir("tmp_img")

pbar = tqdm(total=len(uploaded_files), unit="images")
pbar.update(len(computed_files))
try:
    for filename in files_to_compute:
        pbar.set_description(f"{filename} (prediction)")
        start = time.time()
        # Load s3 file as nib object
        #s3_object = oai_bucket.Object(filename).download_file("tmp_img/" + filename)
        
             
        nii_file = nib.load(f"/datasets/osteoarthritis-initiative/{filename}")

        isright = filename[8] == 'R'
        scan = nii_file.get_fdata()

        # Flip coronal plane
        scan = np.flip(scan, axis=1).copy()

        if isright:
            scan = np.flip(scan, axis=2).copy()

        scan_tensor = volume_transform(torch.from_numpy(scan).float().to(device))

        scan_tensor -= scan_tensor.min()
        scan_tensor /= scan_tensor.max()

        prediction = predict(scan_tensor)

        pbar.set_description(f"{filename} (extract)")
        extracted_features = extract_features(scan_tensor.detach().cpu().numpy(), prediction.detach().cpu().numpy())
        subject_id, is_right, visit = filename_to_subject_info(filename)
        subject_id_and_knee =  str(subject_id) + ("-R" if is_right else "-L"), 
        subject_row = subjects_df.loc[subject_id_and_knee]
        
        row_df = pd.DataFrame([{
            "subject_id_and_knee": subject_id_and_knee, 
            "is_right": is_right, 
            "visit": visit, 
            "filename": filename, 
            "TKR": subject_row["TKR"],
            **extracted_features}])
        
        df = pd.concat([df, row_df])
        df.to_csv("feature_extract.csv", index=False)
        pbar.update(1)
        os.remove("tmp_img/" + filename)
finally:
    files = os.listdir("tmp_img")
    for file in files:
        os.remove("tmp_img/" + file)
        
    pbar.close()

  0%|          | 0/1925 [00:00<?, ?images/s]

  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
