In [2]:
# Paths & URLs

import os

# Enable CUDA stacktrace reporting for debugging
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

# Directorio base
#PATH_BASE = '/content/drive/MyDrive/proximity'
PATH_BASE = 'C:\\Users\\User\\Documents\\Proyecto Proximity'

# Data release actual
DR70_PATH = os.path.join(PATH_BASE, 'DR70')
DR70_CT_PATH = os.path.join(DR70_PATH, 'datalake_sorted')
DR70_LABELS_PATH = os.path.join(DR70_PATH, 'labels.csv')


# CTs in Nibabel format
CT_NIBABEL_PATH = os.path.join(PATH_BASE, 'DR70', 'CTs')

# Embeddings visuales de CTs
#CT_EMBEDDINGS_PATH = DATA_RELEASE_PATH + '/visual_embeddings'


# Etiquetas de los CTs del data release actual
#CT_LABELS_CSV_PATH = DATA_RELEASE_PATH + '/labels.csv'

# Data release (CTs + etiquetas) organizados en un DataFrame
#CT_DATASET_DF_HDF_PATH = os.path.join(PATH_BASE, 'dataset_df.h5')
#CT_DATASET_DF_PICKLE_PATH = os.path.join(PATH_BASE, 'dataset_df.pickle')

# URLs de modelos visuales
#RESNET18_URL = 'microsoft/resnet-18'

# Path que contiene los resnet50 embeddings de CTs del data release actual
#CT_RESNET18_EMBEDDINGS_PATH = os.path.join(DR70_PATH, 'visual_embeddings', 'resnet18')
#CT_RESNET18_EMBEDDINGS_PATH = os.path.join(DR70_PATH, 'visual_embeddings', 'resnet18', 'reshaped_averaged')

# Path de modelos entrenados en base a tripletas
TRIPLET_MODELS_PATH = os.path.join(PATH_BASE, 'retrieval_models', 'triplets')
TRIPLET_CHECKPOINTS_PATH = os.path.join(PATH_BASE, 'retrieval_models', 'triplets', 'checkpoints')

In [2]:
# DATASET LOADING & PREPARATION

import pandas as pd
import numpy as np

def read_labels_csv(csv_path):
    labels_df = pd.read_csv(
        dataframe_path, 
        header=0, 
        index_col=0, 
        dtype={'ct': str, 'condensacion': int, 'nodulos': int, 'quistes': int}
    )
    return labels_df

def split_dataset(dataframe_path):
    labels_df = read_labels_csv(dataframe_path)

    train_size = int(0.8 * len(labels_df))
    val_size = int(0.2 * len(labels_df))
    test_size = int(0.0 * len(labels_df))

    '''
    If necessary, adjust the size of the training set when the sum of 
    the sizes of the three sets differs from the total dataset size.
    '''
    size_diff = len(labels_df) - train_size - val_size - test_size
    train_size += size_diff

    train_df, val_df, test_df = np.split(
        labels_df.sample(frac=1.0, random_state=31),
        [train_size, train_size + val_size]
    )
    return (train_df, val_df, test_df)

train_df, val_df, test_df = split_dataset(DR70_LABELS_PATH)

NameError: name 'dataframe_path' is not defined

In [22]:
print(train_df.shape, val_df.shape, test_df.shape)

(56, 4) (14, 4) (0, 4)


In [4]:
import os
import pandas as pd
import numpy as np
import torch
import torchio as tio
from torch.utils.data import Dataset, random_split, RandomSampler, BatchSampler
from scipy.spatial.distance import hamming
import random
import logging

# Set seeds
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed(0)
torch.cuda.manual_seed_all(0)

logging.basicConfig(level=logging.INFO)

class CustomDataset(Dataset):
    def __init__(self, data_dir, labels_csv_dir):
        self.data_dir = data_dir
        self.labels_df = pd.read_csv(
            labels_csv_dir, 
            header=0, 
            index_col=0, 
            dtype={'CT': str, 'condensacion': int, 'nodulos': int, 'quistes': int}
        )
        
        '''
        delete rows from the dataset that don't have a corresponding embedding--they
        may not exist because they were invalid or their CTs too short to be processed by 
        the visual encoder)
        '''
        to_be_deleted_ids = list()
        for idx, ct_id in zip(self.labels_df.index, self.labels_df['CT']):
            file_path = os.path.join(self.data_dir, f"{ct_id}.npy")
            if not os.path.exists(file_path):
                to_be_deleted_ids.append(idx)
                logging.warning(f"Data point with dataset_id={idx} will be discarded because its corresponding visual embedding file does not exist: {file_path}")
        if len(to_be_deleted_ids) > 0:
            self.labels_df = self.labels_df.drop(index=to_be_deleted_ids)
            self.labels_df = self.labels_df.reset_index(drop=True)
        self.labels_df['embedding'] = None
        #print(self.labels_df.head())

    def __len__(self):
        return len(self.labels_df)

    def __getitem__(self, idx):
        #print('__getitem__',idx)
        if hasattr(idx, '__iter__'):
            return self.__getitems__(idx)
        else:
            return self.__getitems__([idx])

    def __getitems__(self, ids):
        ids_list = list()
        samples_list = list()
        labels_list = list()
        for idx in ids:
            content, labels = self.load_data(idx)
            ids_list.append(idx)
            samples_list.append(content)
            labels_list.append(labels)
        ids_list = torch.tensor(ids)
        samples_list = torch.stack(samples_list)
        labels_list = torch.stack(labels_list)
        return (ids_list.detach().clone(), samples_list.detach().clone(), labels_list.detach().clone())

    
    def get_updated_embeddings(self, torch_model, device=None):
        embeddings = list()
        with torch.inference_mode():
            for i, data_sample in self.labels_df.iterrows():
                data_input, data_labels = self.load_data(i)
                if data_input.shape[0] != 1:
                    data_input = data_input.unsqueeze(0)
                if device:
                    data_output = torch_model(data_input.to(device))
                else:
                    data_output = torch_model(data_input)
                embeddings.append(data_output)
        self.labels_df['embedding'] = embeddings
        return torch.stack(embeddings)
    
    def get_embeddings(self, ids):
        if not hasattr(ids, '__iter__'):
            ids = [ids]
        selected_data_samples = self.labels_df.iloc[ids]
        #print('selected_data_samples[\'embedding\'].shape:', selected_data_samples['embedding'].shape)
        #print('selected_data_samples[\'embedding\']:', selected_data_samples['embedding'])
        selected_embeddings = torch.stack(selected_data_samples['embedding'].to_list())
        return selected_embeddings
    
    def load_data(self, idx):
        #print('load_data',idx)
        ct_id = self.labels_df['CT'].iloc[idx]
        file_path = os.path.join(self.data_dir, f"{ct_id}.npy")
        # WARN: some embedding files associated to CTs in the dataset may not exist: the corresponding CTs are not valid as they contained too few slices to be processed by the visual encoder
        if os.path.exists(file_path):
            sample_content = np.load(file_path)
            if sample_content.shape[0] == 1:
                sample_content = sample_content.squeeze(0)
            sample_content = torch.from_numpy(sample_content).detach().clone()
            sample_labels = np.array(self.labels_df.iloc[idx].iloc[1:4], dtype=int)
            sample_labels = torch.from_numpy(sample_labels).detach().clone()
            return (sample_content, sample_labels)
        
    def get_batch_positive_negative_pairs(self, ids_list):
        logging.info(f"Generating triplets for batch {ids_list}")
        assert(len(ids_list) > 0)
        positives_dict = dict() # positives_dict := { anchor_id : list_of_positives_ids }
        negatives_dict = dict() # negatives_dict _= { anchor_id : (ids not contained in list_of_positives_ids) }
        positive_candidates_label_vectors = list(np.array(self.labels_df.iloc[ids_list].iloc[:,1:4], dtype=int))
        
        for j in ids_list:
            positives_list = list()
            negatives_list = list()
            anchor_label_vector = list(np.array(self.labels_df.iloc[j].iloc[1:4], dtype=int))
            for i, positive_candidate_label_vector in zip(ids_list, positive_candidates_label_vectors):
                if i == j:
                    continue
                distance = hamming(anchor_label_vector, positive_candidate_label_vector)
                if distance == 0.0:
                    positives_list.append(i)
                else:
                    negatives_list.append(i)
            positives_dict[j] = positives_list
            negatives_dict[j] = negatives_list
        
        positive_pairs = list() # positive_pairs := list of pairs [anchor_id, positive_example_id]
        negative_candidates = dict() # negative_candidates := dict of form {anchor_id : [negative_example_1, negative_example_2, ...]]}
        for k in positives_dict:
            if len(positives_dict[k]) > 0:
                positive_pairs.append([k, random.choice(positives_dict[k])])
            else:
                logging.info(f"Could not construct a positive pair for anchor ID {k} in this batch. Skipping...")
                continue
            if len(negatives_dict[k]) > 0:
                negative_candidates[k] = negatives_dict[k]
            else:
                logging.info(f"Could not construct a negative pair for anchor ID {k} in this batch. Skipping...")
                continue
        return positive_pairs, negative_candidates





'''
# Create an instance of the dataset
dataset = CustomDataset(CT_RESNET18_EMBEDDINGS_PATH, DR70_LABELS_PATH)

# Define the sizes of train, validation, and test sets
train_size = int(0.7 * len(dataset))  # 70% of the data for training
val_size = int(0.15 * len(dataset))     # 15% of the data for validation
test_size = len(dataset) - train_size - val_size  # Remaining data for testing

'''
#If necessary, adjust the size of the training set when the sum of 
#the sizes of the three sets differs from the total dataset size.
'''
size_diff = len(dataset) - train_size - val_size - test_size
train_size += size_diff

# Use random_split to split the dataset
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])


training_sampler = RandomSampler(train_dataset, replacement=False, generator=torch.Generator().manual_seed(42))
batch_sampler = BatchSampler(sampler=training_sampler, batch_size=8, drop_last=True)

training_dataloader = torch.utils.data.DataLoader(
    train_dataset, 
    batch_size=8, 
    shuffle=True, 
    generator=torch.Generator().manual_seed(42)
)
ct_embedding_model = CTModel()
triplet_model = TripletModel(model=ct_embedding_model, mode='HN', seed=42)

e = dataset.get_updated_embeddings(triplet_model, torch.device("cuda"))

print(dataset.get_embeddings([0]))


training_sampler = RandomSampler(train_dataset, replacement=False, generator=torch.Generator().manual_seed(42))
val_sampler = RandomSampler(val_dataset, replacement=False, generator=torch.Generator().manual_seed(42))

print(train_dataset[:][0])
print(val_dataset[:][0])
print([e for e in training_sampler])
print([e for e in val_sampler])


# You can now create data loaders for each split if needed

# Create a RandomSampler with seed 42 and no replacement
training_sampler = RandomSampler(train_dataset, replacement=False, generator=torch.Generator().manual_seed(42))
batch_sampler = BatchSampler(sampler=training_sampler, batch_size=8, drop_last=True)
        


import time
lim = 99
for batch in batch_sampler:
    if lim == 0:
        break
    lim -= 1
    print('')
    print('Batch:', batch)  # Do whatever you want with the batch data
    ids = list()
    sample_data = list()
    labels_data = list()
    idx, sample, labels = dataset[batch]
    print('Labels:', labels_data)
    positive, negative = dataset.get_batch_positive_negative_pairs(batch)
    print('Positive pairs:', positive)
    print('Negative pairs:', negative)
    print('sample shape', sample.shape)
    time.sleep(0.5) # just to make the logging and prints to output in the correct order!
'''


'\nsize_diff = len(dataset) - train_size - val_size - test_size\ntrain_size += size_diff\n\n# Use random_split to split the dataset\ntrain_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])\n\n\ntraining_sampler = RandomSampler(train_dataset, replacement=False, generator=torch.Generator().manual_seed(42))\nbatch_sampler = BatchSampler(sampler=training_sampler, batch_size=8, drop_last=True)\n\ntraining_dataloader = torch.utils.data.DataLoader(\n    train_dataset, \n    batch_size=8, \n    shuffle=True, \n    generator=torch.Generator().manual_seed(42)\n)\nct_embedding_model = CTModel()\ntriplet_model = TripletModel(model=ct_embedding_model, mode=\'HN\', seed=42)\n\ne = dataset.get_updated_embeddings(triplet_model, torch.device("cuda"))\n\nprint(dataset.get_embeddings([0]))\n\n\ntraining_sampler = RandomSampler(train_dataset, replacement=False, generator=torch.Generator().manual_seed(42))\nval_sampler = RandomSampler(val_dataset, replacement=Fal

In [24]:
val_dataset[2][0]

tensor([66])

In [11]:
import numpy as np

def recall(labels_gt, labels_prediction):
    """
    set of ground truth labels and list of prediction lists. 
    This is an independent of k way to calculate the recall.
    """
    n = len(labels_prediction)
    recall = 0
    for i in range(n):
        recall += len(set(labels_gt) & set(labels_prediction[i]))/len(labels_gt)
    return recall/n

def get_batch_data(data, index, size):
    """
    For minibatch training
    """
    column_1 = []
    column_2 = []
    for i in range(index, index + size):
        line = data[i]
        # anchor image
        column_1.append(int(line[0]))
        # positive image
        column_2.append(int(line[1]))
    return np.array(column_1), np.array(column_2)


In [2]:
import torch
import torch.nn as nn
from transformers import ConvNextImageProcessor, ResNetForImageClassification, ResNetConfig
import os
import nibabel as nib
import numpy as np
from scipy.ndimage import zoom
from torchvision import models, utils
from torchvision.transforms import v2

class NiftiLoader:
    def __init__(self, folder_path, ct_target_depth=384, resize_strategy='nn_interpolation'):
        self.folder_path = folder_path
        self.ct_target_depth = ct_target_depth
        self.file_list = [file_name for file_name in os.listdir(self.folder_path) if file_name.endswith(".nii.gz")]
        self.resize_strategy = resize_strategy
    
    def __iter__(self):
        self.current_idx = 0
        return self

    def get_ct_id_dict(self):
        ct_id_dict = dict((file_name.split('.nii.gz')[0], i) for (i, file_name) in enumerate(self.file_list))
        return ct_id_dict

    def get_ct_info(self, ct_id):
        file_name = self.file_list[ct_id]
        file_path = os.path.join(self.folder_path, file_name)

        # Load the NIfTI file
        nifti_image = nib.load(file_path)
        return nifti_image.header

    def _fix_depth(self, pixel_ndarray):
        if pixel_ndarray.shape[0] < self.ct_target_depth:
            zoom_factor = self.ct_target_depth / pixel_ndarray.shape[0]
            frame_diff = self.ct_target_depth - pixel_ndarray.shape[0]
            x_size = pixel_ndarray.shape[1]
            y_size = pixel_ndarray.shape[2]
            if self.resize_strategy == "zero_padding":
                pixel_ndarray = np.append(pixel_ndarray, np.zeros(shape=(frame_diff, x_size, y_size)), axis=0)
            elif self.resize_strategy == "nn_interpolation":
                pixel_ndarray = zoom(pixel_ndarray, (zoom_factor, 1, 1), order=0)
            else:
                raise ValueError(f"Resize strategy '{self.resize_strategy}' not supported. Allowed values are 'zero_padding' and 'nn_interpolation'.")
        else:
            fixed_pixel_ndarray = pixel_ndarray.astype(np.int16)[0:self.ct_target_depth]
        return fixed_pixel_ndarray
    
    def __next__(self):
        if self.current_idx < len(self.file_list):
            file_name = self.file_list[self.current_idx]
            file_path = os.path.join(self.folder_path, file_name)

            # Extract CT_ID from the file name
            ct_id = file_name.split('.nii.gz')[0]

            # Load the NIfTI file
            nifti_image = nib.load(file_path)

            # Access the 3D numpy array from the NIfTI image
            ct_volume = nifti_image.get_fdata()

            ct_volume = self._fix_depth(ct_volume)

            # Increment the index for the next iteration
            self.current_idx += 1

            # Return the tuple (CT_ID, CT_3D_array)
            return ct_id, ct_volume
        else:
            raise StopIteration

    def __getitem__(self, idx):
        if idx < len(self.file_list):
            file_name = self.file_list[idx]
            file_path = os.path.join(self.folder_path, file_name)

            # Extract CT_ID from the file name
            ct_id = file_name.split('.nii.gz')[0]

            # Load the NIfTI file
            nifti_image = nib.load(file_path)

            # Access the 3D numpy array from the NIfTI image
            ct_volume = nifti_image.get_fdata()

            ct_volume = self._fix_depth(ct_volume)

            # Return the tuple (CT_ID, CT_3D_array)
            return ct_id, ct_volume
        else:
            raise IndexError("Index out of range")


class ResNetImagePreprocessor(nn.Module):
    def __init__(self, ct_shape):
        super().__init__()
        self.image_preprocessor = ConvNextImageProcessor(
            do_resize=False,
            do_rescale=True,
            rescale_factor=1/255,
            do_normalize=True,
            image_mean=[0.485, 0.456, 0.406],
            image_std=[0.229, 0.224, 0.225],
        )
        self.n_slices = ct_shape[0]
        self.slice_h = ct_shape[1]
        self.slice_w = ct_shape[2]
        preprocessed_slice_thrices = torch.empty((0, self.n_slices // 3, self.slice_h, self.slice_w), dtype=torch.float32)
        self.register_buffer('preprocessed_slice_thrices', preprocessed_slice_thrices, persistent=False)
        preprocessed_img = torch.empty((3, self.slice_h, self.slice_w), dtype=torch.float32)
        self.register_buffer('preprocessed_img', preprocessed_img, persistent=False)
    
    def forward(self, x):
        for ct in x:
            for i in range(self.n_slices // 3):
                threechannel_img = ct[ 3*i : 3*i + 3 ]
                print(self.preprocessed_img)
                preprocess_output = self.image_preprocessor.preprocess(threechannel_img)['pixel_values']
                self.preprocessed_img = torch.tensor(np.array(preprocess_output)).detach().clone()
                print(self.preprocessed_img.shape)
                print(self.preprocessed_img)
                self.preprocessed_img = self.preprocessed_img.unsqueeze(0)
                self.preprocessed_slice_thrices = torch.cat((self.preprocessed_slice_thrices, self.preprocessed_img), 0)
        return self.preprocessed_slice_thrices

class ResNetCTFeatureExtractor(nn.Module):
    def __init__(self):
        super().__init__()
        # Load the ResNet model and modify the its configuration to make it output its hidden states
        resnet_config = ResNetConfig.from_pretrained("microsoft/resnet-18", output_hidden_states=True)
        self.resnet_model = ResNetForImageClassification(resnet_config)
        resnet_hidden_states = torch.empty((0, 512, 128, 6, 6), dtype=torch.float32)
        self.register_buffer('resnet_hidden_states', resnet_hidden_states, persistent=False)

    def forward(self, x):
        n_images = x.shape[0]
        for img in range(n_images):
            resnet_hidden_states = torch.cat((resnet_hidden_states, self.resnet_model(preprocessed_img)['hidden_states'][-1]), 0)
        return resnet_hidden_states

class CTModel(nn.Module):
    def __init__(self):
        super(CTModel, self).__init__()
        
        resnet = models.resnet18(pretrained=True)
        self.features = nn.Sequential(*(list(resnet.children())[:-2]))

        resnet_config = ResNetConfig.from_pretrained("microsoft/resnet-18", output_hidden_states=True)
        self.resnet_model = ResNetForImageClassification(resnet_config)

        #conv input torch.Size([1,134,512,14,14])
        self.reducingconvs = nn.Sequential(
            nn.Conv3d(134, 64, kernel_size = (3,3,3), stride=(3,1,1), padding=0),
            nn.ReLU(),
            
            nn.Conv3d(64, 32, kernel_size = (3,3,3), stride=(3,1,1), padding=0),
            nn.ReLU(),
            
            nn.Conv3d(32, 16, kernel_size = (3,2,2), stride=(3,2,2), padding=0),
            nn.ReLU())
        
        self.fc = nn.Sequential(
            nn.Linear(16*18*6*6, 512),
            nn.ReLU(True),
            nn.Dropout(0.5),
            
            nn.Linear(512, 256), 
            nn.ReLU(True),
            nn.Dropout(0.5),
            
            nn.Linear(256, 128))

        def forward(self, x):
            shape = list(x.size())
            #example shape: [1,134,3,420,420]
            #example shape: [2,134,3,420,420]
            batch_size = int(shape[0])
            x = x.view(batch_size*134,3,512,512)
            x = self.features(x)
            x = x.view(batch_size,134,512,16,16)
            x = self.reducingconvs(x)
            #output is shape [batch_size, 16, 18, 6, 6]
            x = x.view(batch_size, 16*18*6*6)
            x = self.fc(x)
            return x
        
        # Convolutional Layers
        self.conv1 = nn.Sequential( 
            nn.Conv3d(512, 256, kernel_size=(3, 3, 3), stride=(3, 1, 1), padding=0),
            nn.BatchNorm3d(256),
            nn.ReLU()
        )
        self.conv2 = nn.Sequential( 
            nn.Conv3d(256, 128, kernel_size=(3, 3, 3), stride=(3, 1, 1), padding=0),
            nn.BatchNorm3d(128),
            nn.ReLU()
        )
        self.conv3 = nn.Sequential( 
            nn.Conv3d(128, 64, kernel_size=(3, 2, 2), stride=(3, 2, 2), padding=0),
            nn.BatchNorm3d(64),
            nn.ReLU(),
        )
        self.fc1 = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64*4*6*6, 1024)
        )

            
    def old_forward(self, x):

        # ResNet
        x = self.resnet(x)
        
        # Convolutions
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)

        # Fully Connected Layers
        x = self.fc1(x)

        return x



In [127]:
ct_embedding_model = CTModel()
x = torch.randn(1, 384, 512, 512)


device = torch.device("cuda")
ct_embedding_model.to(device)
x = x.to(device)

out = ct_embedding_model(x)

print(out.shape)

tensor([[[ 0.0000,  0.0000, -2.1103,  ..., -2.1365, -2.1140, -2.0976],
         [-2.1154, -2.1440, -2.1242,  ..., -2.1486, -2.1064, -2.1364],
         [-2.1331, -2.0919, -2.0996,  ..., -2.1276, -2.1228, -2.1061],
         ...,
         [-2.1036, -2.1358, -2.1417,  ..., -2.1247, -2.0925, -2.1146],
         [-2.0729, -2.1120, -2.0932,  ..., -2.1170, -2.1351, -2.1147],
         [-2.1249, -2.0949, -2.1047,  ..., -2.1516, -2.1088, -2.1403]],

        [[-2.0411, -2.0560, -2.0577,  ..., -2.0542, -2.0417, -2.0401],
         [-2.0028, -2.0176, -2.0425,  ..., -2.0406, -2.0270, -2.0252],
         [-2.0171, -2.0501, -2.0557,  ..., -2.0297, -2.0359, -2.0082],
         ...,
         [-2.0195, -2.0319, -2.0438,  ..., -2.0374, -2.0019, -2.0422],
         [-2.0233, -2.0393, -2.0251,  ..., -2.0455, -2.0344, -2.0686],
         [-2.0587, -2.0272, -2.0577,  ..., -2.0522, -2.0340, -2.0501]],

        [[-1.8047, -1.8476, -1.8089,  ..., -1.8045, -1.7889, -1.8189],
         [-1.8062, -1.8002, -1.8065,  ..., -1

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument tensors in method wrapper_CUDA_cat)

In [37]:
print(torch.cuda.is_available())

True


In [66]:
from torch.utils.data import Dataset, random_split, RandomSampler, BatchSampler
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
import time
import logging

random.seed(42)

logging.basicConfig(level=logging.INFO)

class TripletModel:
    def __init__(
        self,
        model,
        mode='HN',
        seed=42,
    ):
        """
        mode: (negative sampling) "random" or "HN"
        estimator: "Linear" or "net"
        """

        # dim latent space
        #self.factor = factor 

        #VSE++
        #self.batch_size = 30  

        #ResNet18
        self.batch_size = 8

        #VSE++
        #self.epochs = 30

        # ResNet18
        self.epochs = 1
        
        self.model = model
        
        self.device = None
        
        if not torch.cuda.is_available():
            self.device = torch.device("cuda")
            self.model.to(self.device)
        
        #VSE++
        self.optimizer = torch.optim.Adam(params=self.model.parameters(), lr=0.0002)
        self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=15, gamma=0.1)

        # FF net or LINEAR
        #self.estimator = estimator
        
        # Save checkpoints to this path
        self.checkpoint_path = TRIPLET_MODELS_PATH
        
        # Negative sampling: 'HN' or 'random'
        self.mode = mode
        
        # Triplet loss' alpha
        self.alpha = 0.2
        
        self.seed = seed
    
    def get_triplets(self, ids_list):
        # construct positive pairs and negative-compatibles from the data points in the mini-batch
        #print('ids_list:', ids_list)
        positive_pairs, negative_compatibles = self.dataset.get_batch_positive_negative_pairs(ids_list)
        triplets = list()

        if len(positive_pairs) == 0 or len(negative_compatibles) == 0:
            logging.warning(f"Cannot build triplets from current batch! {len(positive_pairs)} positives and {len(negative_compatibles)} negative pair candidates found.")
            return None #skip batch

        if self.mode == "random":
            negative_pairs = [[a, random.choice(negative_compatibles[neg_id])] for (a, neg_id) in negative_compatibles]
            # a_p and a_n should always be the same value, so it doesn't matter which one we choose to build the triplet
            triplets = torch.tensor([[a_p, p, n] for ([a_p, p], [a_n, n]) in zip(positive_pairs, negative_pairs)])

        elif self.mode == "HN":
            anchors_list = [a for (a, _) in positive_pairs]
            #embeddings = self.dataset.get_embeddings(ids_list)
            anchors_embs = self.dataset.get_embeddings(anchors_list)
            negative_pairs = list()
            for anchor_id, anchor_emb in zip(anchors_list, anchors_embs):
                negatives_compatibles_ids = negative_compatibles[anchor_id]
                n_compatibles = len(negatives_compatibles_ids)
                negatives_compatibles_embs = self.dataset.get_embeddings(negatives_compatibles_ids)
                negatives_compatibles_embs = negatives_compatibles_embs.squeeze(1)
                anchor_emb_repeat = anchor_emb.repeat(n_compatibles, 1)
                #anchor_emb_repeat = torch.fill(torch.empty(n_compatibles), anchor_emb)
                a_n_pairwise_similarities = torch.nn.functional.cosine_similarity(anchor_emb_repeat, negatives_compatibles_embs)
                most_similar_id = torch.argmax(a_n_pairwise_similarities)
                negative_pairs.append([anchor_id, negatives_compatibles_ids[most_similar_id]])
                #print('\nSTART ANCHOR\n')
                #print('anchors_list', anchors_list)
                #print('anchor_id', anchor_id)
                #print('negatives_compatibles_ids', negatives_compatibles_ids)
                #print('anchor_emb_repeat', anchor_emb_repeat.shape, anchor_emb_repeat)
                #print('negatives_compatibles_embs', negatives_compatibles_embs.shape, negatives_compatibles_embs)
                #print('a_n_pairwise_similarities', a_n_pairwise_similarities)
                #print('most_similar_id', most_similar_id)
                #print('\nEND ANCHOR\n')
            # a_p and a_n should always be the same value, so it doesn't matter which one we choose to build the triplet
            triplets = torch.tensor([[a_p, p, n] for ([a_p, p], [a_n, n]) in zip(positive_pairs, negative_pairs)])
        if self.device:
            triplets = triplets.to(self.device)
        return triplets
        
        
    def training(self, dataset, train_frac=0.6, batch_size=8, epochs=1):
        """
        Training process
        """
        
        assert train_frac > 0.0 and 1.0 >= train_frac
        
        self.dataset = dataset
        self.batch_size = batch_size
        self.epochs = epochs
        
        # Define the sizes of train, validation, and test sets
        train_size = int(train_frac * len(self.dataset))  # percentage of the data for training
        val_size = int((1.0 - train_frac)/1 * len(self.dataset)) # half of the remaining for validation
        test_size = len(self.dataset) - train_size - val_size  # the rest of the remaining for testing

        '''
        If necessary, adjust the size of the training set when the sum of 
        the sizes of the three sets differs from the total dataset size.
        '''
        size_diff = len(self.dataset) - train_size - val_size - test_size
        train_size += size_diff

        # Use random_split to split the dataset
        self.train_dataset, self.val_dataset, self.test_dataset = random_split(dataset=self.dataset, lengths=[train_size, val_size, test_size], generator=torch.Generator().manual_seed(self.seed))
        
        # Create a RandomSampler with a predetermined seed and no replacement
        #self.training_sampler = RandomSampler(self.train_dataset, replacement=False, generator=torch.Generator().manual_seed(seed))
        self.validation_sampler = RandomSampler(self.val_dataset, replacement=False, generator=torch.Generator().manual_seed(self.seed))
        self.test_sampler = RandomSampler(self.test_dataset, replacement=False, generator=torch.Generator().manual_seed(self.seed))

        #self.train_batch_sampler = BatchSampler(sampler=self.training_sampler, batch_size=self.batch_size, drop_last=True)
        self.val_batch_sampler = BatchSampler(sampler=self.validation_sampler, batch_size=len(self.val_dataset), drop_last=True)
        self.test_batch_sampler = BatchSampler(sampler=self.test_sampler, batch_size=len(self.test_dataset), drop_last=True)
        
        # self.dataset.shape := (number of data points, number of classes)
        #self.n_classes = self.dataset.shape[1]

        self.recalls = [] # recall 1, 10, 25
        self.loss_per_epoch = []
        self.step_losses = []
        self.val_loss_per_epoch = []
        
        timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
        self.tb_writer = SummaryWriter(os.path.join(TRIPLET_MODELS_PATH, "runs", "ct_retrieval_trainer_{}".format(timestamp)))
        
        start_time = time.time()
        
        for epoch in range(self.epochs):
            logging.info(f"\n--------------------------------\nEpoch {epoch+1} of {self.epochs}\n--------------------------------\n")

            batch_counter = 0 # counter of processed batches across training
            
            epoch_losses = list()
            
            # change the random seed and reconstruct training_sampler and train_batch_sampler at each epoch
            # this way we can assemble never-before-seen triplets to train on.
            self.training_sampler = RandomSampler(self.train_dataset, replacement=False, generator=torch.Generator().manual_seed(self.seed + epoch))
            self.train_batch_sampler = BatchSampler(sampler=self.training_sampler, batch_size=self.batch_size, drop_last=True)

            # recalculate dataset embeddings at this epoch
            dataset_embeddings = self.dataset.get_updated_embeddings(self.model, self.device)
            
            for batch_index, training_batch in enumerate(self.train_batch_sampler):
                batch_counter += 1
                logging.info(f"=> Batch {batch_counter}:")

                # load batch data
                dataset_batch_ids, batch_embeddings, batch_labels = self.train_dataset[training_batch]

                batch_id_map = dict((i, j) for (i, j) in zip(dataset_batch_ids.tolist(), training_batch))

                # dataset_batch_ids != training_batch bacause dataset_batch_ids corresponds to the original
                # unsplitted dataset while training_batch maps to the training dataset

                if self.device is not None:
                    batch_embeddings = batch_embeddings.to(self.device)
                    
                batch_triplets = self.get_triplets(dataset_batch_ids.tolist())
                batch_triplets = [[batch_id_map[a], batch_id_map[p], batch_id_map[n]] for [a, p, n] in batch_triplets.tolist()]
                batch_triplets = torch.tensor(batch_triplets)
                #print('training_batch:', training_batch)
                #print('dataset_batch_ids:', dataset_batch_ids)
                #print('batch_triplets:', batch_triplets)
                
                if batch_triplets is None:
                    # skip empty batch
                    continue
                
                # perform forward pass over all samples in the batch
                self.model.train(True)
                self.optimizer.zero_grad()
                outputs = self.model(batch_embeddings)

                # map the dataset IDs referenced in the triplets to the indices in the batched model output
                #print('batch_ids:', mini_batch)
                #print('triplets before mapping:\n', triplets)
                #triplets = self._map_triplets(mini_batch, triplets)
                #print('triplets after mapping:\n', triplets)
                
                # calculate loss and gradients
                # IMPORTANT: the loss is calculated for each triplet (3-tuples of samples), therefore,
                # if there is a sample which is not part of a triplet, then it will not be needed in
                # calculation.
                batch_loss = self.loss(outputs, training_batch, batch_triplets, alpha=self.alpha)
                batch_loss.requires_grad = True
                batch_loss.backward()

                # update model weights
                self.optimizer.step()

                self.tb_writer.add_scalar('Steps: Loss/train', batch_loss, global_step=batch_counter)
                
                epoch_losses.append(batch_loss)
                self.step_losses.append(batch_loss)
                logging.info(f"Batch {batch_counter} loss: {round(batch_loss.item(), 4)}")
            
            # finished iterating the batches
            epoch_mean_loss = (sum(epoch_losses) / len(epoch_losses)).item()
            self.tb_writer.add_scalar('Epochs: Loss/train', epoch_mean_loss, global_step=epoch+1)
            self.loss_per_epoch.append(epoch_mean_loss)

            logging.info(f'Training loss: {round(epoch_mean_loss, 4)}')
            
            
            # evaluate on validation set
            logging.info("Evaluating...")
            self.model.eval()
            val_epoch_losses = list()
            with torch.no_grad():
                for val_batch_index, val_batch in enumerate(self.val_batch_sampler):

                    val_batch_ids, val_batch_embeddings, val_batch_labels = self.val_dataset[val_batch]

                    val_batch_id_map = dict((i, j) for (i, j) in zip(val_batch_ids.tolist(), val_batch))
                    

                    if self.device is not None:
                        val_batch_embeddings = val_batch_embeddings.to(self.device)
                        
                    val_batch_triplets = self.get_triplets(val_batch_ids.tolist())
                    val_batch_triplets = [[val_batch_id_map[a], val_batch_id_map[p], val_batch_id_map[n]] for [a, p, n] in val_batch_triplets.tolist()]
                    val_batch_triplets = torch.tensor(val_batch_triplets)
                    
                    if val_batch_triplets is None:
                        # skip validation if there was no triplets to evaluate
                        continue

                    # map the dataset IDs referenced in the triplets to the indices in the batched model output

                    #val_triplets = self._map_triplets(val_batch, val_triplets)

                    

                    # perform evaluation on validation data
                    val_outputs = self.model(val_batch_embeddings)

                    # calculate validation loss
                    val_batch_loss = self.loss(val_outputs, val_batch, val_batch_triplets, alpha=self.alpha)
                    val_epoch_losses.append(val_batch_loss)
                    logging.info(f"Validation batch {val_batch_index+1} loss: {round(val_batch_loss.item(), 4)}")
                
                # finished iterating the batches
                val_mean_loss = (sum(val_epoch_losses) / len(val_epoch_losses)).item()
                self.tb_writer.add_scalars(
                    'Training vs. Validation Loss',
                    {
                        'Training': epoch_mean_loss, 
                        'Validation': val_mean_loss,
                    }, 
                    epoch + 1
                )
                self.val_loss_per_epoch.append(val_mean_loss)

                logging.info(f'Validation loss: {round(val_mean_loss, 4)}')
            
                # end batch iteration
            logging.info("Done evaluating.")
            if(self.val_loss_per_epoch[-1] == min(self.val_loss_per_epoch)):
                logging.info("Best validation loss achieved! Saving checkpoint...")
                self.save(
                    epoch,
                    self.loss_per_epoch, 
                    self.step_losses, 
                    self.val_loss_per_epoch, 
                    self.recalls,
                    time.time()-start_time,
                    os.path.join(TRIPLET_CHECKPOINTS_PATH, f'ct_retrieval_trainer_{timestamp}.pth'),
                )
                logging.info("Checkpoint saved.")

            # update optimizer scheduler ater completing an epoch
            self.scheduler.step()
            # end epoch iteration
        #print('train_dataset:', self.train_dataset[0:len(self.train_dataset)][0])
        #print('val_dataset:', self.val_dataset[0:len(self.val_dataset)][0])

    def _map_triplets(self, batch_ids, triplets):
        a_tensor = triplets.select(1, 0)
        p_tensor = triplets.select(1, 1)
        n_tensor = triplets.select(1, 2)
        a_items = [e.item() for e in a_tensor]
        p_items = [e.item() for e in p_tensor]
        n_items = [e.item() for e in n_tensor]
        a_index = list()
        p_index = list()
        n_index = list()
        for (a, p, n) in zip(a_items, p_items, n_items):
            a_index.append(batch_ids.index(a))
            p_index.append(batch_ids.index(p))
            n_index.append(batch_ids.index(n))
        mapped_triplets = torch.tensor([a_index, p_index, n_index])
        mapped_triplets = mapped_triplets.transpose(0, 1)
        return mapped_triplets

    def _max(self, a, b):
        if a > b:
            return a
        return b
    
    def loss(self, embeddings, embeddings_ids, triplets_ids, alpha=0.2):
        
        """
        for computing the triplet loss [alpha + negative_similarity - positive_similarity]_{+}
        """
        anchors_ids = torch.select(triplets_ids, 1, 0).tolist()
        positives_ids = torch.select(triplets_ids, 1, 1).tolist()
        negatives_ids = torch.select(triplets_ids, 1, 2).tolist()

        ids_map = dict((j, i) for i, j in enumerate(embeddings_ids))
        
        anchors = embeddings[[ids_map[a] for a in anchors_ids]]
        positives = embeddings[[ids_map[p] for p in positives_ids]]
        negatives = embeddings[[ids_map[n] for n in negatives_ids]]
        
        cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
        
        positive_similarities = cos(anchors, positives)
        negative_similarities = cos(anchors, negatives)
        
        unrectified_batch_loss = alpha + negative_similarities - positive_similarities
        rectified_batch_loss = unrectified_batch_loss.detach().map_(torch.zeros(len(unrectified_batch_loss)), self._max)
        #rectified_batch_loss = torch.func.vmap(torch.max)(unrectified_batch_loss, torch.tensor(0))
        batch_loss = torch.mean(rectified_batch_loss)
        if self.device:
            batch_loss = batch_loss.to(device)
        return batch_loss
    
    def save(self, epoch, epoch_loss, step_loss, val_loss, recalls, running_time, directory):
        """
        For saving logs of the experiment
        """
        
        torch.save(
            {
                'model_state_dict': self.model.state_dict(),
                'optimizer_state_dict': self.scheduler.state_dict(),
                'epoch': epoch,
                'epoch_loss': epoch_loss,
                'step_loss': step_loss,
                'val_loss': val_loss,
                'recalls': recalls,
                'running_time': running_time,
            },
            directory
        )

In [29]:
'''    
        def predict(self,query,split):
        """
        the split defines where we are going to search
        """
        if type(query) is not list:
            query= [query]

        i_input = self.visual_encoder(tf.nn.embedding_lookup(visual_matrix, query))
        if split=="val":
                c_eval= self.visual_encoder(visual_matrix[self.val_ids])
        if split=="test":
                c_eval= self.visual_encoder(visual_matrix[self.test_ids])
        if split=="train":
                c_eval=self.visual_encoder(visual_matrix[self.train_ids])

        rating= tf.matmul(i_input, tf.transpose(c_eval) )
        return np.reshape(rating, [-1])

    def metric(self, split, k=10):
        """
        to compute metrics
        """
        if split == "val":
            ids, gt = self.val_ids, self.labels_val
        if split == "test":
            ids, gt = self.test_ids, self.labels_test

        metric=0

        for i in range(len(ids)): #querys
            top= np.argsort(self.predict(ids[i],split))[::-1][1:k+1] #first one must be the query
            #get labels
            prediction= [gt[j] for j in top]
            metric+=recall(gt[i], prediction)
        return metric/len(ids) #average

    def inference(self):
        """
        testing phase
        """
        results=(self.metric("test",k=1), self.metric("test",k=10), self.metric("test",k=25))
        print(results)
        np.save(self.DIS_MODEL_FILE+"test_metrics.npy", np.array([results], dtype=object) )
'''

'    \n        def predict(self,query,split):\n        """\n        the split defines where we are going to search\n        """\n        if type(query) is not list:\n            query= [query]\n\n        i_input = self.visual_encoder(tf.nn.embedding_lookup(visual_matrix, query))\n        if split=="val":\n                c_eval= self.visual_encoder(visual_matrix[self.val_ids])\n        if split=="test":\n                c_eval= self.visual_encoder(visual_matrix[self.test_ids])\n        if split=="train":\n                c_eval=self.visual_encoder(visual_matrix[self.train_ids])\n\n        rating= tf.matmul(i_input, tf.transpose(c_eval) )\n        return np.reshape(rating, [-1])\n\n    def metric(self, split, k=10):\n        """\n        to compute metrics\n        """\n        if split == "val":\n            ids, gt = self.val_ids, self.labels_val\n        if split == "test":\n            ids, gt = self.test_ids, self.labels_test\n\n        metric=0\n\n        for i in range(len(ids)):

In [65]:
ct_embedding_model = CTModel()
triplet_model = TripletModel(model=ct_embedding_model, mode='HN', seed=42)

# Create an instance of the dataset
dr70_dataset = CustomDataset(CT_RESNET18_EMBEDDINGS_PATH, DR70_LABELS_PATH)

triplet_model.training(dataset=dr70_dataset, train_frac=0.7, batch_size=8, epochs=10)

INFO:root:
--------------------------------
Epoch 1 of 10
--------------------------------

INFO:root:=> Batch 1:
INFO:root:Generating triplets for batch [67, 47, 18, 55, 39, 12, 45, 28]
INFO:root:Could not construct a positive pair for anchor ID 45 in this batch. Skipping...
INFO:root:Batch 1 loss: 0.2585
INFO:root:=> Batch 2:
INFO:root:Generating triplets for batch [50, 48, 53, 62, 15, 44, 19, 11]
INFO:root:Batch 2 loss: 0.2311
INFO:root:=> Batch 3:
INFO:root:Generating triplets for batch [25, 34, 56, 6, 41, 31, 21, 61]
INFO:root:Could not construct a positive pair for anchor ID 25 in this batch. Skipping...
INFO:root:Could not construct a positive pair for anchor ID 56 in this batch. Skipping...
INFO:root:Batch 3 loss: 0.2167
INFO:root:=> Batch 4:
INFO:root:Generating triplets for batch [43, 3, 42, 35, 68, 32, 30, 23]
INFO:root:Could not construct a positive pair for anchor ID 35 in this batch. Skipping...
INFO:root:Batch 4 loss: 0.2619
INFO:root:=> Batch 5:
INFO:root:Generating tri

train_dataset: tensor([39, 32, 18, 61, 60, 12, 67, 55, 15, 62, 14,  3, 53, 31,  4, 19, 47, 44,
        21, 25, 49, 41, 23, 26, 68, 58, 30, 42,  6, 48, 57, 16, 17,  8, 28, 43,
        65, 11, 50,  0, 40, 63, 35,  9, 51, 56, 45, 34])
val_dataset: tensor([10, 66, 29, 38, 52, 64,  7, 20, 37, 36, 27, 13, 33, 59,  5, 46,  1, 54,
        22,  2])
