## Script to split dataset into train/valid/test parts

In [None]:
# print working directory
import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))
from utils import data_utils,config_utils
from global_constants import gt_column_name,colname_in_pred
import pandas as pd

In [None]:
config_file  = '../../config_files/experiments/breast_test/unimodal/config_UNI_Least_confidence_sampling_foundational_model_diversity_norm_im_0_norm_ge_0_ICF_0.json'
my_config = config_utils.read_config_from_user_file(config_file)

### Column Name Standardization

Since the framework consistently refers to the column gt_column (e.g. `Pathologist Annotations`), it is important to ensure that this naming is uniform across all inputs.

In this first iteration, we explicitly rename the column provided by the pathologist to align with the expected format.  
To preserve the original data, the unmodified annotations will be saved in a backup file.



In [3]:
# read into dataframe
df = pd.read_csv(my_config["original_annotations_test_only"])

columns = df.columns.tolist() 
# second column is Pathologists annotation
column_name_for_rename = columns[1] if len(columns) > 1 else None

if column_name_for_rename is None:
    print("Pathologists annotation column not found in original_annotations_test_only file.")
    os._exit(1)

if column_name_for_rename != gt_column_name:
    
    backup_file = my_config["original_annotations_test_only"].replace(".csv", "_backup.csv")
    df.to_csv(backup_file, index=False)  # make a backup of the original file
    # rename the column
    df = df.rename(columns={column_name_for_rename: gt_column_name})
    df.to_csv(my_config["original_annotations_test_only"], index=False)  # save the modified dataframe
    print(f"Column name {column_name_for_rename} was renamed to {gt_column_name}")
else:
    print(f"Column name {column_name_for_rename} is already {gt_column_name}, no changes made.")



Column name Pathologist Annotations is already Pathologist Annotations, no changes made.


In [None]:

original_data,_,_ = data_utils.load_Loupe_annotations(csv_file_path=my_config["original_annotations_test_only"], 
                                                 patch_dir=None,
                                                 calculate_class_weights=False)

In [7]:
print("Class distribution in original data:")
print('-'*50)
print(original_data.groupby(gt_column_name).count())
print(original_data[gt_column_name].shape[0])
print('-'*50)
print("Total number of classes: ", len(original_data[gt_column_name].unique()))
print("Classes: ", original_data[gt_column_name].unique())
print('-'*50)
print("Total number of patches: ", original_data.shape[0])

Class distribution in original data:
--------------------------------------------------
                                     Barcode
Pathologist Annotations                     
Adipose tissue                           874
Duct_non neoplastic                        7
Necrosis                                  49
atypical ductal hyperplasia                4
exclude                                   54
stroma1_immune cell&vessel enriched      667
stroma2_mod cell_collagen low            253
stroma3_cell low_collagen mod/high       843
stroma4_cell low_collagen high           875
tumor1&2_DCIS                            408
tumor3_invasive&stroma                   958
4992
--------------------------------------------------
Total number of classes:  11
Classes:  ['stroma4_cell low_collagen high' 'stroma3_cell low_collagen mod/high'
 'stroma2_mod cell_collagen low' 'stroma1_immune cell&vessel enriched'
 'tumor3_invasive&stroma' 'Adipose tissue' 'exclude' 'tumor1&2_DCIS'
 'Necrosis' 'Duct_non

### Test data split

In [8]:
_, my_test_data= data_utils.split_annotations_train_val(original_data,valid_proportion = my_config["test_split"])

print('Test data')
print((my_test_data.groupby(gt_column_name)).count())
print('Total number of patches:', my_test_data.shape[0])

# check if test set contains all classes
unique_classes = original_data[gt_column_name].unique()
# get unique classes in test set
unique_classes_test = my_test_data[gt_column_name].unique()
# check if all classes are present in test set
for class_name in unique_classes:
    if class_name not in unique_classes_test:
        raise ValueError('Class {} is not present in test set'.format(class_name))
        



Test data
                                     Barcode
Pathologist Annotations                     
Adipose tissue                            88
Duct_non neoplastic                        1
Necrosis                                   5
atypical ductal hyperplasia                1
exclude                                    6
stroma1_immune cell&vessel enriched       67
stroma2_mod cell_collagen low             26
stroma3_cell low_collagen mod/high        85
stroma4_cell low_collagen high            88
tumor1&2_DCIS                             41
tumor3_invasive&stroma                    96
Total number of patches: 504


Remove test data from the rest

In [9]:
# print the number per class in the original data
print(original_data.groupby(gt_column_name).count())
print('Total number of patches:', original_data.shape[0])
# remove test data from original data
original_data_without_test = original_data[~original_data['Barcode'].isin(my_test_data['Barcode'])]
print(original_data_without_test.groupby(gt_column_name).count())
print('Total number of patches:', original_data_without_test.shape[0])


                                     Barcode
Pathologist Annotations                     
Adipose tissue                           874
Duct_non neoplastic                        7
Necrosis                                  49
atypical ductal hyperplasia                4
exclude                                   54
stroma1_immune cell&vessel enriched      667
stroma2_mod cell_collagen low            253
stroma3_cell low_collagen mod/high       843
stroma4_cell low_collagen high           875
tumor1&2_DCIS                            408
tumor3_invasive&stroma                   958
Total number of patches: 4992
                                     Barcode
Pathologist Annotations                     
Adipose tissue                           786
Duct_non neoplastic                        6
Necrosis                                  44
atypical ductal hyperplasia                3
exclude                                   48
stroma1_immune cell&vessel enriched      600
stroma2_mod cell_collagen

Save test & remaining train data

In [10]:
original_data_without_test

Unnamed: 0,Barcode,Pathologist Annotations
0,AACACCTACTATCGAA-1,stroma4_cell low_collagen high
1,AACACGTGCATCGCAC-1,stroma3_cell low_collagen mod/high
3,AACAGGAAGAGCATAG-1,stroma3_cell low_collagen mod/high
4,AACAGGATTCATAGTT-1,stroma3_cell low_collagen mod/high
6,AACAGGTTATTGCACC-1,stroma1_immune cell&vessel enriched
...,...,...
4987,TGTTGGAACGAGGTCA-1,stroma4_cell low_collagen high
4988,TGTTGGAAGCTCGGTA-1,stroma1_immune cell&vessel enriched
4989,TGTTGGATGGACTTCT-1,tumor3_invasive&stroma
4990,TGTTGGCCAGACCTAC-1,stroma3_cell low_collagen mod/high


In [11]:
# save the original data without test data 
original_data_without_test.to_csv(my_config["train_data"], index=False)
# save the test data
my_test_data.to_csv(my_config["test_data"], index=False)



In [12]:
# delete variables
del original_data
del my_test_data
del original_data_without_test


### Validation data spilt

In [None]:
original_data,_,_ = data_utils.load_Loupe_annotations(csv_file_path=my_config["train_data"], 
                                                 patch_dir=None,
                                                 calculate_class_weights=False)

_,my_val_data = data_utils.split_annotations_train_val(original_data,valid_proportion = my_config["valid_split"])
print('Validation data')
print((my_val_data.groupby(gt_column_name)).count())
print('Total number of patches:', my_val_data.shape[0])


In [16]:
original_data

Unnamed: 0,Barcode,Pathologist Annotations
0,AACACCTACTATCGAA-1,stroma4_cell low_collagen high
1,AACACGTGCATCGCAC-1,stroma3_cell low_collagen mod/high
2,AACAGGAAGAGCATAG-1,stroma3_cell low_collagen mod/high
3,AACAGGATTCATAGTT-1,stroma3_cell low_collagen mod/high
4,AACAGGTTATTGCACC-1,stroma1_immune cell&vessel enriched
...,...,...
4483,TGTTGGAACGAGGTCA-1,stroma4_cell low_collagen high
4484,TGTTGGAAGCTCGGTA-1,stroma1_immune cell&vessel enriched
4485,TGTTGGATGGACTTCT-1,tumor3_invasive&stroma
4486,TGTTGGCCAGACCTAC-1,stroma3_cell low_collagen mod/high


In [17]:
# check if test set contains all classes
unique_classes = original_data[gt_column_name].unique()
# get unique classes in test set
unique_classes_test = my_val_data[gt_column_name].unique()
# check if all classes are present in test set
for class_name in unique_classes:
    if class_name not in unique_classes_test:
        raise ValueError('Class {} is not present in test set'.format(class_name))

Remove valid data from the rest

In [18]:
# print the number per class in the original data
print(original_data.groupby(gt_column_name).count())
print('Total number of patches:', original_data.shape[0])
# remove test data from original data
original_data_without_valid = original_data[~original_data['Barcode'].isin(my_val_data['Barcode'])]
print(original_data_without_valid.groupby(gt_column_name).count())
print('Total number of patches:', original_data_without_valid.shape[0])


                                     Barcode
Pathologist Annotations                     
Adipose tissue                           786
Duct_non neoplastic                        6
Necrosis                                  44
atypical ductal hyperplasia                3
exclude                                   48
stroma1_immune cell&vessel enriched      600
stroma2_mod cell_collagen low            227
stroma3_cell low_collagen mod/high       758
stroma4_cell low_collagen high           787
tumor1&2_DCIS                            367
tumor3_invasive&stroma                   862
Total number of patches: 4488
                                     Barcode
Pathologist Annotations                     
Adipose tissue                           707
Duct_non neoplastic                        5
Necrosis                                  39
atypical ductal hyperplasia                2
exclude                                   43
stroma1_immune cell&vessel enriched      540
stroma2_mod cell_collagen

### Save validation and training data

In [19]:
# save the original data without test data
original_data_without_valid.to_csv(my_config["train_data"], index=False)
# save validation data
my_val_data.to_csv(my_config["valid_data"], index=False)


In [20]:
# delete variables
del original_data
del my_val_data
del original_data_without_valid

## Check for overlaps between sets

In [None]:
# load train, valid and test data
train_data,_,_ = data_utils.load_Loupe_annotations(csv_file_path=my_config["train_data"],patch_dir=None, calculate_class_weights=False)
valid_data,_,_ = data_utils.load_Loupe_annotations(csv_file_path=my_config["valid_data"],patch_dir=None, calculate_class_weights=False)
test_data,_,_ = data_utils.load_Loupe_annotations(csv_file_path=my_config["test_data"],patch_dir=None, calculate_class_weights=False)

In [22]:
# find overlap between train, valid and test data
train_valid_overlap = train_data[train_data['Barcode'].isin(valid_data['Barcode'])]
train_test_overlap = train_data[train_data['Barcode'].isin(test_data['Barcode'])]
valid_test_overlap = valid_data[valid_data['Barcode'].isin(test_data['Barcode'])]
print('Overlap between train and valid:', train_valid_overlap.shape[0])
print('Overlap between train and test:', train_test_overlap.shape[0])
print('Overlap between valid and test:', valid_test_overlap.shape[0])

Overlap between train and valid: 0
Overlap between train and test: 0
Overlap between valid and test: 0


In [23]:
# check if test set contains all classes
unique_classes = train_data[gt_column_name].unique()
# get unique classes in test set
unique_classes_test = test_data[gt_column_name].unique()
# check if all classes are present in test set
for class_name in unique_classes:
    if class_name not in unique_classes_test:
        print('Class {} is not present in test set'.format(class_name))

# check if valid set contains all classes
unique_classes_valid = valid_data[gt_column_name].unique()
for clss_name in unique_classes:
    if clss_name not in unique_classes_valid:
        print('Class {} is not present in valid set'.format(clss_name))


    