This notebook is used to:
1) Create a balanced dataset
2) Split the data into training/validation, development, and testing
3) Save the data to the split cryo data folder 

# Here we will create patches from the annotated coco data 
1) Load in the coco annotations and labels
2) load in the images that are labeled
3) create patches from the images using the coco labels and save to a new file
4) Create a new csv of the images, labels

## Import Necessary Libraries

In [None]:
# Load in the coco annotations
import json
import cv2
import numpy as np

import pandas as pd

# For Creating a split dataset
from sklearn.model_selection import train_test_split
import random


## Set Save and Load Paths

In [None]:
# Here is the path to all of the json files that contain the annotations:
path_to_annotations = "your_path_here/JSON_annotations/"

image_folder_data5 = "your_path_here/CLAHE_corrected_data/"

image_folder_data = [image_folder_data5]
# Note: if you have multiple folders, you can add them to the list above

patches_image_directory = "your_path_here/All_Patches/"

cryo_data_folder = "your_path_here/Cryo_Data/" # Data sent here is not split into train, test, develop
combined_split_data = "your_path_here/Train_test_develop_split_combined_background/"

synthetic_data = "your_path_here/Synthetic_data/"

# Real Images

## Get Image Patches

Loop Through COCO Formated Cryo Annotations and Get Patch Images

In [None]:
# The decode_fast function is referenced from the following link:
# https://gist.github.com/akTwelve/dc0bbbf26fb14493898fc74cd2aa7f74

def decode_fast(mask_rle, shape):
    # get the width and height of the image
    height, width = shape
    zero_one = np.zeros_like(mask_rle, dtype=np.uint8)
    zero_one[1::2] = 255
    zero_one = zero_one.reshape((len(mask_rle), 1))
    expanded = np.repeat(zero_one, mask_rle, axis=0)
    filled = np.append(expanded, np.zeros(width * height - len(expanded)))
    im_arr = filled.reshape((height, width), order='F').astype(np.uint8)    
    return im_arr

In [None]:
image_set_5 = "100F_All_Labels"
image_sets = [image_set_5]

datasets = []

# Load in the JSON files. Create a for loop that loads in the json files 
for i, val in enumerate(image_sets):
    with open(path_to_annotations + val + ".json") as f:
        globals()['data' + str(i+1)] = json.load(f) # creates a global variable for each data set
        # append the data to the datasets list
        datasets.append(globals()['data' + str(i+1)])

# Iterate over annotations for all images in all of the datasets

# Choose a background type. True for keeping the background, False for removing the background
background = True

# initialize lists to store patch names and labels
patch_names = []
labels = []

for i, data in enumerate(datasets):
    for annotation in data['annotations']:
        image_id = annotation['image_id']
        image_info = next((img for img in data['images'] if img['id'] == image_id), None) #takes only the image info for the image id of that annotation

        if image_info:
            image_file = f"{image_folder_data[i]}{image_info['file_name']}"
            original_image = cv2.imread(image_file)

            ########################## Create Patches Using BBOX ###########################
            if background == True: # if true, here we keep the background
                bbox = annotation['bbox']
                x, y, width, height = map(int, bbox)
                patch = original_image[y:y+height, x:x+width]

            ##################### Create Patches Using Segmentation ########################
            else: # if false, here we remove the background 
                counts_size = annotation['segmentation']
                mask_rle = annotation['segmentation']['counts']
                
                # Convert the mask_rle to a numpy array
                mask_rle = np.array(mask_rle)
                
                mask = decode_fast(mask_rle, (image_info['height'], image_info['width']))
                
                # Apply the mask to the original image
                patch = cv2.bitwise_and(original_image, original_image, mask=mask)
                
                # Use the bbox to crop the patch
                bbox = annotation['bbox']
                x, y, width, height = map(int, bbox)
                patch = patch[y:y+height, x:x+width]
            
            # get the label of the annotation
            label = annotation['category_id']
            labels.append(label)
            
            # Save the patch
            patch_name = f"{image_sets[i]}_{image_id}_{annotation['id']}.png"
            patch_filename = f"{patches_image_directory}{patch_name}"
            
            # save the patch name to a list of patch names
            patch_names.append(patch_name)
            
            # Save the patch
            cv2.imwrite(patch_filename, patch)
            
    # When done with the annotations for one dataset, print a message
    print(f"Finished processing {image_sets[i]}")
            
# Create a df of the patch names and labels
names_labels_df = pd.DataFrame({'patch_name': patch_names, 'label': labels})

# # Save the df as a csv 
names_labels_df.to_csv(f"{combined_split_data}names_labels_df.csv", index=False)

In [None]:
# Load in each of the patches one at a time from All_Patches and resize them to 224x224. Use the names_labels_df to get the names
# This code was added after I had already created patches to make classification easier. You could also add it above :)
# Image files that are too large have to be loaded this way to avoid memory issues

# Create a new directory for the resized patches
resized_patches_directory = "your_path_here/All_Patches_resized/"
names_labels_df = pd.read_csv(f"{combined_split_data}names_labels_df.csv")

# Iterate over the patch names and labels
for index, row in names_labels_df.iterrows():
    patch_name = row['patch_name']
    
    # Load the patch
    patch_file = f"{patches_image_directory}{patch_name}"
    patch = cv2.imread(patch_file)
    
    # Resize the patch to 224x224
    resized_patch = cv2.resize(patch, (224, 224))
    
    # Save the resized patch
    resized_patch_filename = f"{resized_patches_directory}{patch_name}"
    cv2.imwrite(resized_patch_filename, resized_patch)


## Load the df

In [None]:
# # load in the names_labels_df csv and turn it into a df
names_labels_df = pd.read_csv(f"{combined_split_data}names_labels_df.csv")

## Create a Balanced Dataset

In [None]:
# Show how many patches are full partial and empty in each dataset
# Use names_labels_df to get the counts of each label in each dataset
for i, val in enumerate(image_sets):
    dataset = names_labels_df[names_labels_df['patch_name'].str.contains(val)]
    print(f"Dataset: {val}")
    print(dataset['label'].value_counts())
    print("\n")

In [None]:
############################ categorize the patches into their prelabeled class ############################

# sort the patches into full and empty based on their labels using the df and the patches_image_directory.

# Create an evenly split dataset of full, partial, and empty capsids.  
# Note: Where: "categories":[{"id":1,"name":"Full","supercategory":""},{"id":2,"name":"Partial","supercategory":""},{"id":3,"name":"Empty","supercategory":""}]

full_names = []
partial_names = []
empty_names = []

full_labels = []
partial_labels = []
empty_labels = []

####### This section is if there is debris ########
aggregation_names = []
ice_names = []
broken_names = []
background_names = []

aggregation_labels = []
ice_labels = []
broken_labels = []
background_labels = []

for _, row in names_labels_df.iterrows():
    if row['label'] == 1:
        # append the image name and label to the lists
        full_names.append(row['patch_name'])
        full_labels.append(row['label'])
        
    elif row['label'] == 2:
        # append the image name and label to the lists
        partial_names.append(row['patch_name'])
        partial_labels.append(row['label'])
    # else:
    elif row['label'] == 3:
        # append the image name and label to the lists
        empty_names.append(row['patch_name'])
        empty_labels.append(row['label'])
        
    ############################# This section is if there is debris in the dataset #######################
    elif row['label'] == 4:
        # append the image name and label to the lists
        aggregation_names.append(row['patch_name'])
        aggregation_labels.append(row['label'])
        
    elif row['label'] == 5:
        # append the image name and label to the lists
        ice_names.append(row['patch_name'])
        ice_labels.append(row['label'])
        
    elif row['label'] == 6:
        # append the image name and label to the lists
        broken_names.append(row['patch_name'])
        broken_labels.append(row['label'])
        
    else:
        background_names.append(row['patch_name'])
        background_labels.append(row['label'])
    ######################################################################################################

# Find the minimum number of patches for each class
min_num = min(len(full_names), len(partial_names), len(empty_names), len(aggregation_names), len(ice_names), len(broken_names), len(background_names))

################# Create a new list of file names and labels that are evenly split but make it random #################

random.seed(0)
full_names = random.sample(full_names, min_num)    
partial_names = random.sample(partial_names, min_num)
empty_names = random.sample(empty_names, min_num)

aggregation_names = random.sample(aggregation_names, min_num)
ice_names = random.sample(ice_names, min_num)
broken_names = random.sample(broken_names, min_num)
background_names = random.sample(background_names, min_num)

full_labels = [1]*min_num 
partial_labels = [2]*min_num
empty_labels = [3]*min_num

aggregation_labels = [4]*min_num
ice_labels = [5]*min_num
broken_labels = [6]*min_num
background_labels = [7]*min_num

# Print the length of the lists
print(f"Full: {len(full_names)}")
print(f"Partial: {len(partial_names)}")
print(f"Empty: {len(empty_names)}")
print(f"Aggregation: {len(aggregation_names)}")
print(f"Ice: {len(ice_names)}")
print(f"Broken: {len(broken_names)}")
print(f"Background: {len(background_names)}")

# append the full, partial, and empty lists together
even_file_names = full_names + partial_names + empty_names + aggregation_names + ice_names + broken_names + background_names
even_labels = full_labels + partial_labels + empty_labels + aggregation_labels + ice_labels + broken_labels + background_labels

## Partition the Dataset into Training, Development, and Testing

In [None]:
# Create an even split of the classes using even_file_names and even_labels

############################# Split the data into two partitions ############################
tra_val_file_names, dev_file_names, tra_val_labels, dev_labels = train_test_split(even_file_names, even_labels, test_size=0.2, random_state=42, stratify=even_labels)

######################### Save the file names and labels to a csv ############################
# create a df with the file names and labels
tra_val_df = pd.DataFrame({'patch_name': tra_val_file_names, 'label': tra_val_labels})
dev_df     = pd.DataFrame({'patch_name': dev_file_names, 'label': dev_labels})

# Save the df as a csv
tra_val_df.to_csv(f"{combined_split_data}tra_val_names_labels_df.csv", index=False)
dev_df.to_csv(f"{combined_split_data}dev_names_labels_df.csv", index=False)

# Synthetic Data

In [None]:
############################## Load in the synthetic data ###################################

# Load in capsids.npy and labels.npy from the synthetic data
synthetic_images = np.load(f"{synthetic_data}capsids.npy")
synthetic_labels = np.load(f"{synthetic_data}labels.npy")

################### Split the Synthetic data into three partitions ##########################
# Create three partitions of the data: 20% development, 70% training and cross-validation, 10% final real world test data

# Use stratified random sampling to split the data into three partitions
synthetic_train_images, synthetic_dev_images, synthetic_train_labels, synthetic_dev_labels = train_test_split(synthetic_images, synthetic_labels, test_size=0.2, random_state=42, stratify=synthetic_labels)

# Use stratified random sampling to split the data into three partitions
synthetic_train_images, synthetic_rw_images, synthetic_train_labels, synthetic_rw_labels = train_test_split(synthetic_train_images, synthetic_train_labels, test_size=0.125, random_state=42, stratify=synthetic_train_labels)

# Save the images and labels into the synthetic data folder
np.save(f"{synthetic_data}synthetic_train_images.npy", synthetic_train_images)
np.save(f"{synthetic_data}synthetic_train_labels.npy", synthetic_train_labels)

np.save(f"{synthetic_data}synthetic_dev_images.npy", synthetic_dev_images)
np.save(f"{synthetic_data}synthetic_dev_labels.npy", synthetic_dev_labels)

np.save(f"{synthetic_data}synthetic_rw_images.npy", synthetic_rw_images)
np.save(f"{synthetic_data}synthetic_rw_labels.npy", synthetic_rw_labels)

In [None]:
############################ Load in the split synthetic data  ##############################

# Load in the synthetic data and convert it to a list
synthetic_train_images = np.load(f"{synthetic_data}synthetic_train_images.npy").tolist()
synthetic_train_labels = np.load(f"{synthetic_data}synthetic_train_labels.npy").tolist()

synthetic_dev_images = np.load(f"{synthetic_data}synthetic_dev_images.npy").tolist()
synthetic_dev_labels = np.load(f"{synthetic_data}synthetic_dev_labels.npy").tolist()

synthetic_rw_images = np.load(f"{synthetic_data}synthetic_rw_images.npy").tolist()
synthetic_rw_labels = np.load(f"{synthetic_data}synthetic_rw_labels.npy").tolist()