In [12]:
import numpy as np
import torch
import pandas as pd
import pathlib
import os
from pathlib import Path

In [13]:
class ChallengeDataset1():
    def __init__(self,lc_path, params_path=None, transform=None, start_ind=0,
                 max_size=int(1e9), shuffle=True, seed=None, device=None):
        self.lc_path = lc_path # 数据所在的文件夹
        self.transform = transform # 数据的处理方法
        self.device = device # 设备
        self.files = sorted(
            [p for p in os.listdir(self.lc_path) if p.endswith('txt')]) # 文件夹下的文件名
        if shuffle:
            np.random.seed(seed)
            np.random.shuffle(self.files)
        self.files = self.files[start_ind:start_ind+max_size] # 选择部分，并且shuffle

        if params_path is not None:
            self.params_path = params_path   # 如果有标签，那么读取标签。
        else:
            self.params_path = None   # 否则没有标签，设为none。
            self.params_files = None


    def __getitem__(self, idx: int):
        item_lc_path = Path(self.lc_path) / self.files[idx]
        lc = np.loadtxt(item_lc_path)
        if self.transform:
            lc = self.transform(lc)
        if self.params_path is not None:
            item_params_path = Path(self.params_path) / self.files[idx]
            target = np.loadtxt(item_params_path)
        else:
            target = torch.Tensor()
        #lc 表示模型的输入，从noisy_train中可以得到。
        #target表示模型的输出，从params_train中可以得到。
        return lc
    def __len__(self):
        return len(self.files)

In [41]:
dataset_train = ChallengeDataset1("data/noisy_train/home/ucapats/Scratch/ml_data_challenge/training_set/noisy_train"
                                  ,"data/params_train/home/ucapats/Scratch/ml_data_challenge/training_set/params_train", shuffle=True, start_ind=0,
                                   max_size=10000, transform=lambda x:x, device='cpu')

In [42]:
a=np.array(dataset_train)

In [33]:
a.mean()

0.9940419875550011

In [34]:
a.std()

0.0367698530490506

In [38]:
a.mean(axis=(0,2))

array([0.99421436, 0.99422582, 0.99416736, 0.99409498, 0.99413781,
       0.99403749, 0.99409564, 0.99409953, 0.99402775, 0.99406303,
       0.99406908, 0.99409773, 0.99407221, 0.99408032, 0.99413919,
       0.99412296, 0.99415566, 0.99401831, 0.99405101, 0.9940012 ,
       0.9940578 , 0.99401299, 0.99399763, 0.99401205, 0.99402021,
       0.99398729, 0.99388536, 0.99392285, 0.99398041, 0.99395112,
       0.99400585, 0.99407681, 0.99398763, 0.99402793, 0.99402445,
       0.99394029, 0.99398748, 0.99393739, 0.99408596, 0.99396185,
       0.99406379, 0.99403255, 0.99392942, 0.99418762, 0.99413247,
       0.99398518, 0.99400593, 0.99405125, 0.99401915, 0.99396681,
       0.99406159, 0.99405765, 0.99384969, 0.99405496, 0.99407647])

In [40]:
a.std(axis=(0,2))

array([0.03583693, 0.03569331, 0.03566646, 0.03569547, 0.03570023,
       0.03568121, 0.03567544, 0.03568862, 0.03573369, 0.03610245,
       0.03614205, 0.03597644, 0.03601348, 0.03604199, 0.03602117,
       0.03607119, 0.03599659, 0.03616262, 0.03604043, 0.03618137,
       0.03619029, 0.03627466, 0.03620288, 0.03621327, 0.03621314,
       0.0363212 , 0.0363201 , 0.03625863, 0.03624343, 0.03640044,
       0.03636861, 0.03654736, 0.03646759, 0.03651567, 0.03662979,
       0.03665317, 0.03675937, 0.03694968, 0.03673388, 0.03701295,
       0.0370219 , 0.03729575, 0.03824422, 0.0495555 , 0.03598641,
       0.03604815, 0.03611789, 0.03622259, 0.03632392, 0.03647351,
       0.0366248 , 0.03676271, 0.0371457 , 0.03777215, 0.04406461])

In [43]:
a-a.mean()/a.std()

array([[[-26.41593102, -26.19880869, -26.10683045, ..., -25.92054633,
         -25.92166634, -25.91968471],
        [-26.41641217, -26.1998755 , -26.10756768, ..., -25.9205149 ,
         -25.92026549, -25.92002951],
        [-26.41641633, -26.19825949, -26.10678871, ..., -25.91955601,
         -25.9202042 , -25.91979524],
        ...,
        [-26.41404933, -26.20178157, -26.11308355, ..., -25.92223501,
         -25.92393554, -25.92292873],
        [-26.41569081, -26.19573989, -26.10831601, ..., -25.92341719,
         -25.91942326, -25.92286722],
        [-26.41420512, -26.19167467, -26.09889997, ..., -25.91930343,
         -25.92086405, -25.91770417]],

       [[-26.40983641, -26.19451715, -26.10778698, ..., -25.91388766,
         -25.92199335, -25.91533093],
        [-26.41803203, -26.19786132, -26.10637103, ..., -25.92019892,
         -25.91593025, -25.92275076],
        [-26.41887262, -26.20098119, -26.1068008 , ..., -25.92068114,
         -25.91726425, -25.92066834],
        ...,
