In [None]:
import pandas as pd
import numpy as np

from pathlib import Path

import torch
import timm
from torchvision.transforms import v2 as T
from torchvision.transforms.v2 import functional as TF
from torchvision.io import read_image
from torchvision.tv_tensors import Image

from tqdm.autonotebook import tqdm
import time

@torch.inference_mode()
def get_latent(model, imgloader):
    all_latents = []
    for imgs in tqdm(imgloader):
        x = model.forward_features(imgs.cuda())
        x = x.mean(1).cpu()
        all_latents.append(x)
    return torch.concat(all_latents).numpy()


class ImgDataset(torch.utils.data.Dataset):
    def __init__(self, img_dir: str | Path | list[Path], transform: torch.nn.Module = None):
        self.img_dir = img_dir
        if isinstance(img_dir, (str, Path)):
            self.img_paths = list(Path(img_dir).rglob('*.jpg'))
        else:
            self.img_paths = img_dir
        self.transform = transform or get_default_training_transforms()

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        img_path = self.img_paths[idx]
        img = read_image(str(img_path))
        img = self.transform(img)
        return img

dftrain = pd.read_csv('/mnt/c/Users/Justin/Desktop/retina_imaging_datasets/ROP_dataset/dftrain.csv')
dftest = pd.read_csv('/mnt/c/Users/Justin/Desktop/retina_imaging_datasets/ROP_dataset/dftest.csv')
img_pathtrain = '/mnt/c/Users/Justin/Desktop/retina_imaging_datasets/ROP_dataset/image/'
img_pathtest = '/mnt/c/Users/Justin/Desktop/retina_imaging_datasets/ROP_dataset/image/'

imgstrain = [img_pathtrain+image_id for image_id in dftrain['img_name']]
imgstest = [img_pathtest+image_id for image_id in dftest['img_name']]


model = torch.load(r"/mnt/c/Users/Justin/Downloads/retfoundgreen.pth").eval()
model.cuda()


transforms = T.Compose([T.Resize((392, 392), antialias=True),
          T.ToDtype(torch.float32, scale=True),
          T.Normalize((0.5,), (0.5,))])
imgdatasettrain = ImgDataset(img_dir=imgstrain, transform=transforms)
imgloadertrain = torch.utils.data.DataLoader(imgdatasettrain, batch_size=1, shuffle=False, num_workers=0)
imgdatasettest = ImgDataset(img_dir=imgstest, transform=transforms)
imgloadertest = torch.utils.data.DataLoader(imgdatasettest, batch_size=1, shuffle=False, num_workers=0)

start = time.time()
Xtrain = get_latent(model, imgloadertrain)
stop = time.time()
print(f'Took: {stop-start}s')
np.save('ROP_Train_X_ours.npy', Xtrain)

start = time.time()
Xtest = get_latent(model, imgloadertest)
stop = time.time()
print(f'Took: {stop-start}s')
np.save('ROP_Test_X_ours.npy', Xtest)

In [None]:
model = torch.load(r"/mnt/c/Users/Justin/Downloads/retfoundgreen_224.pth").eval()
model.cuda()
transforms = T.Compose([T.Resize((224, 224), antialias=True),
          T.ToDtype(torch.float32, scale=True),
          T.Normalize((0.5,), (0.5,))])
imgdatasettrain = ImgDataset(img_dir=imgstrain, transform=transforms)
imgloadertrain = torch.utils.data.DataLoader(imgdatasettrain, batch_size=1, shuffle=False, num_workers=0)
imgdatasettest = ImgDataset(img_dir=imgstest, transform=transforms)
imgloadertest = torch.utils.data.DataLoader(imgdatasettest, batch_size=1, shuffle=False, num_workers=0)

start = time.time()
Xtrain = get_latent(model, imgloadertrain)
stop = time.time()
print(f'Took: {stop-start}s')
np.save('ROP_Train_X_ours224.npy', Xtrain)

start = time.time()
Xtest = get_latent(model, imgloadertest)
stop = time.time()
print(f'Took: {stop-start}s')
np.save('ROP_Test_X_ours224.npy', Xtest)

In [None]:
model = timm.create_model('vit_small_patch14_reg4_dinov2.lvd142m',
                          pretrained=True, img_size=(392, 392), num_classes=0).cuda().eval()
transforms = T.Compose([T.Resize((392, 392), antialias=True),
          T.ToDtype(torch.float32, scale=True),
          T.Normalize((0.5,), (0.5,))])
imgdatasettrain = ImgDataset(img_dir=imgstrain, transform=transforms)
imgloadertrain = torch.utils.data.DataLoader(imgdatasettrain, batch_size=1, shuffle=False, num_workers=0)
imgdatasettest = ImgDataset(img_dir=imgstest, transform=transforms)
imgloadertest = torch.utils.data.DataLoader(imgdatasettest, batch_size=1, shuffle=False, num_workers=0)

start = time.time()
Xtrain = get_latent(model, imgloadertrain)
stop = time.time()
print(f'Took: {stop-start}s')
np.save('ROP_Train_X_dino392.npy', Xtrain)

start = time.time()
Xtest = get_latent(model, imgloadertest)
stop = time.time()
print(f'Took: {stop-start}s')
np.save('ROP_Test_X_dino392.npy', Xtest)

In [None]:
model = timm.create_model('vit_small_patch14_reg4_dinov2.lvd142m',
                          pretrained=True, img_size=(224, 224), num_classes=0).cuda().eval()
transforms = T.Compose([T.Resize((224, 224), antialias=True),
          T.ToDtype(torch.float32, scale=True),
          T.Normalize((0.5,), (0.5,))])
imgdatasettrain = ImgDataset(img_dir=imgstrain, transform=transforms)
imgloadertrain = torch.utils.data.DataLoader(imgdatasettrain, batch_size=1, shuffle=False, num_workers=0)
imgdatasettest = ImgDataset(img_dir=imgstest, transform=transforms)
imgloadertest = torch.utils.data.DataLoader(imgdatasettest, batch_size=1, shuffle=False, num_workers=0)

start = time.time()
Xtrain = get_latent(model, imgloadertrain)
stop = time.time()
print(f'Took: {stop-start}s')
np.save('ROP_Train_X_dino224.npy', Xtrain)

start = time.time()
Xtest = get_latent(model, imgloadertest)
stop = time.time()
print(f'Took: {stop-start}s')
np.save('ROP_Test_X_dino224.npy', Xtest)