In [None]:
import torch
from torch.utils.data import DataLoader,Dataset
from torchvision.transforms import transforms
from path import Path
from PIL import Image

images=Path(r'JPEGImages')

class CatData(Dataset):
    def __init__(self,files,transformer):
        super().__init__()
        self.transformer=transformer
        self.files=files
          
    def __getitem__(self, index):
        # 读取图片
        image_path = self.files[index]
        image = Image.open(image_path)      
        features = self.transformer(image)
        return features, 0
    
    def __len__(self):
        return len(self.files)

trans = transforms.Compose([
    transforms.Resize((410, 410)),
    transforms.ToTensor()
])

train_dataset = CatData(images.files(), trans)
train_loader = DataLoader(dataset=train_dataset, batch_size=16, shuffle=True)


def get_mean_std(loader):
    # Var[x] = E[X**2]-E[X]**2
    channels_sum, channels_squared_sum, num_batches = 0, 0, 0
    for data, _ in loader:
        channels_sum += torch.mean(data, dim=[0, 2, 3])
        channels_squared_sum += torch.mean(data ** 2, dim=[0, 2, 3])
        num_batches += 1

    # print(num_batches)
    # print(channels_sum)
    mean = channels_sum / num_batches
    std = (channels_squared_sum / num_batches - mean ** 2) ** 0.5

    return mean, std


mean, std = get_mean_std(train_loader)

print(mean)
print(std)