In [1]:
import torch
import torchxrayvision as xrv
import torchvision.transforms as transforms
from skimage import io
import os
import json

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")

In [3]:
image_folder = "resized_images/"
output_file = "extracted_features.json"

In [4]:
transform = transforms.Compose([
    xrv.datasets.XRayCenterCrop(),
    xrv.datasets.XRayResizer(224)
])

In [5]:
model = xrv.models.DenseNet(weights="densenet121-res224-all").to(device)


In [None]:
with open(output_file, 'w') as f:
    f.write("{\n")  # Start of JSON object
    
    # Iterate over all PNG images in the folder
    for idx, img_name in enumerate(os.listdir(image_folder)):
        if img_name.endswith(".png"):
            img_path = os.path.join(image_folder, img_name)
            
            # Load and preprocess the image
            img = io.imread(img_path)
            img = xrv.datasets.normalize(img, 255)  # Normalize the image
            img = img.mean(2)[None, ...]  # Convert to single color channel
            img = transform(img)
            img = torch.from_numpy(img).unsqueeze(0).to(device)  # Add batch dimension and move to MPS/CPU

            # Extract features using DenseNet model
            features = model.features(img)
            
            # Move features back to CPU for serialization
            features = features.cpu().detach().numpy().tolist()
            
            # Write the current image's features to the file
            f.write(f'"{img_name}": {json.dumps(features)}')
            
            # Add a comma after each entry except the last one
            if idx < len(os.listdir(image_folder)) - 1:
                f.write(",\n")
            else:
                f.write("\n")
    
    f.write("}\n")  # End of JSON object