In [1]:
import numpy as np

import torch
import torch.nn as nn
import torch.nn.init as init
import torch.optim as optim

from torch.utils.data import Dataset, DataLoader

from tqdm import tqdm

from src.dataloaders import load_data_dict_from_yaml
from src.preprocesses import VTFPreprocessor, InfodrawPreprocessor, TargetPreprocessor, ImagePreprocessor

In [None]:
def preprocess_fpath_target_data(input_yaml, output_npz):
    yaml_data = load_data_dict_from_yaml(input_yaml)

    fpath_list, target_list = [], []
    for data in tqdm(yaml_data):
        vtf         = VTFPreprocessor.get(data['vtf'])
        infodraw    = InfodrawPreprocessor.get(data['infodraw'])
        target      = TargetPreprocessor.get(data['target'])
        
        _, H, W = infodraw.shape
        
        for h in range(H):
            for w in range(W):
                if infodraw[0, h, w] < 0.99: # threshold 0.99 is magic number
                    fpath_list.append(vtf[:, h, w])
                    target_list.append(target[0, h, w])

    print(f"total nums: {len(fpath_list)}")
    np_fpath_list = np.array(fpath_list)
    np_target_list = np.array(target_list)
    np_target_list = np.expand_dims(np_target_list, axis=1)

    np_data = np.concatenate([np_fpath_list, np_target_list], axis=1)
    np.savez(output_npz, np_data)

In [None]:
preprocess_fpath_target_data(input_yaml="dataset/val.yaml", output_npz="dataset/val.npz")
preprocess_fpath_target_data(input_yaml="dataset/test.yaml", output_npz="dataset/test.npz")
preprocess_fpath_target_data(input_yaml="dataset/train_small.yaml", output_npz="dataset/train_small.npz")

In [2]:
data = np.load("dataset/train_small.npz", allow_pickle=True)["arr_0"]

In [3]:
data.shape

(46859895, 22)