In [1]:
%pwd

'c:\\Users\\lenovo\\Desktop\\Stage\\text-to-3D_Model_Generation\\research'

In [2]:
import os
os.chdir("..")
%pwd

'c:\\Users\\lenovo\\Desktop\\Stage\\text-to-3D_Model_Generation'

In [3]:
from dataclasses import dataclass
from pathlib import Path

@dataclass(frozen=True)
class DataSplitConfig:
    local_data_file: Path
    output_dir : Path
    train_ratio: float
    test_ratio: float
    val_ratio: float

In [4]:
from textTo3DModelGen.constants import *
from textTo3DModelGen.utils.common import read_yaml, create_directories

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
class ConfigurationManager:
    def __init__(
            self, 
            config_filepath = CONFIG_FILE_PATH,
            params_filepath = HYPER_PARAMS_FILE_PATH):
        
        self.config = read_yaml(config_filepath)
        self.params = read_yaml(params_filepath)

        create_directories([
            self.config.artifacts_root
        ])

    def get_data_split_config(self) -> DataSplitConfig:
        config = self.config.data_split

        create_directories([
            config.output_dir
        ])

        data_split_config = DataSplitConfig(
            local_data_file= config.local_data_file,
            output_dir= config.output_dir,
            train_ratio= config.train_ratio,
            test_ratio= config.test_ratio,
            val_ratio= config.val_ratio
        )

        return data_split_config


In [6]:
from textTo3DModelGen import logger
from textTo3DModelGen.utils.common import save_list_to_textfile
import random
import pandas as pd

In [7]:
class DataSplit:
    def __init__(self, config: DataSplitConfig):
        self.config = config

    def split_data(self):
        try:
            data = pd.read_csv(self.config.local_data_file)
            logger.info(f"Read of data from {self.config.local_data_file} is done.")

            uids = data["uids"].to_list()

            random.shuffle(uids)
            logger.info(f"Shuffle the list of uids is done.")

            # Compute the indices for splitting
            train_size = int(self.config.train_ratio * len(uids))
            val_size = int(self.config.val_ratio * len(uids))
            test_size = len(uids) - train_size - val_size 
            logger.info(f"train_size: {train_size} || val_size: {val_size} || test_size: {test_size}.")

            # Split the list
            train_uids = uids[:train_size]
            val_uids = uids[train_size:train_size + val_size]
            test_uids = uids[train_size + val_size:]

            train_filename = os.path.join(self.config.output_dir, "train.txt")
            test_filename = os.path.join(self.config.output_dir, "test.txt")
            val_filename = os.path.join(self.config.output_dir, "val.txt")

            # Save the train, validation, and test IDs to text files
            save_list_to_textfile(train_filename, train_uids, "train_uids")
            save_list_to_textfile(test_filename, test_uids, "test_uids")
            save_list_to_textfile(val_filename, val_uids, "val_uids")
        except Exception as e:
            raise e

In [8]:
try:
    config = ConfigurationManager()
    data_split_config = config.get_data_split_config()
    data_split = DataSplit(config= data_split_config)
    data_split.split_data()
except Exception as e:
    raise e

[2024-09-02 16:02:26,743: INFO: yaml file: config\config.yaml loaded successfully]
[2024-09-02 16:02:26,745: INFO: yaml file: hyper_params.yaml loaded successfully]
[2024-09-02 16:02:26,747: INFO: created directory at: artifacts]
[2024-09-02 16:02:26,749: INFO: created directory at: artifacts/data_split]
[2024-09-02 16:02:26,860: INFO: Read of data from artifacts/data_ingestion/objaverse_with_description.csv is done.]
[2024-09-02 16:02:26,870: INFO: Shuffle the list of uids is done.]
[2024-09-02 16:02:26,872: INFO: train_size: 1 || val_size: 0 || test_size: 1.]
[2024-09-02 16:02:26,875: INFO: train_uids saved to artifacts/data_split\train.txt.]
[2024-09-02 16:02:26,878: INFO: test_uids saved to artifacts/data_split\test.txt.]
[2024-09-02 16:02:26,881: INFO: val_uids saved to artifacts/data_split\val.txt.]
