In [None]:
from importlib.resources import path
import os
import pathlib
import shutil
from random import shuffle

In [1]:
class SegmentationFileSeperator:
    
    def __init__(self, target_image_path, original_image_dir, target_segmap_path, original_segmap_dir):
        self.target_image_path = target_image_path
        self.original_image_dir = original_image_dir 
        self.target_segmap_path = target_segmap_path
        self.original_segmap_dir = original_segmap_dir
        
    def dataset_segregate(self):
        
        #This function is used to create empty train, test, and validation directories within the main directory
        for typ in ["Train", "Test", "Valid"]:
            path_1 = os.path.join(self.target_image_path, typ)
            os.makedirs(path_1)

            path_2 = os.path.join(self.target_segmap_path, typ)
            os.makedirs(path_2)
    
    def class_maker(self):
        
        #This function makes a directories within the train, test, and validation directories for each class in the datset.
        for typ in os.listdir(self.target_image_path):
            for section in os.listdir(self.original_image_dir):
                path_1 = os.path.join(self.target_image_path, typ, section)
                path_2 = os.path.join(self.target_segmap_path, typ, section)
                os.makedirs(path_1)
                os.makedirs(path_2)
    
    def shuffle_together(self, x,  y):
        z = zip(x, y)
        shuffle(z)
        a, b = zip(*z)
        return a, b

                
    def file_mover(self, train_pr, valid_pr):
        
        #This function is the most important function. This moves all the files from the origianal directory into the target directories with the proportions for train, test and validation data.
        #Section represents class
        #dir represents train,test or val
        for dir in os.listdir(self.target_image_path):
            print(f"Moving to dir: {dir}")
            for section in os.listdir(self.original_image_dir):
                d_path = os.path.join(self.original_image_dir, section)
                m_path = os.path.join(self.original_segmap_dir, section)
                if dir == "Train":
                    start_point = 0
                    cutoff = len(os.listdir(d_path))
                    end_point = int(train_pr * cutoff)
                elif dir == "Valid":
                    start_off = len(os.listdir(d_path))
                    start_point = int(train_pr * start_off)

                    cutoff = len(os.listdir(d_path))
                    end_point = start_point + int(valid_pr * cutoff)
                else:
                    test_pr = train_pr + valid_pr
                    start_off = len(os.listdir(d_path))
                    start_point = int(test_pr * start_off)

                   
                    end_point = len(os.listdir(cutoff))
        
                moveables_1 = sorted(os.listdir(d_path))
                moveables_2 = sorted(os.listdir(m_path))

                mv_1, mv_2 = self.shuffle_together(moveables_1, moveables_2)
                for i in range(start_point, end_point):

                    i_src_path = os.path.join(self.original_image_dir, section, mv_1[i])
                    i_des_path = os.path.join(self.target_image_path, dir,section, mv_1[i])

                    shutil.copy(i_src_path, i_des_path)

                    #Moving segmaps
                    m_src_path = os.path.join(self.original_segmap_dir, section, mv_2[i])
                    m_des_path = os.path.join(self.target_segmap_path, dir, section, mv_2[i])

                    shutil.copy(m_src_path, m_des_path)
                    
    def print_statistics(self):
        
        for dir in os.listdir(self.target_image_path):
    
            for section in os.listdir(self.original_image_dir):
                num = len(os.listdir(f"{self.target_image_path}/{dir}/{section}"))
                print(f"Files in {dir} -> {section} Directory : {num}")
            print(" ")

        
        for dir in os.listdir(self.target_segmap_path):
    
            for section in os.listdir(self.original_segmap_dir):
                num = len(os.listdir(f"{self.target_segmap_path}/{dir}/{section}"))
                print(f"Files in {dir} -> {section} Directory : {num}")
            print(" ")
            
    
    def run(self, train_pr: float, valid_pr: float):
        
        self.dataset_segregate()
        self.class_maker()

        if train_pr < 0 or train_pr > 1:
            print("Train proportion value not valid. Please enter a value greater than 0 and less than 1.")
        elif valid_pr < 0 or valid_pr > 1:
            print("The validation proportion is not valid. Please enter a value greater than 0 and less than 1.")
        elif valid_pr > train_pr:
            print("Validation proportion is greater than training data proportion, please enter a value less than the training proportion.")
        elif valid_pr + train_pr >= 1:
            print("The sum of the validation and training proportion is greater than or equal to one, this is not valid. Please enter values such that their sum is strictly less than 1.")

        
        self.file_mover(train_pr=train_pr, valid_pr=valid_pr)
    def test_proportions(self, train_pr: float, valid_pr: float):

        if train_pr < 0 or train_pr > 1:
            print("Train proportion value not valid. Please enter a value greater than 0 and less than 1.")
        elif valid_pr < 0 or valid_pr > 1:
            print("The validation proportion is not valid. Please enter a value greater than 0 and less than 1.")
        elif valid_pr > train_pr:
            print("Validation proportion is greater than training data proportion, please enter a value less than the training proportion.")
        elif valid_pr + train_pr >= 1:
            print("The sum of the validation and training proportion is greater than or equal to one, this is not valid. Please enter values such that their sum is strictly less than 1.")
        else:
            print("The proportions enterred are valid.")
        
        
        
        
