In [37]:
import torch
import pandas as pd
import numpy as np
from torch.utils.data import Dataset, DataLoader, random_split

np.random.seed(3407)
torch.manual_seed(3407)
torch.cuda.manual_seed_all(3407)
torch.backends.cudnn.deterministic = True  # 保证每次结果一样
torch.backends.cudnn.benchmark = False

In [38]:
# 自定义数据集类，需要实现__len__和__getitem__方法
class CustomDataset(Dataset):
    def __init__(self, data_path):
        self.data = pd.read_csv(data_path)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, i):
        return self.data.iloc[i, 1], self.data.iloc[i, 2]


In [39]:
# 创建自定义数据集
dataset = CustomDataset(data_path='./ship_data/experiment_data.csv')

# 定义拆分比例
train_size = int(0.9 * len(dataset))
val_size = int(0.05 * len(dataset))
test_size = len(dataset) - train_size - val_size

# 使用random_split函数拆分数据集
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])

# 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=2)
test_loader = DataLoader(test_dataset, batch_size=2)

# 现在，你可以使用train_loader、val_loader和test_loader来迭代你的数据集进行训练、验证和测试。

In [55]:
len(test_dataset)

118658