In [1]:
# Paths & URLs

import os

# Enable CUDA stacktrace reporting for debugging
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

# Directorio base
#PATH_BASE = '/content/drive/MyDrive/proximity'
PATH_BASE = 'C:\\Users\\User\\Documents\\Proyecto Proximity'

# Data release 70
DR70_PATH = os.path.join(PATH_BASE, 'DR70')
DR70_CT_PATH = os.path.join(DR70_PATH, 'datalake_sorted')
DR70_LABELS_PATH = os.path.join(DR70_PATH, 'labels.csv')

# Data release 176
DR176_PATH = os.path.join(PATH_BASE, 'DR176')
DR176_CT_PATH = os.path.join(DR176_PATH, 'DR176_studies')
DR176_LABELS_PATH = os.path.join(PATH_BASE, 'reports_with_label.csv')


# CTs in Nibabel format
CT_NIBABEL_PATH = os.path.join(PATH_BASE, 'DR70', 'CTs')

# Embeddings visuales de CTs
#CT_EMBEDDINGS_PATH = DATA_RELEASE_PATH + '/visual_embeddings'


# Etiquetas de los CTs del data release actual
#CT_LABELS_CSV_PATH = DATA_RELEASE_PATH + '/labels.csv'

# Data release (CTs + etiquetas) organizados en un DataFrame
#CT_DATASET_DF_HDF_PATH = os.path.join(PATH_BASE, 'dataset_df.h5')
#CT_DATASET_DF_PICKLE_PATH = os.path.join(PATH_BASE, 'dataset_df.pickle')

# URLs de modelos visuales
#RESNET18_URL = 'microsoft/resnet-18'

# Path que contiene los resnet50 embeddings de CTs del data release actual
#CT_RESNET18_EMBEDDINGS_PATH = os.path.join(DR70_PATH, 'visual_embeddings', 'resnet18')
#CT_RESNET18_EMBEDDINGS_PATH = os.path.join(DR70_PATH, 'visual_embeddings', 'resnet18', 'reshaped_averaged')

# Path de modelos entrenados en base a tripletas
TRIPLET_MODELS_PATH = os.path.join(PATH_BASE, 'retrieval_models', 'triplets')
TRIPLET_CHECKPOINTS_PATH = os.path.join(PATH_BASE, 'retrieval_models', 'triplets', 'checkpoints')

In [24]:
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import os
import pickle
import pandas as pd



#Set seeds
np.random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed(0)
torch.cuda.manual_seed_all(0)


"""CT volume preprocessing functions"""

#############################################
# Pixel Values (on torch Tensors for speed) #-----------------------------------
#############################################
def normalize(ctvol, lower_bound, upper_bound): #Done testing
    """Clip images and normalize"""
    #formula https://stats.stackexchange.com/questions/70801/how-to-normalize-data-to-0-1-range
    ctvol = torch.clamp(ctvol, lower_bound, upper_bound)
    ctvol = (ctvol - lower_bound) / (upper_bound - lower_bound)
    return ctvol

def torchify_pixelnorm_pixelcenter(ctvol, pixel_bounds):
    """Normalize using specified pixel_bounds and then center on the ImageNet
    mean. Used in 2019_10 dataset preparation"""
    #Cast to torch Tensor
    #use torch Tensor instead of numpy array because addition, subtraction,
    #multiplication, and division are faster in torch Tensors than np arrays
    ctvol = torch.from_numpy(ctvol).type(torch.float)
    
    #Clip Hounsfield units and normalize pixel values
    ctvol = normalize(ctvol, pixel_bounds[0], pixel_bounds[1])
    
    #Center on the ImageNet mean since you are using an ImageNet pretrained
    #feature extractor:
    ctvol = ctvol - 0.449
    return ctvol

###########
# Padding #---------------------------------------------------------------------
###########
def pad_slices(ctvol, max_slices): #Done testing
    """For <ctvol> of shape (slices, side, side) pad the slices to shape
    max_slices for output of shape (max_slices, side, side)"""
    padding_needed = max_slices - ctvol.shape[0]
    assert (padding_needed >= 0), 'Image slices exceed max_slices by'+str(-1*padding_needed)
    if padding_needed > 0:
        before_padding = int(padding_needed/2.0)
        after_padding = padding_needed - before_padding
        ctvol = np.pad(ctvol, pad_width = ((before_padding, after_padding), (0,0), (0,0)),
                     mode = 'constant', constant_values = np.amin(ctvol))
        assert ctvol.shape[0]==max_slices
    return ctvol

def pad_sides(ctvol, max_side_length): #Done testing
    """For <ctvol> of shape (slices, side, side) pad the sides to shape
    max_side_length for output of shape (slices, max_side_length,
    max_side_length)"""
    needed_padding = 0
    for side in [1,2]:
        padding_needed = max_side_length - ctvol.shape[side]
        if padding_needed > 0:
            before_padding = int(padding_needed/2.0)
            after_padding = padding_needed - before_padding
            if side == 1:
                ctvol = np.pad(ctvol, pad_width = ((0,0), (before_padding, after_padding), (0,0)),
                         mode = 'constant', constant_values = np.amin(ctvol))
                needed_padding += 1
            elif side == 2:
                ctvol = np.pad(ctvol, pad_width = ((0,0), (0,0), (before_padding, after_padding)),
                         mode = 'constant', constant_values = np.amin(ctvol))
                needed_padding += 1
    if needed_padding == 2: #if both sides needed to be padded, then they
        #should be equal (but it's possible one side or both were too large
        #in which case we wouldn't expect them to be equal)
        assert ctvol.shape[1]==ctvol.shape[2]==max_side_length
    return ctvol

def pad_volume(ctvol, max_slices, max_side_length):
    """Pad <ctvol> to a minimum size of
    [max_slices, max_side_length, max_side_length], e.g. [402, 308, 308]
    Used in 2019_10 dataset preparation"""
    if ctvol.shape[0] < max_slices:
        ctvol = pad_slices(ctvol, max_slices)
    if ctvol.shape[1] < max_side_length:
        ctvol = pad_sides(ctvol, max_side_length)
    return ctvol

###########################
# Reshaping to 3 Channels #-----------------------------------------------------
###########################
def sliceify(ctvol): #Done testing
    """Given a numpy array <ctvol> with shape [slices, square, square]
    reshape to 'RGB' [max_slices/3, 3, square, square]"""
    return np.reshape(ctvol, newshape=[int(ctvol.shape[0]/3), 3, ctvol.shape[1], ctvol.shape[2]])

def reshape_3_channels(ctvol):
    """Reshape grayscale <ctvol> to a 3-channel image
    Used in 2019_10 dataset preparation"""
    if ctvol.shape[0]%3 == 0:
        ctvol = sliceify(ctvol)
    else:
        if (ctvol.shape[0]-1)%3 == 0:
            ctvol = sliceify(ctvol[:-1,:,:])
        elif (ctvol.shape[0]-2)%3 == 0:
            ctvol = sliceify(ctvol[:-2,:,:])
    return ctvol

##################################
# Cropping and Data Augmentation #----------------------------------------------
##################################
def crop_specified_axis(ctvol, max_dim, axis): #Done testing
    """Crop 3D volume <ctvol> to <max_dim> along <axis>"""
    dim = ctvol.shape[axis]
    if dim > max_dim:
        amount_to_crop = dim - max_dim
        part_one = int(amount_to_crop/2.0)
        part_two = dim - (amount_to_crop - part_one)
        if axis == 0:
            return ctvol[part_one:part_two, :, :]
        elif axis == 1:
            return ctvol[:, part_one:part_two, :]
        elif axis == 2:
            return ctvol[:, :, part_one:part_two]
    else:
        return ctvol

def single_crop_3d_fixed(ctvol, max_slices, max_side_length):
    """Crop a single 3D volume to shape [max_slices, max_side_length,
    max_side_length]"""
    ctvol = crop_specified_axis(ctvol, max_slices, 0)
    ctvol = crop_specified_axis(ctvol, max_side_length, 1)
    ctvol = crop_specified_axis(ctvol, max_side_length, 2)
    return ctvol

def single_crop_3d_augment(ctvol, max_slices, max_side_length):
    """Crop a single 3D volume to shape [max_slices, max_side_length,
    max_side_length] with randomness in the centering and random
    flips or rotations"""
    #Introduce random padding so that the centered crop will be slightly random
    ctvol = rand_pad(ctvol)
    
    #Obtain the center crop
    ctvol = single_crop_3d_fixed(ctvol, max_slices, max_side_length)
    
    #Flip and rotate
    ctvol = rand_flip(ctvol)
    ctvol = rand_rotate(ctvol)
    
    #Make contiguous array to avoid Pytorch error
    return np.ascontiguousarray(ctvol)

def rand_pad(ctvol):
    """Introduce random padding between 0 and 15 pixels on each of the 6 sides
    of the <ctvol>"""
    randpad = np.random.randint(low=0,high=15,size=(6))
    ctvol = np.pad(ctvol, pad_width = ((randpad[0],randpad[1]), (randpad[2],randpad[3]), (randpad[4], randpad[5])),
                         mode = 'constant', constant_values = np.amin(ctvol))
    return ctvol
    
def rand_flip(ctvol):
    """Flip <ctvol> along a random axis with 50% probability"""
    if np.random.randint(low=0,high=100) < 50:
        chosen_axis = np.random.randint(low=0,high=3) #0, 1, and 2 are axis options
        ctvol =  np.flip(ctvol, axis=chosen_axis)
    return ctvol

def rand_rotate(ctvol):
    """Rotate <ctvol> some random amount axially with 50% probability"""
    if np.random.randint(low=0,high=100) < 50:
        chosen_k = np.random.randint(low=0,high=4)
        ctvol = np.rot90(ctvol, k=chosen_k, axes=(1,2))
    return ctvol

###########################################
# 2019_10 Dataset Preprocessing Sequences #-------------------------------------
###########################################
def prepare_ctvol_2019_10_dataset(ctvol, pixel_bounds, data_augment, num_channels,
                                  crop_type):
    """Pad, crop, possibly augment, reshape to correct
    number of channels, cast to torch tensor (to speed up subsequent operations),
    Clip Hounsfield units, normalize pixel values, center on the
    ImageNet mean, and return as a torch tensor (for crop_type='single')
    
    <pixel_bounds> is a list of ints e.g. [-1000,200] Hounsfield units. Used for
        pixel value clipping and normalization.
    <data_augment> is True to employ data augmentation, and False otherwise
    <num_channels> is an int, e.g. 3 to reshape the grayscale volume into
        a volume of 3-channel images
    <crop_type>: if 'single' then return the volume as one 3D numpy array."""
    max_slices = 402
    max_side_length = 512
    assert num_channels == 3 or num_channels == 1
    assert crop_type == 'single'
    
    #Padding to minimum size [max_slices, max_side_length, max_side_length]
    ctvol = pad_volume(ctvol, max_slices, max_side_length)
    
    #Cropping, and data augmentation if indicated
    if crop_type == 'single':
        if data_augment is True:
            ctvol = single_crop_3d_augment(ctvol, max_slices, max_side_length)
        else:
            ctvol = single_crop_3d_fixed(ctvol, max_slices, max_side_length)
        #Reshape to 3 channels if indicated
        if num_channels == 3:
            ctvol = reshape_3_channels(ctvol)
        #Cast to torch tensor and deal with pixel values
        output = torchify_pixelnorm_pixelcenter(ctvol, pixel_bounds)
    
    return output



###################################################
# PACE Dataset for Data Stored in 2019-10-BigData #-----------------------------
###################################################
class CTDataset_2019_10(Dataset):    
    def __init__(self, setname, label_type_ld,
                 label_meanings, num_channels, pixel_bounds,
                 data_augment, crop_type,
                 selected_note_acc_files):
        """CT Dataset class that works for preprocessed data in 2019-10-BigData.
        A single example (for crop_type == 'single') is a 4D CT volume:
            if num_channels == 3, shape [134,3,420,420]
            if num_channels == 1, shape [402,420,420]
        
        Variables:
        <setname> is either 'train' or 'valid' or 'test'
        <label_type_ld> is 'disease_new'
        <label_meanings>: list of strings indicating which labels should
            be kept. Alternatively, can be the string 'all' in which case
            all labels are kept.
        <num_channels>: number of channels to reshape the image to.
            == 3 if the model uses a pretrained feature extractor.
            == 1 if the model uses only 3D convolutions.
        <pixel_bounds>: list of ints e.g. [-1000,200]
            Determines the lower bound, upper bound of pixel value clipping
            and normalization.
        <data_augment>: if True, perform data augmentation.
        <crop_type>: is 'single' for an example consisting of one 4D numpy array
        <selected_note_acc_files>: This should be a dictionary
            with key equal to setname and value that is a string. If the value
            is a path to a file, the file must be a CSV. Only note accessions
            in this file will be used. If the value is not a valid file path,
            all available note accs will be used, i.e. the model will be
            trained on the whole dataset."""
        self.setname = setname
        self.define_subsets_list()
        self.label_type_ld = label_type_ld
        self.label_meanings = label_meanings
        self.num_channels = num_channels
        self.pixel_bounds = pixel_bounds
        if self.setname == 'train':
            self.data_augment = data_augment
        else:
            self.data_augment = False
        print('For dataset',self.setname,'data_augment is',self.data_augment)
        self.crop_type = crop_type
        assert self.crop_type == 'single'
        self.selected_note_acc_files = selected_note_acc_files
        
        #Define location of the CT volumes
        self.main_clean_path = DR176_CT_PATH
        self.volume_log_df = pd.read_csv('./load_dataset/fakedata/CT_Scan_Preprocessing_Log_File_FINAL_SMALL.csv',header=0,index_col=0)
        
        #Get the example ids
        self.volume_accessions = self.get_volume_accessions()
                        
        #Get the ground truth labels
        self.labels_df = self.get_labels_df()
    
    # Pytorch Required Methods #------------------------------------------------
    def __len__(self):
        return len(self.volume_accessions)
        
    def __getitem__(self, idx):
        """Return a single sample at index <idx>. The sample is a Python
        dictionary with keys 'data' and 'gr_truth' for the image and label,
        respectively"""
        return self._get_pace(self.volume_accessions[idx])
    
    # Volume Accession Methods #------------------------------------------------
    def get_note_accessions(self):
        setname_file = self.selected_note_acc_files[self.setname]
        if os.path.isfile(setname_file):
            print('\tObtaining note accessions from',setname_file)
            sel_accs = pd.read_csv(setname_file,header=0)            
            assert sorted(list(set(sel_accs['Subset_Assigned'].values.tolist())))==sorted(self.subsets_list)
            note_accs = sel_accs.loc[:,'Accession'].values.tolist()
            print('\tTotal theoretical note accessions in subsets:',len(note_accs))
            return note_accs
        else: 
            print('\tObtaining note accessions from complete identifiers file')
            #Read in identifiers file, which contains note_accessions
            #Columns are MRN, Accession, Set_Assigned, Set_Should_Be, Subset_Assigned
            all_ids = pd.read_csv('./load_dataset/fakedata/all_identifiers.csv',header=0)
           
            #Extract the note_accessions
            note_accs = []
            for subset in self.subsets_list: #e.g. ['imgvalid_a','imgvalid_b']
                subset_note_accs = all_ids[all_ids['Subset_Assigned']==subset].loc[:,'Accession'].values.tolist()
                note_accs += subset_note_accs
            print('\tTotal theoretical note accessions in subsets:',len(note_accs))
            return note_accs
    
    def get_volume_accessions(self):
        note_accs = self.get_note_accessions()
        #Translate note_accessions to volume_accessions based on what data has been
        #preprocessed successfully. volume_log_df has note accessions as the
        #index, and the column 'full_filename_npz' for the volume accession.
        #The column 'status' should equal 'success' if the volume has been
        #preprocessed correctly.
        print('\tTotal theoretical volumes in whole dataset:',self.volume_log_df.shape[0])
        self.volume_log_df = self.volume_log_df[self.volume_log_df['status']=='success']
        print('\tTotal successfully preprocessed volumes in whole dataset:',self.volume_log_df.shape[0])
        volume_accs = []
        for note_acc in note_accs:
            if note_acc in self.volume_log_df.index.values.tolist():
                volume_accs.append(self.volume_log_df.at[note_acc,'full_filename_npz'])
        print('\tFinal total successfully preprocessed volumes in requested subsets:',len(volume_accs))
        #According to this thread: https://github.com/pytorch/pytorch/issues/13246
        #it is better to use a numpy array than a list to reduce memory leaks.
        return np.array(volume_accs)
    
    # Ground Truth Label Methods #----------------------------------------------
    def get_labels_df(self):
        #Get the ground truth labels based on requested label type.
        labels_df = read_in_labels(self.label_type_ld, self.setname)
        
        #Now filter the ground truth labels based on the desired label meanings:
        if self.label_meanings != 'all': #i.e. if you want to filter
            labels_df = labels_df[self.label_meanings]
        return labels_df
    
    # Fetch a CT Volume (__getitem__ implementation) #--------------------------
    def _get_pace(self, volume_acc):
        """<volume_acc> is for example RHAA12345_6.npz"""
        #Load compressed npz file: [slices, square, square]
        ctvol = np.load(os.path.join(self.main_clean_path, volume_acc))['ct']
        
        #Prepare the CT volume data (already torch Tensors)
        data = prepare_ctvol_2019_10_dataset(ctvol, self.pixel_bounds, self.data_augment, self.num_channels, self.crop_type)
        
        #Get the ground truth:
        note_acc = self.volume_log_df[self.volume_log_df['full_filename_npz']==volume_acc].index.values.tolist()[0]
        gr_truth = self.labels_df.loc[note_acc, :].values
        gr_truth = torch.from_numpy(gr_truth).squeeze().type(torch.float)
        
        #When training on only one abnormality you must unsqueeze to prevent
        #a dimensions error when training the model:
        if len(self.label_meanings)==1:
            gr_truth = gr_truth.unsqueeze(0)
        
        #Create the sample
        sample = {'data': data, 'gr_truth': gr_truth, 'volume_acc': volume_acc}
        return sample
    
    # Sanity Check #------------------------------------------------------------
    def define_subsets_list(self):
        assert self.setname in ['train','valid','test']
        if self.setname == 'train':
            self.subsets_list = ['imgtrain']
        elif self.setname == 'valid':
            self.subsets_list = ['imgvalid_a']
        elif self.setname == 'test':
            self.subsets_list = ['imgtest_a','imgtest_b','imgtest_c','imgtest_d']
        print('Creating',self.setname,'dataset with subsets',self.subsets_list)

#######################
# Ground Truth Labels #---------------------------------------------------------
#######################

def read_in_labels(label_type_ld, setname):
    """Return a pandas dataframe with the dataset labels.
    Accession numbers are the index and labels (e.g. "pneumonia") are the columns.
    <setname> can be 'train', 'valid', or 'test'."""
    assert label_type_ld == 'disease_new'
    labels_file = './load_dataset/fakedata/2019-12-18_duke_disease/img'+setname+'_BinaryLabels.csv'
    return pd.read_csv(labels_file, header=0, index_col = 0)
    

'\nsize_diff = len(dataset) - train_size - val_size - test_size\ntrain_size += size_diff\n\n# Use random_split to split the dataset\ntrain_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])\n\n\ntraining_sampler = RandomSampler(train_dataset, replacement=False, generator=torch.Generator().manual_seed(42))\nbatch_sampler = BatchSampler(sampler=training_sampler, batch_size=8, drop_last=True)\n\ntraining_dataloader = torch.utils.data.DataLoader(\n    train_dataset, \n    batch_size=8, \n    shuffle=True, \n    generator=torch.Generator().manual_seed(42)\n)\nct_embedding_model = CTModel()\ntriplet_model = TripletModel(model=ct_embedding_model, mode=\'HN\', seed=42)\n\ne = dataset.get_updated_embeddings(triplet_model, torch.device("cuda"))\n\nprint(dataset.get_embeddings([0]))\n\n\ntraining_sampler = RandomSampler(train_dataset, replacement=False, generator=torch.Generator().manual_seed(42))\nval_sampler = RandomSampler(val_dataset, replacement=Fal

In [23]:
dataset[0]

In [11]:
import numpy as np

def recall(labels_gt, labels_prediction):
    """
    set of ground truth labels and list of prediction lists. 
    This is an independent of k way to calculate the recall.
    """
    n = len(labels_prediction)
    recall = 0
    for i in range(n):
        recall += len(set(labels_gt) & set(labels_prediction[i]))/len(labels_gt)
    return recall/n

def get_batch_data(data, index, size):
    """
    For minibatch training
    """
    column_1 = []
    column_2 = []
    for i in range(index, index + size):
        line = data[i]
        # anchor image
        column_1.append(int(line[0]))
        # positive image
        column_2.append(int(line[1]))
    return np.array(column_1), np.array(column_2)


In [2]:
import torch
import torch.nn as nn
from transformers import ConvNextImageProcessor, ResNetForImageClassification, ResNetConfig
import os
import nibabel as nib
import numpy as np
from scipy.ndimage import zoom
from torchvision import models, utils
from torchvision.transforms import v2

class CTModel(nn.Module):
    def __init__(self):
        super(CTModel, self).__init__()
        
        resnet = models.resnet18(pretrained=True)
        self.features = nn.Sequential(*(list(resnet.children())[:-2]))

        #conv input torch.Size([1,134,512,14,14])
        self.reducingconvs = nn.Sequential(
            nn.Conv3d(134, 64, kernel_size = (3,3,3), stride=(3,1,1), padding=0),
            nn.ReLU(),
            
            nn.Conv3d(64, 32, kernel_size = (3,3,3), stride=(3,1,1), padding=0),
            nn.ReLU(),
            
            nn.Conv3d(32, 16, kernel_size = (3,2,2), stride=(3,2,2), padding=0),
            nn.ReLU())
        
        self.fc = nn.Sequential(
            nn.Linear(16*18*6*6, 512),
            nn.ReLU(True),
            nn.Dropout(0.5),
            
            nn.Linear(512, 256), 
            nn.ReLU(True),
            nn.Dropout(0.5),
            
            nn.Linear(256, 128))

        def forward(self, x):
            shape = list(x.size())
            #example shape: [1,134,3,420,420]
            #example shape: [2,134,3,420,420]
            batch_size = int(shape[0])
            x = x.view(batch_size*134,3,512,512)
            x = self.features(x)
            x = x.view(batch_size,134,512,16,16)
            x = self.reducingconvs(x)
            #output is shape [batch_size, 16, 18, 6, 6]
            x = x.view(batch_size, 16*18*6*6)
            x = self.fc(x)
            return x
        
        # Convolutional Layers
        self.conv1 = nn.Sequential( 
            nn.Conv3d(512, 256, kernel_size=(3, 3, 3), stride=(3, 1, 1), padding=0),
            nn.BatchNorm3d(256),
            nn.ReLU()
        )
        self.conv2 = nn.Sequential( 
            nn.Conv3d(256, 128, kernel_size=(3, 3, 3), stride=(3, 1, 1), padding=0),
            nn.BatchNorm3d(128),
            nn.ReLU()
        )
        self.conv3 = nn.Sequential( 
            nn.Conv3d(128, 64, kernel_size=(3, 2, 2), stride=(3, 2, 2), padding=0),
            nn.BatchNorm3d(64),
            nn.ReLU(),
        )
        self.fc1 = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64*4*6*6, 1024)
        )

            
    def old_forward(self, x):

        # ResNet
        x = self.resnet(x)
        
        # Convolutions
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)

        # Fully Connected Layers
        x = self.fc1(x)

        return x



In [127]:
ct_embedding_model = CTModel()
x = torch.randn(1, 384, 512, 512)


device = torch.device("cuda")
ct_embedding_model.to(device)
x = x.to(device)

out = ct_embedding_model(x)

print(out.shape)

tensor([[[ 0.0000,  0.0000, -2.1103,  ..., -2.1365, -2.1140, -2.0976],
         [-2.1154, -2.1440, -2.1242,  ..., -2.1486, -2.1064, -2.1364],
         [-2.1331, -2.0919, -2.0996,  ..., -2.1276, -2.1228, -2.1061],
         ...,
         [-2.1036, -2.1358, -2.1417,  ..., -2.1247, -2.0925, -2.1146],
         [-2.0729, -2.1120, -2.0932,  ..., -2.1170, -2.1351, -2.1147],
         [-2.1249, -2.0949, -2.1047,  ..., -2.1516, -2.1088, -2.1403]],

        [[-2.0411, -2.0560, -2.0577,  ..., -2.0542, -2.0417, -2.0401],
         [-2.0028, -2.0176, -2.0425,  ..., -2.0406, -2.0270, -2.0252],
         [-2.0171, -2.0501, -2.0557,  ..., -2.0297, -2.0359, -2.0082],
         ...,
         [-2.0195, -2.0319, -2.0438,  ..., -2.0374, -2.0019, -2.0422],
         [-2.0233, -2.0393, -2.0251,  ..., -2.0455, -2.0344, -2.0686],
         [-2.0587, -2.0272, -2.0577,  ..., -2.0522, -2.0340, -2.0501]],

        [[-1.8047, -1.8476, -1.8089,  ..., -1.8045, -1.7889, -1.8189],
         [-1.8062, -1.8002, -1.8065,  ..., -1

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument tensors in method wrapper_CUDA_cat)

In [37]:
print(torch.cuda.is_available())

True


In [66]:
from torch.utils.data import Dataset, random_split, RandomSampler, BatchSampler
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
import time
import logging

random.seed(42)

logging.basicConfig(level=logging.INFO)

class TripletModel:
    def __init__(
        self,
        model,
        mode='HN',
        seed=42,
    ):
        """
        mode: (negative sampling) "random" or "HN"
        estimator: "Linear" or "net"
        """

        # dim latent space
        #self.factor = factor 

        #VSE++
        #self.batch_size = 30  

        #ResNet18
        self.batch_size = 8

        #VSE++
        #self.epochs = 30

        # ResNet18
        self.epochs = 1
        
        self.model = model
        
        self.device = None
        
        if not torch.cuda.is_available():
            self.device = torch.device("cuda")
            self.model.to(self.device)
        
        #VSE++
        self.optimizer = torch.optim.Adam(params=self.model.parameters(), lr=0.0002)
        self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=15, gamma=0.1)

        # FF net or LINEAR
        #self.estimator = estimator
        
        # Save checkpoints to this path
        self.checkpoint_path = TRIPLET_MODELS_PATH
        
        # Negative sampling: 'HN' or 'random'
        self.mode = mode
        
        # Triplet loss' alpha
        self.alpha = 0.2
        
        self.seed = seed
    
    def get_triplets(self, ids_list):
        # construct positive pairs and negative-compatibles from the data points in the mini-batch
        #print('ids_list:', ids_list)
        positive_pairs, negative_compatibles = self.dataset.get_batch_positive_negative_pairs(ids_list)
        triplets = list()

        if len(positive_pairs) == 0 or len(negative_compatibles) == 0:
            logging.warning(f"Cannot build triplets from current batch! {len(positive_pairs)} positives and {len(negative_compatibles)} negative pair candidates found.")
            return None #skip batch

        if self.mode == "random":
            negative_pairs = [[a, random.choice(negative_compatibles[neg_id])] for (a, neg_id) in negative_compatibles]
            # a_p and a_n should always be the same value, so it doesn't matter which one we choose to build the triplet
            triplets = torch.tensor([[a_p, p, n] for ([a_p, p], [a_n, n]) in zip(positive_pairs, negative_pairs)])

        elif self.mode == "HN":
            anchors_list = [a for (a, _) in positive_pairs]
            #embeddings = self.dataset.get_embeddings(ids_list)
            anchors_embs = self.dataset.get_embeddings(anchors_list)
            negative_pairs = list()
            for anchor_id, anchor_emb in zip(anchors_list, anchors_embs):
                negatives_compatibles_ids = negative_compatibles[anchor_id]
                n_compatibles = len(negatives_compatibles_ids)
                negatives_compatibles_embs = self.dataset.get_embeddings(negatives_compatibles_ids)
                negatives_compatibles_embs = negatives_compatibles_embs.squeeze(1)
                anchor_emb_repeat = anchor_emb.repeat(n_compatibles, 1)
                #anchor_emb_repeat = torch.fill(torch.empty(n_compatibles), anchor_emb)
                a_n_pairwise_similarities = torch.nn.functional.cosine_similarity(anchor_emb_repeat, negatives_compatibles_embs)
                most_similar_id = torch.argmax(a_n_pairwise_similarities)
                negative_pairs.append([anchor_id, negatives_compatibles_ids[most_similar_id]])
                #print('\nSTART ANCHOR\n')
                #print('anchors_list', anchors_list)
                #print('anchor_id', anchor_id)
                #print('negatives_compatibles_ids', negatives_compatibles_ids)
                #print('anchor_emb_repeat', anchor_emb_repeat.shape, anchor_emb_repeat)
                #print('negatives_compatibles_embs', negatives_compatibles_embs.shape, negatives_compatibles_embs)
                #print('a_n_pairwise_similarities', a_n_pairwise_similarities)
                #print('most_similar_id', most_similar_id)
                #print('\nEND ANCHOR\n')
            # a_p and a_n should always be the same value, so it doesn't matter which one we choose to build the triplet
            triplets = torch.tensor([[a_p, p, n] for ([a_p, p], [a_n, n]) in zip(positive_pairs, negative_pairs)])
        if self.device:
            triplets = triplets.to(self.device)
        return triplets
        
        
    def training(self, dataset, train_frac=0.6, batch_size=8, epochs=1):
        """
        Training process
        """
        
        assert train_frac > 0.0 and 1.0 >= train_frac
        
        self.dataset = dataset
        self.batch_size = batch_size
        self.epochs = epochs
        
        # Define the sizes of train, validation, and test sets
        train_size = int(train_frac * len(self.dataset))  # percentage of the data for training
        val_size = int((1.0 - train_frac)/1 * len(self.dataset)) # half of the remaining for validation
        test_size = len(self.dataset) - train_size - val_size  # the rest of the remaining for testing

        '''
        If necessary, adjust the size of the training set when the sum of 
        the sizes of the three sets differs from the total dataset size.
        '''
        size_diff = len(self.dataset) - train_size - val_size - test_size
        train_size += size_diff

        # Use random_split to split the dataset
        self.train_dataset, self.val_dataset, self.test_dataset = random_split(dataset=self.dataset, lengths=[train_size, val_size, test_size], generator=torch.Generator().manual_seed(self.seed))
        
        # Create a RandomSampler with a predetermined seed and no replacement
        #self.training_sampler = RandomSampler(self.train_dataset, replacement=False, generator=torch.Generator().manual_seed(seed))
        self.validation_sampler = RandomSampler(self.val_dataset, replacement=False, generator=torch.Generator().manual_seed(self.seed))
        self.test_sampler = RandomSampler(self.test_dataset, replacement=False, generator=torch.Generator().manual_seed(self.seed))

        #self.train_batch_sampler = BatchSampler(sampler=self.training_sampler, batch_size=self.batch_size, drop_last=True)
        self.val_batch_sampler = BatchSampler(sampler=self.validation_sampler, batch_size=len(self.val_dataset), drop_last=True)
        self.test_batch_sampler = BatchSampler(sampler=self.test_sampler, batch_size=len(self.test_dataset), drop_last=True)
        
        # self.dataset.shape := (number of data points, number of classes)
        #self.n_classes = self.dataset.shape[1]

        self.recalls = [] # recall 1, 10, 25
        self.loss_per_epoch = []
        self.step_losses = []
        self.val_loss_per_epoch = []
        
        timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
        self.tb_writer = SummaryWriter(os.path.join(TRIPLET_MODELS_PATH, "runs", "ct_retrieval_trainer_{}".format(timestamp)))
        
        start_time = time.time()
        
        for epoch in range(self.epochs):
            logging.info(f"\n--------------------------------\nEpoch {epoch+1} of {self.epochs}\n--------------------------------\n")

            batch_counter = 0 # counter of processed batches across training
            
            epoch_losses = list()
            
            # change the random seed and reconstruct training_sampler and train_batch_sampler at each epoch
            # this way we can assemble never-before-seen triplets to train on.
            self.training_sampler = RandomSampler(self.train_dataset, replacement=False, generator=torch.Generator().manual_seed(self.seed + epoch))
            self.train_batch_sampler = BatchSampler(sampler=self.training_sampler, batch_size=self.batch_size, drop_last=True)

            # recalculate dataset embeddings at this epoch
            dataset_embeddings = self.dataset.get_updated_embeddings(self.model, self.device)
            
            for batch_index, training_batch in enumerate(self.train_batch_sampler):
                batch_counter += 1
                logging.info(f"=> Batch {batch_counter}:")

                # load batch data
                dataset_batch_ids, batch_embeddings, batch_labels = self.train_dataset[training_batch]

                batch_id_map = dict((i, j) for (i, j) in zip(dataset_batch_ids.tolist(), training_batch))

                # dataset_batch_ids != training_batch bacause dataset_batch_ids corresponds to the original
                # unsplitted dataset while training_batch maps to the training dataset

                if self.device is not None:
                    batch_embeddings = batch_embeddings.to(self.device)
                    
                batch_triplets = self.get_triplets(dataset_batch_ids.tolist())
                batch_triplets = [[batch_id_map[a], batch_id_map[p], batch_id_map[n]] for [a, p, n] in batch_triplets.tolist()]
                batch_triplets = torch.tensor(batch_triplets)
                #print('training_batch:', training_batch)
                #print('dataset_batch_ids:', dataset_batch_ids)
                #print('batch_triplets:', batch_triplets)
                
                if batch_triplets is None:
                    # skip empty batch
                    continue
                
                # perform forward pass over all samples in the batch
                self.model.train(True)
                self.optimizer.zero_grad()
                outputs = self.model(batch_embeddings)

                # map the dataset IDs referenced in the triplets to the indices in the batched model output
                #print('batch_ids:', mini_batch)
                #print('triplets before mapping:\n', triplets)
                #triplets = self._map_triplets(mini_batch, triplets)
                #print('triplets after mapping:\n', triplets)
                
                # calculate loss and gradients
                # IMPORTANT: the loss is calculated for each triplet (3-tuples of samples), therefore,
                # if there is a sample which is not part of a triplet, then it will not be needed in
                # calculation.
                batch_loss = self.loss(outputs, training_batch, batch_triplets, alpha=self.alpha)
                batch_loss.requires_grad = True
                batch_loss.backward()

                # update model weights
                self.optimizer.step()

                self.tb_writer.add_scalar('Steps: Loss/train', batch_loss, global_step=batch_counter)
                
                epoch_losses.append(batch_loss)
                self.step_losses.append(batch_loss)
                logging.info(f"Batch {batch_counter} loss: {round(batch_loss.item(), 4)}")
            
            # finished iterating the batches
            epoch_mean_loss = (sum(epoch_losses) / len(epoch_losses)).item()
            self.tb_writer.add_scalar('Epochs: Loss/train', epoch_mean_loss, global_step=epoch+1)
            self.loss_per_epoch.append(epoch_mean_loss)

            logging.info(f'Training loss: {round(epoch_mean_loss, 4)}')
            
            
            # evaluate on validation set
            logging.info("Evaluating...")
            self.model.eval()
            val_epoch_losses = list()
            with torch.no_grad():
                for val_batch_index, val_batch in enumerate(self.val_batch_sampler):

                    val_batch_ids, val_batch_embeddings, val_batch_labels = self.val_dataset[val_batch]

                    val_batch_id_map = dict((i, j) for (i, j) in zip(val_batch_ids.tolist(), val_batch))
                    

                    if self.device is not None:
                        val_batch_embeddings = val_batch_embeddings.to(self.device)
                        
                    val_batch_triplets = self.get_triplets(val_batch_ids.tolist())
                    val_batch_triplets = [[val_batch_id_map[a], val_batch_id_map[p], val_batch_id_map[n]] for [a, p, n] in val_batch_triplets.tolist()]
                    val_batch_triplets = torch.tensor(val_batch_triplets)
                    
                    if val_batch_triplets is None:
                        # skip validation if there was no triplets to evaluate
                        continue

                    # map the dataset IDs referenced in the triplets to the indices in the batched model output

                    #val_triplets = self._map_triplets(val_batch, val_triplets)

                    

                    # perform evaluation on validation data
                    val_outputs = self.model(val_batch_embeddings)

                    # calculate validation loss
                    val_batch_loss = self.loss(val_outputs, val_batch, val_batch_triplets, alpha=self.alpha)
                    val_epoch_losses.append(val_batch_loss)
                    logging.info(f"Validation batch {val_batch_index+1} loss: {round(val_batch_loss.item(), 4)}")
                
                # finished iterating the batches
                val_mean_loss = (sum(val_epoch_losses) / len(val_epoch_losses)).item()
                self.tb_writer.add_scalars(
                    'Training vs. Validation Loss',
                    {
                        'Training': epoch_mean_loss, 
                        'Validation': val_mean_loss,
                    }, 
                    epoch + 1
                )
                self.val_loss_per_epoch.append(val_mean_loss)

                logging.info(f'Validation loss: {round(val_mean_loss, 4)}')
            
                # end batch iteration
            logging.info("Done evaluating.")
            if(self.val_loss_per_epoch[-1] == min(self.val_loss_per_epoch)):
                logging.info("Best validation loss achieved! Saving checkpoint...")
                self.save(
                    epoch,
                    self.loss_per_epoch, 
                    self.step_losses, 
                    self.val_loss_per_epoch, 
                    self.recalls,
                    time.time()-start_time,
                    os.path.join(TRIPLET_CHECKPOINTS_PATH, f'ct_retrieval_trainer_{timestamp}.pth'),
                )
                logging.info("Checkpoint saved.")

            # update optimizer scheduler ater completing an epoch
            self.scheduler.step()
            # end epoch iteration
        #print('train_dataset:', self.train_dataset[0:len(self.train_dataset)][0])
        #print('val_dataset:', self.val_dataset[0:len(self.val_dataset)][0])

    def _map_triplets(self, batch_ids, triplets):
        a_tensor = triplets.select(1, 0)
        p_tensor = triplets.select(1, 1)
        n_tensor = triplets.select(1, 2)
        a_items = [e.item() for e in a_tensor]
        p_items = [e.item() for e in p_tensor]
        n_items = [e.item() for e in n_tensor]
        a_index = list()
        p_index = list()
        n_index = list()
        for (a, p, n) in zip(a_items, p_items, n_items):
            a_index.append(batch_ids.index(a))
            p_index.append(batch_ids.index(p))
            n_index.append(batch_ids.index(n))
        mapped_triplets = torch.tensor([a_index, p_index, n_index])
        mapped_triplets = mapped_triplets.transpose(0, 1)
        return mapped_triplets

    def _max(self, a, b):
        if a > b:
            return a
        return b
    
    def loss(self, embeddings, embeddings_ids, triplets_ids, alpha=0.2):
        
        """
        for computing the triplet loss [alpha + negative_similarity - positive_similarity]_{+}
        """
        anchors_ids = torch.select(triplets_ids, 1, 0).tolist()
        positives_ids = torch.select(triplets_ids, 1, 1).tolist()
        negatives_ids = torch.select(triplets_ids, 1, 2).tolist()

        ids_map = dict((j, i) for i, j in enumerate(embeddings_ids))
        
        anchors = embeddings[[ids_map[a] for a in anchors_ids]]
        positives = embeddings[[ids_map[p] for p in positives_ids]]
        negatives = embeddings[[ids_map[n] for n in negatives_ids]]
        
        cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
        
        positive_similarities = cos(anchors, positives)
        negative_similarities = cos(anchors, negatives)
        
        unrectified_batch_loss = alpha + negative_similarities - positive_similarities
        rectified_batch_loss = unrectified_batch_loss.detach().map_(torch.zeros(len(unrectified_batch_loss)), self._max)
        #rectified_batch_loss = torch.func.vmap(torch.max)(unrectified_batch_loss, torch.tensor(0))
        batch_loss = torch.mean(rectified_batch_loss)
        if self.device:
            batch_loss = batch_loss.to(device)
        return batch_loss
    
    def save(self, epoch, epoch_loss, step_loss, val_loss, recalls, running_time, directory):
        """
        For saving logs of the experiment
        """
        
        torch.save(
            {
                'model_state_dict': self.model.state_dict(),
                'optimizer_state_dict': self.scheduler.state_dict(),
                'epoch': epoch,
                'epoch_loss': epoch_loss,
                'step_loss': step_loss,
                'val_loss': val_loss,
                'recalls': recalls,
                'running_time': running_time,
            },
            directory
        )

In [29]:
'''    
        def predict(self,query,split):
        """
        the split defines where we are going to search
        """
        if type(query) is not list:
            query= [query]

        i_input = self.visual_encoder(tf.nn.embedding_lookup(visual_matrix, query))
        if split=="val":
                c_eval= self.visual_encoder(visual_matrix[self.val_ids])
        if split=="test":
                c_eval= self.visual_encoder(visual_matrix[self.test_ids])
        if split=="train":
                c_eval=self.visual_encoder(visual_matrix[self.train_ids])

        rating= tf.matmul(i_input, tf.transpose(c_eval) )
        return np.reshape(rating, [-1])

    def metric(self, split, k=10):
        """
        to compute metrics
        """
        if split == "val":
            ids, gt = self.val_ids, self.labels_val
        if split == "test":
            ids, gt = self.test_ids, self.labels_test

        metric=0

        for i in range(len(ids)): #querys
            top= np.argsort(self.predict(ids[i],split))[::-1][1:k+1] #first one must be the query
            #get labels
            prediction= [gt[j] for j in top]
            metric+=recall(gt[i], prediction)
        return metric/len(ids) #average

    def inference(self):
        """
        testing phase
        """
        results=(self.metric("test",k=1), self.metric("test",k=10), self.metric("test",k=25))
        print(results)
        np.save(self.DIS_MODEL_FILE+"test_metrics.npy", np.array([results], dtype=object) )
'''

'    \n        def predict(self,query,split):\n        """\n        the split defines where we are going to search\n        """\n        if type(query) is not list:\n            query= [query]\n\n        i_input = self.visual_encoder(tf.nn.embedding_lookup(visual_matrix, query))\n        if split=="val":\n                c_eval= self.visual_encoder(visual_matrix[self.val_ids])\n        if split=="test":\n                c_eval= self.visual_encoder(visual_matrix[self.test_ids])\n        if split=="train":\n                c_eval=self.visual_encoder(visual_matrix[self.train_ids])\n\n        rating= tf.matmul(i_input, tf.transpose(c_eval) )\n        return np.reshape(rating, [-1])\n\n    def metric(self, split, k=10):\n        """\n        to compute metrics\n        """\n        if split == "val":\n            ids, gt = self.val_ids, self.labels_val\n        if split == "test":\n            ids, gt = self.test_ids, self.labels_test\n\n        metric=0\n\n        for i in range(len(ids)):

In [65]:
ct_embedding_model = CTModel()
triplet_model = TripletModel(model=ct_embedding_model, mode='HN', seed=42)

# Create an instance of the dataset
dr70_dataset = CustomDataset(CT_RESNET18_EMBEDDINGS_PATH, DR70_LABELS_PATH)

triplet_model.training(dataset=dr70_dataset, train_frac=0.7, batch_size=8, epochs=10)

INFO:root:
--------------------------------
Epoch 1 of 10
--------------------------------

INFO:root:=> Batch 1:
INFO:root:Generating triplets for batch [67, 47, 18, 55, 39, 12, 45, 28]
INFO:root:Could not construct a positive pair for anchor ID 45 in this batch. Skipping...
INFO:root:Batch 1 loss: 0.2585
INFO:root:=> Batch 2:
INFO:root:Generating triplets for batch [50, 48, 53, 62, 15, 44, 19, 11]
INFO:root:Batch 2 loss: 0.2311
INFO:root:=> Batch 3:
INFO:root:Generating triplets for batch [25, 34, 56, 6, 41, 31, 21, 61]
INFO:root:Could not construct a positive pair for anchor ID 25 in this batch. Skipping...
INFO:root:Could not construct a positive pair for anchor ID 56 in this batch. Skipping...
INFO:root:Batch 3 loss: 0.2167
INFO:root:=> Batch 4:
INFO:root:Generating triplets for batch [43, 3, 42, 35, 68, 32, 30, 23]
INFO:root:Could not construct a positive pair for anchor ID 35 in this batch. Skipping...
INFO:root:Batch 4 loss: 0.2619
INFO:root:=> Batch 5:
INFO:root:Generating tri

train_dataset: tensor([39, 32, 18, 61, 60, 12, 67, 55, 15, 62, 14,  3, 53, 31,  4, 19, 47, 44,
        21, 25, 49, 41, 23, 26, 68, 58, 30, 42,  6, 48, 57, 16, 17,  8, 28, 43,
        65, 11, 50,  0, 40, 63, 35,  9, 51, 56, 45, 34])
val_dataset: tensor([10, 66, 29, 38, 52, 64,  7, 20, 37, 36, 27, 13, 33, 59,  5, 46,  1, 54,
        22,  2])
