In [1]:
import numpy as np
import torch
import os
import sys
import random
from PIL import Image
from torch.utils.data import Dataset

DATASIZE = 2000
# !!!如果改变此参数,split_datasets中的也需要改
TRAIN_PRECENT = 0.8
CLASS_NUM = 10


class handWritten_Dataset(Dataset):
    def __init__(self, data_path, transform=None, train=True):
        """
        mfeat-fou的Dataset
        :param data_path: str, 数据集所在路径
        :param transform: torch.transform，数据预处理
        """

        self.label_name = {"0": 0, "1": 1}
        self.class_size = self.get_class_size(train)
        self.data_info = self.get_img_info(data_path)
        self.transform = transform


    def __getitem__(self, index):
        data, label = self.data_info[index]

        if self.transform is not None:
            data = self.transform(data)

        return data, label, index

    def __len__(self):
        return len(self.data_info)

    # @staticmethod
    def get_class_size(self, train):
        class_size = int(DATASIZE*TRAIN_PRECENT/CLASS_NUM)
        if not train:
            class_size = 200 - class_size
        return class_size

    # @staticmethod
    def get_img_info(self, file_path):
        data_info = list()
        train_data = np.load(file_path)
        for i, data in enumerate(train_data):
            label = int(i / self.class_size)
            data_info.append((data, label))

        # print(data_info[0])
        return data_info

    def get_features_size(self):
        return len(self.data_info[0][0])




## 创建一个规律的dataset,以便观察

In [2]:
dataset = [[i] for i in range(1600)]
dataset = np.array(dataset)

mydataset_path = './mydataset.npy' 
np.save(mydataset_path, dataset)

dataset = torch.tensor(dataset).to("cuda:0")
print(dataset)

tensor([[   0],
        [   1],
        [   2],
        ...,
        [1597],
        [1598],
        [1599]], device='cuda:0')


In [3]:
file_name = 'mfeat-zer'
# train_file_path = '../datasets/Handwritten/mfeat/splite/' + file_name + '/train_data.npy'
train_file_path = mydataset_path

In [4]:
train_data = handWritten_Dataset(data_path=train_file_path)
train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size=2, shuffle=True)

In [6]:
# total = torch.tensor()

count = 2
for i, data in enumerate(train_loader):
    print("# ============================ the {} data".format(i))
    print(data)
    
    _, _, index = data
    
    index.to("cuda:0")
    print("get reconstruct represesnetation")
#     print(dataset[index.numpy()])
    print(dataset[index])

    if i > count:
        break

[tensor([[ 361],
        [1536]]), tensor([2, 9]), tensor([ 361, 1536])]
get reconstruct represesnetation
tensor([[ 361],
        [1536]], device='cuda:0')
[tensor([[ 458],
        [1307]]), tensor([2, 8]), tensor([ 458, 1307])]
get reconstruct represesnetation
tensor([[ 458],
        [1307]], device='cuda:0')
[tensor([[425],
        [170]]), tensor([2, 1]), tensor([425, 170])]
get reconstruct represesnetation
tensor([[425],
        [170]], device='cuda:0')
[tensor([[678],
        [856]]), tensor([4, 5]), tensor([678, 856])]
get reconstruct represesnetation
tensor([[678],
        [856]], device='cuda:0')


In [8]:
# total = torch.tensor()

count = 4
for i, data in enumerate(train_loader):
    print("# ============================ the {} data".format(i))
    print(data[2])
    _, _, index = data
    print(index)
    
  

tensor([354, 564])
tensor([354, 564])
tensor([949, 783])
tensor([949, 783])
tensor([  2, 866])
tensor([  2, 866])
tensor([ 384, 1546])
tensor([ 384, 1546])
tensor([1136, 1225])
tensor([1136, 1225])
tensor([1033,  529])
tensor([1033,  529])
tensor([ 511, 1030])
tensor([ 511, 1030])
tensor([1192,  891])
tensor([1192,  891])
tensor([1159,  952])
tensor([1159,  952])
tensor([443, 249])
tensor([443, 249])
tensor([ 725, 1302])
tensor([ 725, 1302])
tensor([302, 513])
tensor([302, 513])
tensor([1517, 1391])
tensor([1517, 1391])
tensor([802, 785])
tensor([802, 785])
tensor([600, 585])
tensor([600, 585])
tensor([ 43, 664])
tensor([ 43, 664])
tensor([ 483, 1142])
tensor([ 483, 1142])
tensor([1504, 1433])
tensor([1504, 1433])
tensor([1185,  280])
tensor([1185,  280])
tensor([ 64, 363])
tensor([ 64, 363])
tensor([1303,  774])
tensor([1303,  774])
tensor([543, 391])
tensor([543, 391])
tensor([970, 728])
tensor([970, 728])
tensor([1575, 1029])
tensor([1575, 1029])
tensor([1005, 1064])
tensor([1005, 1

tensor([1516, 1245])
tensor([1516, 1245])
tensor([531, 477])
tensor([531, 477])
tensor([1558,  408])
tensor([1558,  408])
tensor([1579,  147])
tensor([1579,  147])
tensor([ 334, 1087])
tensor([ 334, 1087])
tensor([1375,  741])
tensor([1375,  741])
tensor([1505,   71])
tensor([1505,   71])
tensor([460, 190])
tensor([460, 190])
tensor([1236, 1300])
tensor([1236, 1300])
tensor([812, 319])
tensor([812, 319])
tensor([1396,  931])
tensor([1396,  931])
tensor([227, 428])
tensor([227, 428])
tensor([1190, 1501])
tensor([1190, 1501])
tensor([720,  78])
tensor([720,  78])
tensor([1428, 1256])
tensor([1428, 1256])
tensor([ 671, 1598])
tensor([ 671, 1598])
tensor([1081,  294])
tensor([1081,  294])
tensor([1254, 1413])
tensor([1254, 1413])
tensor([1069,  134])
tensor([1069,  134])
tensor([117, 977])
tensor([117, 977])
tensor([882,  39])
tensor([882,  39])
tensor([1401, 1090])
tensor([1401, 1090])
tensor([809, 371])
tensor([809, 371])
tensor([1372, 1333])
tensor([1372, 1333])
tensor([1512, 1278])
ten

tensor([797, 991])
tensor([797, 991])
tensor([ 405, 1536])
tensor([ 405, 1536])
tensor([ 583, 1595])
tensor([ 583, 1595])
tensor([332,   4])
tensor([332,   4])
tensor([1023,  195])
tensor([1023,  195])
tensor([484, 196])
tensor([484, 196])
tensor([825, 867])
tensor([825, 867])
tensor([251,  77])
tensor([251,  77])
tensor([1458, 1497])
tensor([1458, 1497])
tensor([524, 565])
tensor([524, 565])
tensor([414, 142])
tensor([414, 142])
tensor([556, 244])
tensor([556, 244])
tensor([ 115, 1257])
tensor([ 115, 1257])
tensor([1535,  494])
tensor([1535,  494])
tensor([1599,  515])
tensor([1599,  515])
tensor([519, 508])
tensor([519, 508])
tensor([ 989, 1589])
tensor([ 989, 1589])
tensor([534, 157])
tensor([534, 157])
tensor([ 516, 1439])
tensor([ 516, 1439])
tensor([365, 799])
tensor([365, 799])
tensor([1414,  490])
tensor([1414,  490])
tensor([1022,  131])
tensor([1022,  131])
tensor([1145, 1317])
tensor([1145, 1317])
tensor([592, 178])
tensor([592, 178])
tensor([1252, 1489])
tensor([1252, 1489]

tensor([1492, 1057])
tensor([1547,  860])
tensor([1547,  860])
tensor([1350, 1452])
tensor([1350, 1452])
tensor([1091,  648])
tensor([1091,  648])
tensor([999, 420])
tensor([999, 420])
tensor([1182,   17])
tensor([1182,   17])
tensor([ 975, 1382])
tensor([ 975, 1382])
tensor([ 630, 1387])
tensor([ 630, 1387])
tensor([1293,   55])
tensor([1293,   55])
tensor([1107,  900])
tensor([1107,  900])
tensor([743,  37])
tensor([743,  37])
tensor([ 966, 1124])
tensor([ 966, 1124])
tensor([  12, 1282])
tensor([  12, 1282])
tensor([983, 330])
tensor([983, 330])
tensor([339, 454])
tensor([339, 454])
tensor([422, 520])
tensor([422, 520])
tensor([1330,  400])
tensor([1330,  400])
tensor([1534,  226])
tensor([1534,  226])
tensor([220, 376])
tensor([220, 376])
tensor([1041,  982])
tensor([1041,  982])
tensor([120,  68])
tensor([120,  68])
tensor([1376, 1093])
tensor([1376, 1093])
tensor([ 661, 1199])
tensor([ 661, 1199])
tensor([791, 239])
tensor([791, 239])
tensor([1525, 1305])
tensor([1525, 1305])
ten

In [14]:
arr2 = np.array([[7,8,9], [10,11,12]])
data = np.empty((2,3), dtype = int) 
print(data)


[[ 7  8  9]
 [10 11 12]]


In [15]:
data = np.empty((2,3), dtype = int) 
print(data)

[[ 7  8  9]
 [10 11 12]]


In [12]:

arr1 = np.array([[1,2,3], [4,5,6]])
arr2 = np.array([[7,8,9], [10,11,12]])
data = np.empty(arr2.shape, dtype = int) 
print(data)
data = np.concatenate([data, arr2], axis=0) # axis参数指明合并的轴向，0表示按行，1表示按列
print(data.shape)
data

[[ 7  8  9]
 [10 11 12]]
(4, 3)


array([[ 7,  8,  9],
       [10, 11, 12],
       [ 7,  8,  9],
       [10, 11, 12]])

In [61]:
data = np.concatenate([data, arr2], axis=0) # axis参数指明合并的轴向，0表示按行，1表示按列
data

array([[ 1,  2,  3],
       [ 4,  5,  6],
       [ 7,  8,  9],
       [10, 11, 12],
       [ 7,  8,  9],
       [10, 11, 12],
       [ 7,  8,  9],
       [10, 11, 12]])