In [1]:
import os
import torch
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from airfrans.simulation import Simulation
import numpy as np
import json

class AirfRANSDataset(Dataset):

    def __init__(self, root, task='scarce', train=True):
        self.root = root
        self.task = task
        self.train = train

        taskk = 'full' if task == 'scarce' and not train else task
        split = 'train' if train else 'test'

        with open(os.path.join(root, 'manifest.json'), 'r') as f:
            manifest = json.load(f)[f"{taskk}_{split}"]

        self.names = manifest
        self.samples = []
        for name in tqdm(manifest, desc=f'Loading AirfRANS ({taskk}, {split})'):
            sim = Simulation(root=root, name=name)

            inlet_velocity = (
                np.array([np.cos(sim.angle_of_attack), np.sin(sim.angle_of_attack)])
                * sim.inlet_velocity
            ).reshape(1, 2) * np.ones_like(sim.sdf)

            data = np.concatenate([
                sim.position,
                inlet_velocity,
                sim.sdf,
                sim.normals,
                sim.velocity,
                sim.pressure,
                sim.nu_t,
                sim.surface.reshape(-1, 1)
            ], axis=-1)

            self.samples.append(torch.tensor(data, dtype=torch.float32))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        data = self.samples[idx]
        name = self.names[idx]

        x = data[:, :7]
        y = data[:, 7:11]

        return x, y, name

root = r"C:\airfran\Dataset"
dataset = AirfRANSDataset(root, task='scarce', train=True)
loader = DataLoader(dataset, batch_size=1, shuffle=True)

for x, y, name in loader:
    print(name, x.shape, y.shape)
    break


Loading AirfRANS (scarce, train): 100%|██████████| 200/200 [00:44<00:00,  4.45it/s]

('airFoil2D_SST_40.175_10.442_1.962_3.054_0.0_6.878',) torch.Size([1, 197386, 7]) torch.Size([1, 197386, 4])



