In [77]:
import torch
import wandb
#!pip install wandb
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from torch.utils.data import DataLoader, Dataset, random_split
import os

In [78]:
RANDOM_SEED = 42

In [79]:
def dataset_splitting(dataset: Dataset, train_fraction, val_fraction, test_fraction):
    """Split the dataset into training, validation, and test sets.
    Input:
    - dataset: the dataset to split
    - train_fraction: the fraction of the dataset to use for training
    - val_fraction: the fraction of the dataset to use for validation
    - test_fraction: the fraction of the dataset to use for testing
    Output:
    - train_set (torch.utils.data.Subset): the training set
    - val_set (torch.utils.data.Subset): the validation set
    - test_set (torch.utils.data.Subset): the test set"""
        
    # Define the sizes of the training, validation, and test sets
    num_samples = len(dataset)
    train_size = int(train_fraction * num_samples)
    val_size = int(val_fraction * num_samples)
    # Check if the test fraction is correct
    if (train_fraction + val_fraction + test_fraction != 1):
        raise ValueError("The sum of the training, validation, and test fractions must be equal to 1.")
    test_size = num_samples - train_size - val_size

    generator = torch.Generator().manual_seed(RANDOM_SEED)
    
    # Split the dataset into training, validation, and test sets
    train_set, val_set, test_set = random_split(dataset, [train_size, val_size, test_size], generator=generator)

    # The obtained objects are of type torch.utils.data.Subset
    return train_set, val_set, test_set


class CBIS_DDSM_Dataset(Dataset):

    def __init__(self, csv_file, root_dir, transform=None):
        """
        Arguments:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.frame = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform

        # remove the benign without callback cases
        self.frame = self.frame[self.frame["pathology"] != "BENIGN_WITHOUT_CALLBACK"]
        self.frame = self.frame.reset_index(drop=True)
        # 1 for malignant, 0 for benign
        self.frame["pathology_id"] = self.frame["pathology"].map({"BENIGN": 0, "MALIGNANT": 1})

        # We want to obtain a split of 60% training, 20% validation, and 20% testing
        self.train_set, self.val_set, self.test_set = dataset_splitting(self, 0.6, 0.2, 0.2)

    def __len__(self):
        return len(self.frame)

    def __getitem__(self, idx):
        print("called with idx", idx)   
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_name = os.path.join(self.root_dir,
                                self.frame.iloc(idx)["image_path"])
        image = plt.imread(img_name)
        label = self.frame.iloc(idx)["pathology_id"]
        sample = {'image': image, 'label': label}

        if self.transform:
            sample = self.transform(sample)

        return sample
    
    def get_loaders(self, batch_size, shuffle, drop_last, num_workers):
        return {"train": DataLoader(self.train_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, num_workers=num_workers),
           "val": DataLoader(self.val_set, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=num_workers),
           "test": DataLoader(self.test_set, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=num_workers)}
    
    def get_trainset(self):
        """Return the train set."""
        return self.train_set
    

    def get_valset(self):
        """Return the validation set."""
        return self.val_set
    
    def get_testset(self):
        """Return the test set."""
        return self.test_set
    
    def get_set_len(self):
        """Return the length of the train, val and test set."""
        return len(self.train_set), len(self.val_set), len(self.test_set)

In [80]:
dataset = CBIS_DDSM_Dataset(csv_file="../data/processed/calc_dicom.csv", root_dir="../data/processed", transform=None)

dataset.__len__()

train_set = dataset.get_trainset()
val_set = dataset.get_valset()
test_set = dataset.get_testset()

lengths = dataset.get_set_len()
lengths

(798, 266, 267)

In [81]:
dataset.frame.idxmax()

subject_id      518
image_path      129
pathology         1
pathology_id      1
dtype: int64

In [82]:
loaders = dataset.get_loaders(batch_size=32, shuffle=True, drop_last=True, num_workers=0)

In [83]:
print("length of train loader", loaders["train"])
print(train_set.indices)
print(dataset.frame.iloc[train_set.indices])
for data in loaders["train"]:
    img = data["image"]
    label = data["label"]

    plt.imshow(img[0])
    plt.title(label[0])
    plt.show()

length of train loader <torch.utils.data.dataloader.DataLoader object at 0x763c6b961a60>
[259, 318, 828, 1121, 113, 640, 476, 1247, 750, 764, 1324, 945, 795, 12, 1178, 143, 901, 534, 306, 725, 1173, 87, 911, 514, 1046, 654, 757, 36, 4, 1242, 1233, 511, 942, 374, 600, 1167, 1265, 595, 1209, 1217, 922, 1253, 864, 594, 110, 1272, 134, 1113, 887, 626, 884, 240, 1179, 1106, 114, 862, 1071, 713, 1022, 1151, 719, 801, 1135, 383, 762, 94, 42, 715, 774, 1312, 325, 1160, 1298, 24, 1269, 404, 359, 68, 693, 470, 1310, 772, 245, 1082, 492, 927, 1215, 1039, 819, 900, 379, 958, 1136, 490, 832, 1276, 1062, 1240, 463, 847, 473, 674, 963, 994, 213, 722, 561, 198, 947, 357, 72, 545, 493, 50, 837, 494, 496, 402, 630, 875, 125, 587, 445, 241, 1076, 308, 936, 899, 479, 290, 910, 799, 276, 701, 1050, 1007, 827, 261, 565, 699, 301, 1061, 218, 802, 1219, 759, 123, 197, 1148, 690, 777, 293, 161, 669, 148, 612, 40, 588, 497, 182, 385, 1321, 10, 1, 417, 955, 735, 1297, 41, 412, 252, 335, 1329, 814, 563, 176, 98, 

ValueError: No axis named 1242 for object type DataFrame

In [None]:
missing_indices = set(range(max(dataset.frame.index))) - set(dataset.frame.index)
print(missing_indices)

set()
