In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
base_path = '../../'
sys.path.append(base_path)

In [26]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import numpy as np
from tqdm import tqdm

# 3D U-Net with a small parameter count
class SmallUNet3D(nn.Module):
    def __init__(self, in_channels=1, out_channels=5, base_filters=16):
        super(SmallUNet3D, self).__init__()
        self.enc1 = nn.Conv3d(in_channels, base_filters, kernel_size=3, padding=1)
        self.enc2 = nn.Conv3d(base_filters, base_filters * 2, kernel_size=3, padding=1)
        
        self.pool = nn.MaxPool3d(2)
        self.upsample = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True)
        
        self.dec1 = nn.Conv3d(base_filters * 2, base_filters, kernel_size=3, padding=1)
        self.dec2 = nn.Conv3d(base_filters, out_channels, kernel_size=1)  # Final regression output

        self.act = nn.ReLU()
    
    def forward(self, x):
        e1 = self.act(self.enc1(x))
        e2 = self.act(self.enc2(self.pool(e1)))
        
        d1 = self.act(self.dec1(self.upsample(e2)))
        d2 = self.dec2(d1)  # No activation for regression
        
        return d2




# Initialize model and train
model = SmallUNet3D()


In [18]:
from src.pyvista_flow_field_dataset import PyvistaFlowFieldDataset
from src.voxel_flow_field_dataset import VoxelFlowFieldDataset, VoxelFlowFieldDatasetConfig
ds_pv = PyvistaFlowFieldDataset.load_from_huggingface(num_samples=3)
ds_voxel = VoxelFlowFieldDataset('datasets/voxels',VoxelFlowFieldDatasetConfig(ds_pv))
#ds_voxel = VoxelFlowFieldDataset('datasets/voxels')
ds_voxel.normalize()

ValueError: Error getting the list of files in repository 'peteole/CoolMucSmall': (MaxRetryError('HTTPSConnectionPool(host=\'huggingface.co\', port=443): Max retries exceeded with url: /api/datasets/peteole/CoolMucSmall/tree/main?recursive=True&expand=False (Caused by NameResolutionError("<urllib3.connection.HTTPSConnection object at 0x17e203d50>: Failed to resolve \'huggingface.co\' ([Errno 8] nodename nor servname provided, or not known)"))'), '(Request ID: 0a4141a8-33b2-4c66-823f-158fa69a5e0c)')

In [19]:

import torch.utils.data as data
dataloader = data.DataLoader(ds_voxel.get_default_loadable_dataset(), batch_size=3, shuffle=True)

In [29]:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
criterion = nn.MSELoss()
lr=1e-3
epochs=100
optimizer = optim.Adam(model.parameters(), lr=lr)

for epoch in range(epochs):
    model.train()
    epoch_loss = 0

    for x,y in tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}"):
        x, y = x.to(device), y.to(device)
        x=x.float().unsqueeze(-1)
        x = x.permute(0, 4, 1, 2, 3)  # (N, D, H, W, C) → (N, C, D, H, W)
        y = y.permute(0, 4, 1, 2, 3)
        optimizer.zero_grad()
        outputs = model(x)
        loss = criterion(outputs, y)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    print(f"Epoch {epoch+1}, Loss: {epoch_loss / len(dataloader)}")

Epoch 1/100: 100%|██████████| 1/1 [00:00<00:00,  2.83it/s]


Epoch 1, Loss: 0.6820264458656311


Epoch 2/100: 100%|██████████| 1/1 [00:00<00:00,  3.29it/s]


Epoch 2, Loss: 2.4455931186676025


Epoch 3/100: 100%|██████████| 1/1 [00:00<00:00,  3.19it/s]


Epoch 3, Loss: 0.819566547870636


Epoch 4/100: 100%|██████████| 1/1 [00:00<00:00,  3.31it/s]


Epoch 4, Loss: 1.006456732749939


Epoch 5/100: 100%|██████████| 1/1 [00:00<00:00,  3.51it/s]


Epoch 5, Loss: 1.4585212469100952


Epoch 6/100: 100%|██████████| 1/1 [00:00<00:00,  3.33it/s]


Epoch 6, Loss: 1.3300859928131104


Epoch 7/100: 100%|██████████| 1/1 [00:00<00:00,  3.46it/s]


Epoch 7, Loss: 1.012536883354187


Epoch 8/100: 100%|██████████| 1/1 [00:00<00:00,  3.51it/s]


Epoch 8, Loss: 0.795452892780304


Epoch 9/100: 100%|██████████| 1/1 [00:00<00:00,  3.49it/s]


Epoch 9, Loss: 0.7420609593391418


Epoch 10/100: 100%|██████████| 1/1 [00:00<00:00,  3.53it/s]


Epoch 10, Loss: 0.7963640093803406


Epoch 11/100: 100%|██████████| 1/1 [00:00<00:00,  3.52it/s]


Epoch 11, Loss: 0.8794823288917542


Epoch 12/100: 100%|██████████| 1/1 [00:00<00:00,  2.97it/s]


Epoch 12, Loss: 0.9254394769668579


Epoch 13/100: 100%|██████████| 1/1 [00:00<00:00,  3.48it/s]


Epoch 13, Loss: 0.9076719880104065


Epoch 14/100: 100%|██████████| 1/1 [00:00<00:00,  3.47it/s]


Epoch 14, Loss: 0.850120484828949


Epoch 15/100: 100%|██████████| 1/1 [00:00<00:00,  3.28it/s]


Epoch 15, Loss: 0.78790682554245


Epoch 16/100: 100%|██████████| 1/1 [00:00<00:00,  3.04it/s]


Epoch 16, Loss: 0.7477678656578064


Epoch 17/100: 100%|██████████| 1/1 [00:00<00:00,  2.27it/s]


Epoch 17, Loss: 0.7383223176002502


Epoch 18/100: 100%|██████████| 1/1 [00:00<00:00,  1.87it/s]


Epoch 18, Loss: 0.7528688907623291


Epoch 19/100: 100%|██████████| 1/1 [00:00<00:00,  2.30it/s]


Epoch 19, Loss: 0.7767745852470398


Epoch 20/100: 100%|██████████| 1/1 [00:00<00:00,  2.06it/s]


Epoch 20, Loss: 0.7953836917877197


Epoch 21/100: 100%|██████████| 1/1 [00:00<00:00,  2.05it/s]


Epoch 21, Loss: 0.8002187013626099


Epoch 22/100: 100%|██████████| 1/1 [00:00<00:00,  2.09it/s]


Epoch 22, Loss: 0.7906236052513123


Epoch 23/100: 100%|██████████| 1/1 [00:00<00:00,  1.76it/s]


Epoch 23, Loss: 0.7718489766120911


Epoch 24/100: 100%|██████████| 1/1 [00:00<00:00,  2.16it/s]


Epoch 24, Loss: 0.7515672445297241


Epoch 25/100: 100%|██████████| 1/1 [00:00<00:00,  2.08it/s]


Epoch 25, Loss: 0.7364195585250854


Epoch 26/100: 100%|██████████| 1/1 [00:00<00:00,  2.35it/s]


Epoch 26, Loss: 0.7298973202705383


Epoch 27/100: 100%|██████████| 1/1 [00:00<00:00,  2.11it/s]


Epoch 27, Loss: 0.7317894697189331


Epoch 28/100: 100%|██████████| 1/1 [00:00<00:00,  2.03it/s]


Epoch 28, Loss: 0.738735556602478


Epoch 29/100: 100%|██████████| 1/1 [00:00<00:00,  2.03it/s]


Epoch 29, Loss: 0.7462075352668762


Epoch 30/100: 100%|██████████| 1/1 [00:00<00:00,  2.17it/s]


Epoch 30, Loss: 0.7501504421234131


Epoch 31/100: 100%|██████████| 1/1 [00:00<00:00,  2.09it/s]


Epoch 31, Loss: 0.748520016670227


Epoch 32/100: 100%|██████████| 1/1 [00:00<00:00,  2.12it/s]


Epoch 32, Loss: 0.7417916655540466


Epoch 33/100: 100%|██████████| 1/1 [00:00<00:00,  2.29it/s]


Epoch 33, Loss: 0.7325006723403931


Epoch 34/100: 100%|██████████| 1/1 [00:00<00:00,  1.77it/s]


Epoch 34, Loss: 0.7239240407943726


Epoch 35/100: 100%|██████████| 1/1 [00:00<00:00,  2.07it/s]


Epoch 35, Loss: 0.7188690304756165


Epoch 36/100: 100%|██████████| 1/1 [00:00<00:00,  1.90it/s]


Epoch 36, Loss: 0.7180801630020142


Epoch 37/100: 100%|██████████| 1/1 [00:00<00:00,  1.97it/s]


Epoch 37, Loss: 0.7202916741371155


Epoch 38/100: 100%|██████████| 1/1 [00:00<00:00,  2.10it/s]


Epoch 38, Loss: 0.7231580018997192


Epoch 39/100: 100%|██████████| 1/1 [00:00<00:00,  2.17it/s]


Epoch 39, Loss: 0.7242196202278137


Epoch 40/100: 100%|██████████| 1/1 [00:00<00:00,  2.17it/s]


Epoch 40, Loss: 0.7223427891731262


Epoch 41/100: 100%|██████████| 1/1 [00:00<00:00,  2.08it/s]


Epoch 41, Loss: 0.7181198000907898


Epoch 42/100: 100%|██████████| 1/1 [00:00<00:00,  2.15it/s]


Epoch 42, Loss: 0.713248074054718


Epoch 43/100: 100%|██████████| 1/1 [00:00<00:00,  1.82it/s]


Epoch 43, Loss: 0.7094648480415344


Epoch 44/100: 100%|██████████| 1/1 [00:00<00:00,  1.78it/s]


Epoch 44, Loss: 0.7076796293258667


Epoch 45/100: 100%|██████████| 1/1 [00:00<00:00,  2.18it/s]


Epoch 45, Loss: 0.7075963616371155


Epoch 46/100: 100%|██████████| 1/1 [00:00<00:00,  2.33it/s]


Epoch 46, Loss: 0.708216667175293


Epoch 47/100: 100%|██████████| 1/1 [00:00<00:00,  2.10it/s]


Epoch 47, Loss: 0.7084781527519226


Epoch 48/100: 100%|██████████| 1/1 [00:00<00:00,  2.39it/s]


Epoch 48, Loss: 0.7077293395996094


Epoch 49/100: 100%|██████████| 1/1 [00:00<00:00,  1.90it/s]


Epoch 49, Loss: 0.7059589624404907


Epoch 50/100: 100%|██████████| 1/1 [00:00<00:00,  2.34it/s]


Epoch 50, Loss: 0.7037123441696167


Epoch 51/100: 100%|██████████| 1/1 [00:00<00:00,  2.00it/s]


Epoch 51, Loss: 0.7017093300819397


Epoch 52/100: 100%|██████████| 1/1 [00:00<00:00,  1.88it/s]


Epoch 52, Loss: 0.7004512548446655


Epoch 53/100: 100%|██████████| 1/1 [00:00<00:00,  1.73it/s]


Epoch 53, Loss: 0.6999898552894592


Epoch 54/100: 100%|██████████| 1/1 [00:00<00:00,  1.97it/s]


Epoch 54, Loss: 0.6999602913856506


Epoch 55/100: 100%|██████████| 1/1 [00:00<00:00,  1.74it/s]


Epoch 55, Loss: 0.6998646855354309


Epoch 56/100: 100%|██████████| 1/1 [00:00<00:00,  1.83it/s]


Epoch 56, Loss: 0.6993200182914734


Epoch 57/100: 100%|██████████| 1/1 [00:00<00:00,  1.88it/s]


Epoch 57, Loss: 0.6982825994491577


Epoch 58/100: 100%|██████████| 1/1 [00:00<00:00,  2.14it/s]


Epoch 58, Loss: 0.6970713138580322


Epoch 59/100: 100%|██████████| 1/1 [00:00<00:00,  2.18it/s]


Epoch 59, Loss: 0.6960994601249695


Epoch 60/100: 100%|██████████| 1/1 [00:00<00:00,  2.23it/s]


Epoch 60, Loss: 0.6955845355987549


Epoch 61/100: 100%|██████████| 1/1 [00:00<00:00,  2.14it/s]


Epoch 61, Loss: 0.6954137086868286


Epoch 62/100: 100%|██████████| 1/1 [00:00<00:00,  2.34it/s]


Epoch 62, Loss: 0.695264458656311


Epoch 63/100: 100%|██████████| 1/1 [00:00<00:00,  2.03it/s]


Epoch 63, Loss: 0.6948596835136414


Epoch 64/100: 100%|██████████| 1/1 [00:00<00:00,  2.26it/s]


Epoch 64, Loss: 0.6941652894020081


Epoch 65/100: 100%|██████████| 1/1 [00:00<00:00,  2.16it/s]


Epoch 65, Loss: 0.6933809518814087


Epoch 66/100: 100%|██████████| 1/1 [00:00<00:00,  2.18it/s]


Epoch 66, Loss: 0.6927607655525208


Epoch 67/100: 100%|██████████| 1/1 [00:00<00:00,  1.67it/s]


Epoch 67, Loss: 0.6924075484275818


Epoch 68/100: 100%|██████████| 1/1 [00:00<00:00,  2.07it/s]


Epoch 68, Loss: 0.6922082901000977


Epoch 69/100: 100%|██████████| 1/1 [00:00<00:00,  2.29it/s]


Epoch 69, Loss: 0.6919369101524353


Epoch 70/100: 100%|██████████| 1/1 [00:00<00:00,  1.90it/s]


Epoch 70, Loss: 0.6914701461791992


Epoch 71/100: 100%|██████████| 1/1 [00:00<00:00,  2.01it/s]


Epoch 71, Loss: 0.6908888220787048


Epoch 72/100: 100%|██████████| 1/1 [00:00<00:00,  1.79it/s]


Epoch 72, Loss: 0.6903788447380066


Epoch 73/100: 100%|██████████| 1/1 [00:00<00:00,  1.79it/s]


Epoch 73, Loss: 0.6900505423545837


Epoch 74/100: 100%|██████████| 1/1 [00:00<00:00,  1.89it/s]


Epoch 74, Loss: 0.689832866191864


Epoch 75/100: 100%|██████████| 1/1 [00:00<00:00,  1.83it/s]


Epoch 75, Loss: 0.6895773410797119


Epoch 76/100: 100%|██████████| 1/1 [00:00<00:00,  1.92it/s]


Epoch 76, Loss: 0.6892004013061523


Epoch 77/100: 100%|██████████| 1/1 [00:00<00:00,  1.73it/s]


Epoch 77, Loss: 0.6887378692626953


Epoch 78/100: 100%|██████████| 1/1 [00:00<00:00,  1.57it/s]


Epoch 78, Loss: 0.6882916688919067


Epoch 79/100: 100%|██████████| 1/1 [00:00<00:00,  1.84it/s]


Epoch 79, Loss: 0.687930703163147


Epoch 80/100: 100%|██████████| 1/1 [00:00<00:00,  1.90it/s]


Epoch 80, Loss: 0.6876388788223267


Epoch 81/100: 100%|██████████| 1/1 [00:00<00:00,  1.93it/s]


Epoch 81, Loss: 0.6873493790626526


Epoch 82/100: 100%|██████████| 1/1 [00:00<00:00,  1.99it/s]


Epoch 82, Loss: 0.6870127320289612


Epoch 83/100: 100%|██████████| 1/1 [00:00<00:00,  2.39it/s]


Epoch 83, Loss: 0.6866310834884644


Epoch 84/100: 100%|██████████| 1/1 [00:00<00:00,  1.97it/s]


Epoch 84, Loss: 0.6862481236457825


Epoch 85/100: 100%|██████████| 1/1 [00:00<00:00,  1.95it/s]


Epoch 85, Loss: 0.6859066486358643


Epoch 86/100: 100%|██████████| 1/1 [00:00<00:00,  2.00it/s]


Epoch 86, Loss: 0.6856057643890381


Epoch 87/100: 100%|██████████| 1/1 [00:00<00:00,  2.19it/s]


Epoch 87, Loss: 0.6853033900260925


Epoch 88/100: 100%|██████████| 1/1 [00:00<00:00,  2.15it/s]


Epoch 88, Loss: 0.6849656701087952


Epoch 89/100: 100%|██████████| 1/1 [00:00<00:00,  2.11it/s]


Epoch 89, Loss: 0.6845988035202026


Epoch 90/100: 100%|██████████| 1/1 [00:00<00:00,  1.84it/s]


Epoch 90, Loss: 0.6842367649078369


Epoch 91/100: 100%|██████████| 1/1 [00:00<00:00,  2.39it/s]


Epoch 91, Loss: 0.6839020848274231


Epoch 92/100: 100%|██████████| 1/1 [00:00<00:00,  2.07it/s]


Epoch 92, Loss: 0.6835809946060181


Epoch 93/100: 100%|██████████| 1/1 [00:00<00:00,  2.02it/s]


Epoch 93, Loss: 0.6832476258277893


Epoch 94/100: 100%|██████████| 1/1 [00:00<00:00,  2.11it/s]


Epoch 94, Loss: 0.6828892230987549


Epoch 95/100: 100%|██████████| 1/1 [00:00<00:00,  2.21it/s]


Epoch 95, Loss: 0.682553231716156


Epoch 96/100: 100%|██████████| 1/1 [00:00<00:00,  1.99it/s]


Epoch 96, Loss: 0.6822530031204224


Epoch 97/100: 100%|██████████| 1/1 [00:00<00:00,  2.09it/s]


Epoch 97, Loss: 0.681964635848999


Epoch 98/100: 100%|██████████| 1/1 [00:00<00:00,  2.11it/s]


Epoch 98, Loss: 0.6816698908805847


Epoch 99/100: 100%|██████████| 1/1 [00:00<00:00,  1.76it/s]


Epoch 99, Loss: 0.6813727021217346


Epoch 100/100: 100%|██████████| 1/1 [00:00<00:00,  1.98it/s]

Epoch 100, Loss: 0.6810876727104187



