In [3]:
import pandas as pd
import numpy as np
import os
import torch
from torch.utils.data import Dataset
from utils import *

In [16]:
file_paths = []
train_dir = 'F:/git_repo/WKN_SSO/new_data/Train/'
test_dir = 'F:/git_repo/WKN_SSO/new_data/Test/'

def load_batdata(root_dir):
    for filename in os.listdir(root_dir):
        file_path = os.path.join(root_dir, filename)
        if file_path.endswith('.csv'):
            file_path = os.path.normpath(file_path)
            file_paths.append(file_path)
    
    for file_path in file_paths:
        file_num = int(file_path.split('\\')[-1].split('_')[-1].split('.')[0])/5
        data = pd.read_csv(file_path)
        inputs = data['amplitude'].values.astype(float)
        inputs = min_max_scale(inputs).reshape(1, -1)
        inputs = torch.from_numpy(inputs.astype(np.float32)).cuda()
    
    return inputs

train_data = load_batdata(train_dir)
print(train_data.shape)

F:\git_repo\WKN_SSO\new_data\Train\time_0.csv 0.0
F:\git_repo\WKN_SSO\new_data\Train\time_10.csv 2.0
F:\git_repo\WKN_SSO\new_data\Train\time_100.csv 20.0
F:\git_repo\WKN_SSO\new_data\Train\time_1000.csv 200.0
F:\git_repo\WKN_SSO\new_data\Train\time_1005.csv 201.0
F:\git_repo\WKN_SSO\new_data\Train\time_1010.csv 202.0
F:\git_repo\WKN_SSO\new_data\Train\time_1015.csv 203.0
F:\git_repo\WKN_SSO\new_data\Train\time_1020.csv 204.0
F:\git_repo\WKN_SSO\new_data\Train\time_1025.csv 205.0
F:\git_repo\WKN_SSO\new_data\Train\time_1030.csv 206.0
F:\git_repo\WKN_SSO\new_data\Train\time_1035.csv 207.0
F:\git_repo\WKN_SSO\new_data\Train\time_1040.csv 208.0
F:\git_repo\WKN_SSO\new_data\Train\time_1045.csv 209.0
F:\git_repo\WKN_SSO\new_data\Train\time_105.csv 21.0
F:\git_repo\WKN_SSO\new_data\Train\time_1050.csv 210.0
F:\git_repo\WKN_SSO\new_data\Train\time_1055.csv 211.0
F:\git_repo\WKN_SSO\new_data\Train\time_1060.csv 212.0
F:\git_repo\WKN_SSO\new_data\Train\time_1065.csv 213.0
F:\git_repo\WKN_SSO\new

In [24]:
def folder_total(root_dir):
    folder_total = 0
    for filename in os.listdir(root_dir):
        if filename.endswith('.csv'):
            folder_total += 1

    return int(folder_total)

def get_hi(root_dir, file_path, hi_type=1, two_stage_hp=[0.6, 0.6]):
    folder_tot = folder_total(root_dir)
    file_num = int(int(file_path.split('\\')[-1].split('_')[-1].split('.')[0])/5)
    print(file_num)

    if hi_type == 1:
        hi = np.linspace(1,0,folder_tot)
    elif hi_type == 2:
        hi = two_stage_hi(two_stage_hp[0],two_stage_hp[1], folder_tot)

    return hi[file_num]

In [33]:
class BatteryDataSet(Dataset):
    def __init__(self, root_dir, transform=None, mode='train', label_style=2, two_stage_hp=[0.6, 0.6]):
        self.root_dir = root_dir
        self.transform = transform
        self.mode = mode
        self.label_style = label_style
        self.two_stage_hp = two_stage_hp
        self.file_paths = self.load_batdata()

        def load_batdata(self):
            file_paths = []
            for filename in os.listdir(self.root_dir):
                file_path = os.path.join(self.root_dir, filename)
                if file_path.endswith('.csv'):
                    file_path = os.path.normpath(file_path)
                    file_paths.append(file_path)
            
            return file_paths
    
        def __getitem__(self, index):
            file_path = self.file_paths[index]
            data = pd.read_csv(file_path)
            inputs = data['amplitude'].values.astype(float)
            inputs = min_max_scale(inputs).reshape(1, -1)
            inputs = torch.from_numpy(inputs.astype(np.float32))

            if self.mode == 'train':
                label = get_hi(self.root_dir, file_path, self.label_style, self.two_stage_hp)
                label = torch.tensor(label, dtype=torch.float32)
                if self.transform:
                    inputs = self.transform(inputs)
                
                return inputs, label    
            
            elif self.mode == 'test':
                if self.transform:
                    inputs = self.transform(inputs)
                
                return inputs
            
        def __len__(self):
            return len(self.file_paths)
        
train_dir = 'F:/git_repo/WKN_SSO/new_data/Train/'
test_dir = 'F:/git_repo/WKN_SSO/new_data/Test/'

train_dataset = BatteryDataSet(train_dir)