In [1]:
import os
import sys
import json
import time
import wave
import random
import numpy as np
from typing import Dict, Tuple

import torch
import torchaudio
from torch.utils.data import DataLoader

from torchvision import transforms
from PIL import Image
import PIL


sys.path.append('../training')
from utils import data_utils, utils, audio_utils
from datasets.loader import Dataset

import matplotlib.pyplot as plt
import IPython.display as ipd

In [2]:
# main_dir = '/home/jaejun/nansy/
config_path = '../training/configs/f2v.json'
with open(config_path, "r") as f:
    data = f.read()
config = json.loads(data)
args = utils.HParams(**config)

In [3]:
class TMP_Dataset(torch.utils.data.Dataset):
    def __init__(self,
                args,
                meta_root = 'filelists',
                mode='train',
                img_datasets=['VGG_Face'],
                sample_rate = 16000, 
                ):
        self.args = args
        self.mode = mode
        self.img_datasets = img_datasets
        self.sample_rate = sample_rate
        self.max_sec = 4
        self.max_len = sample_rate * self.max_sec
        self.data_files = []
        for dset in img_datasets:
            meta_file_path = os.path.join(meta_root, '{}_{}.txt').format(dset, mode)
            files = data_utils.load_text(meta_file_path)
            self.data_files += files
        self.data_files_len = len(self.data_files)
        self.trans = transforms.Compose([transforms.Resize((args.features.image.size,args.features.image.size), interpolation=PIL.Image.BICUBIC),
                transforms.CenterCrop(args.features.image.size), transforms.ToTensor()])
        
    def get_image(self, index):
        img = Image.open()
        
    def __getitem__(self, index):
        img_path = self.data_files[index]
        img = Image.open(img_path)
        img_tensor = self.trans(img)
        return img_tensor

    def __len__(self):
        return len(self.data_files)


In [4]:
tmpset = TMP_Dataset(args, img_datasets=["VGG_Face"], meta_root='../training/filelists/VGG_Face')

In [5]:
def calculate_stdv(dataset):
    # dataset의 axis=1, 2에 대한 평균 산출
    std_ = np.array([np.std(x.numpy(), axis=(1, 2)) for x in dataset])
    # r, g, b 채널에 대한 각각의 표준편차 산출
    std_r = std_[:, 0].mean()
    std_g = std_[:, 1].mean()
    std_b = std_[:, 2].mean()

    return (std_r, std_g, std_b)

In [6]:
calculate_stdv(tmpset)

(0.26429313, 0.23981343, 0.23351036)