In [1]:
import torch
from torch.utils.data import DataLoader
from neuralop.models import GINO
from tqdm import tqdm
from cfd_dataset import *
from tensorly import tenalg
tenalg.set_backend("einsum")

def train_gino(model, train_loader, val_loader, input_geom, latent_queries, num_epochs, learning_rate, device):
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    criterion = torch.nn.MSELoss()
    
    for epoch in range(num_epochs):
        model.train()
        train_loss = 0
        print('here')

        for x, y, output_queries in tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}'):
        # for (x, y, output_queries) in enumerate(train_loader):
            x = x.squeeze(0)
            y = y.squeeze(0)
            output_queries = output_queries.squeeze(0)
            x, y = x.to(device), y.to(device)
            output_queries = output_queries.to(device)
            
            x.requires_grad_(True)  # Enable gradient computation for x
            
            # print(f"x shape: {x.shape}")
            # print(f"y shape: {y.shape}")
            # print(f"input_geom shape: {input_geom.shape}")
            # print(f"latent_queries shape: {latent_queries.shape}")
            # print(f"output_queries shape: {output_queries.shape}")

            optimizer.zero_grad()
            ada_in = torch.randn(1, device=device)  # Random ada_in for each batch
            output = model(x, input_geom, latent_queries, output_queries, ada_in=ada_in)
            loss = criterion(output, y)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()

            # Check if all parameters have gradients
            n_unused_params = sum(1 for param in model.parameters() if param.grad is None)
            assert n_unused_params == 0, f"{n_unused_params} parameters were unused!"

            # Check if output is finite
            assert output.isfinite().all(), "Output contains non-finite values!"

            if x.shape[0] > 1:
                # assert x[1:] accumulates no grad
                assert not x.grad[1:].nonzero().any()

        train_loss /= len(train_loader)

        # Validation
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for x, y, output_queries in val_loader:
                x = x.squeeze(0)
                y = y.squeeze(0)
                output_queries = output_queries.squeeze(0)
                x, y = x.to(device), y.to(device)
                output_queries = output_queries.to(device)
                ada_in = torch.randn(1, device=device)
                output = model(x, input_geom, latent_queries, output_queries, ada_in=ada_in)
                loss = criterion(output, y)
                val_loss += loss.item()

        val_loss /= len(val_loader)

        print(f'Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')

    return model

# def main():
# Hyperparameters
batch_size = 1
num_epochs = 10
learning_rate = 1e-3
num_train_samples = 250  # Adjust based on your dataset
num_val_samples = 50     # Adjust based on your dataset

# Create datasets
train_dataset, val_dataset = create_datasets('data/', num_train_samples=250, num_val_samples=50, shuffle=True, seed=42)


# Wrap datasets to only return x, y, and output_queries
train_dataset_wrapped = CFDDatasetWrapper(train_dataset)
val_dataset_wrapped = CFDDatasetWrapper(val_dataset)

# Create data loaders
train_loader = DataLoader(train_dataset_wrapped, batch_size=batch_size, shuffle=True, num_workers=1)
val_loader = DataLoader(val_dataset_wrapped, batch_size=batch_size, shuffle=False, num_workers=1)

# Get input_geom and latent_queries from the first item of the dataset
_, _, input_geom, latent_queries, _ = train_dataset[0]

# Initialize GINO model
model = GINO(
    in_channels=1,  # 1 for pressure
    out_channels=1,  # 1 for smoke concentration
    gno_coord_dim=3,
    gno_radius=0.3,  # Added from test case
    projection_channels=16,  # Added from test case
    in_gno_mlp_hidden_layers=[16, 16],  # Added from test case
    out_gno_mlp_hidden_layers=[16, 16],  # Added from test case
    in_gno_transform_type="nonlinear",  # Added from test case
    out_gno_transform_type="nonlinear",  # Added from test case
    fno_n_modes=(16, 16, 16),
    fno_hidden_channels=64,
    fno_lifting_channels=16,  # Added from test case
    fno_projection_channels=16,  # Added from test case
    fno_norm="ada_in",  # Added from test case
)

# Train the model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Move input_geom and latent_queries to the device
input_geom = input_geom.to(device)
latent_queries = latent_queries.to(device)
    
trained_model = train_gino(model, train_loader, val_loader, input_geom, latent_queries, num_epochs, learning_rate, device)

# Save the trained model
torch.save(trained_model.state_dict(), 'gino_cfd_model.pth')

# if __name__ == '__main__':
#     main()

Using device: cpu
here


Epoch 1/10: 100%|██████████| 250/250 [01:01<00:00,  4.04it/s]


Epoch 1/10, Train Loss: 0.4203, Val Loss: 0.0932
here


Epoch 2/10: 100%|██████████| 250/250 [01:00<00:00,  4.13it/s]


Epoch 2/10, Train Loss: 0.0937, Val Loss: 0.0452
here


Epoch 3/10: 100%|██████████| 250/250 [00:58<00:00,  4.25it/s]


Epoch 3/10, Train Loss: 0.0395, Val Loss: 0.0233
here


Epoch 4/10: 100%|██████████| 250/250 [00:57<00:00,  4.37it/s]


Epoch 4/10, Train Loss: 0.0298, Val Loss: 0.0150
here


Epoch 5/10: 100%|██████████| 250/250 [00:57<00:00,  4.32it/s]


Epoch 5/10, Train Loss: 0.0211, Val Loss: 0.0244
here


Epoch 6/10: 100%|██████████| 250/250 [00:58<00:00,  4.30it/s]


Epoch 6/10, Train Loss: 0.0270, Val Loss: 0.0145
here


Epoch 7/10: 100%|██████████| 250/250 [01:00<00:00,  4.14it/s]


Epoch 7/10, Train Loss: 0.0140, Val Loss: 0.0092
here


Epoch 8/10: 100%|██████████| 250/250 [00:56<00:00,  4.43it/s]


Epoch 8/10, Train Loss: 0.0239, Val Loss: 0.0108
here


Epoch 9/10: 100%|██████████| 250/250 [00:55<00:00,  4.48it/s]


Epoch 9/10, Train Loss: 0.0277, Val Loss: 0.0127
here


Epoch 10/10: 100%|██████████| 250/250 [00:57<00:00,  4.32it/s]


Epoch 10/10, Train Loss: 0.0089, Val Loss: 0.0202


In [2]:
dataset = CFDDataset('data/', num_train_samples)
velocity = dataset[0][0]  # Try to load the first item


TypeError: 'int' object is not iterable