In [None]:
import os

# Processing
import numpy as np
import pandas as pd

# PyTorch packages
import torchxrayvision as xrv
import torch, torchvision
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# To load in image. You can use something else as well
import skimage

# To retrieve all paths in an image
import glob

# Progress bar (optional)
import tqdm

# Plotting
import matplotlib.pyplot as plt

In [None]:
# Specify your directory to the dataset, where all images are stored within
path = '/content/physionet.org/files/mimic-cxr-jpg/2.0.0/files'

In [None]:
from torch.utils.data import Dataset

class MIMICDataset(Dataset):
    def __init__(self, img_dir, transform=None, target_transform=None):
        self.img_dir = img_dir
        # Get all directories to the image dataset. * means wildcard, ** means wildcard including any subdirectories on the way
        self.dir = glob.glob(img_dir+'/**/*.jpg', recursive=True)
        self.transform = transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        image = skimage.io.imread(self.dir[idx])
        image = xrv.datasets.normalize(image, 255)
        if self.transform:
            image = self.transform(image[None, ...])
        return torch.from_numpy(image)

transform = torchvision.transforms.Compose([xrv.datasets.XRayCenterCrop(),xrv.datasets.XRayResizer(224)])
data = MIMICDataset(path, transform=transform)

In [None]:
model = xrv.models.DenseNet(weights="densenet121-res224-chex")
model.to(device)
activation = {}

def get_activation(name):
    def hook(model, input, output):
        activation[name] = output.detach()
    return hook
model.classifier.register_forward_hook(get_activation('classifier'))

latent = []
for i in tqdm.tqdm(data):
  i.to(device)
  output = model(i[None,...])
  latent.append(activation['classifier'])

out = torch.stack(latent).numpy()

Downloading weights...
If this fails you can run `wget https://github.com/mlmed/torchxrayvision/releases/download/v1/chex-densenet121-d121-tw-lr001-rot45-tr15-sc15-seed0-best.pt -O /root/.torchxrayvision/models_data/chex-densenet121-d121-tw-lr001-rot45-tr15-sc15-seed0-best.pt`
[██████████████████████████████████████████████████]


2it [00:02,  1.04s/it]


In [None]:
df = pd.DataFrame(out.reshape(-1,18))
df.insert(loc=0, column='file_name', value=[os.path.basename(s) for s in data.dir])
df.to_csv('processed_data.csv')