In [1]:
from sklearn.model_selection import train_test_split
from scripts.tools import return_files_in_directory, human_sort
import os 
from scripts.config import DATA_DIR, DEVICE
import torch
from torchvision.models import resnet50
from scripts.dataset import Colonoscopy_Dataset
from torch.utils.data import DataLoader
from scripts.embeddings import ResnetFeatureExtractor,generate_embedding_masks_for_dataset

In [2]:
image_files = return_files_in_directory(DATA_DIR + "/original", ".tif")
box_files = return_files_in_directory(DATA_DIR + "/boxmasks", ".png")
# Ensure files are in correct order
human_sort(image_files)
human_sort(box_files)
# X_train, X_test, y_train, y_test = train_test_split(image_files, box_files, test_size=0.1, random_state=1)
# X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.11111, random_state=1) # 0.1111 x 0.9 = 0.1

# dataset = Colonoscopy_Dataset(X_val, y_val)
dataset = Colonoscopy_Dataset(image_files, box_files)
data_loader = DataLoader(dataset, batch_size=64, shuffle=False, num_workers=0)

TESTING_DIR = DATA_DIR + "/testing/robust_boxshrink"
EMBEDDING_DIR = DATA_DIR + "/mean_embeddings/"
if not os.path.exists(TESTING_DIR):
    os.makedirs(TESTING_DIR)

if not os.path.exists(EMBEDDING_DIR):
    os.makedirs(EMBEDDING_DIR)

In [3]:
## Get mean embeddings
f = torch.load(EMBEDDING_DIR + "foreground_embedding.pt")
b = torch.load(EMBEDDING_DIR + "background_embedding.pt")
mean_f = torch.mean(f, dim=0)
mean_b = torch.mean(b, dim=0)

In [4]:
resnet = resnet50(weights="ResNet50_Weights.IMAGENET1K_V2")
resnet.eval()
feature_extract_model = ResnetFeatureExtractor(resnet)

feature_extract_model.to(DEVICE)
res = generate_embedding_masks_for_dataset(dataset, TESTING_DIR, feature_extract_model, mean_f, mean_b, save_as_png=True)

Generating Embedding Masks:   0%|          | 0/10 [00:00<?, ?batch/s]

mask.shape = torch.Size([288, 384])
tensor([[  0,   0,   0,  ...,   0,   0,   0],
        [  0,   0,   0,  ...,   0,   0,   0],
        [  0,   0,   0,  ...,   0,   0,   0],
        ...,
        [255, 255, 255,  ...,   0,   0,   0],
        [255, 255, 255,  ...,   0,   0,   0],
        [255, 255, 255,  ...,   0,   0,   0]])
mask.shape = torch.Size([288, 384])
tensor([[  0,   0,   0,  ...,   0,   0,   0],
        [  0,   0,   0,  ...,   0,   0,   0],
        [  0,   0,   0,  ...,   0,   0,   0],
        ...,
        [255, 255, 255,  ...,   0,   0,   0],
        [255, 255, 255,  ...,   0,   0,   0],
        [255, 255, 255,  ...,   0,   0,   0]])
mask.shape = torch.Size([288, 384])
tensor([[  0,   0,   0,  ...,   0,   0,   0],
        [  0,   0,   0,  ...,   0,   0,   0],
        [  0,   0,   0,  ...,   0,   0,   0],
        ...,
        [255, 255, 255,  ...,   0,   0,   0],
        [255, 255, 255,  ...,   0,   0,   0],
        [255, 255, 255,  ...,   0,   0,   0]])
mask.shape = torch.Siz