In [None]:
### Make torch.dataset ###
# takes two separate tensors as input
dataset = TensorDataset(data[:, :2], data[:, 2:])

# Inspect
# Shape of (first) X sample
print(dataset[0][0].shape)
# Shape of (first) Y sample
print(dataset[0][1].shape)

### Define sizes n observations for splits ###
train_size = int(0.7 * dataset.__len__())
test_size = dataset.__len__() - train_size
print(f"Train size: {train_size}, Test size: {test_size}")

### Random split ###
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

### Create DataLoaders ###
train_loader = DataLoader(train_dataset, batch_size = 1024, shuffle = True)
test_loader = DataLoader(test_dataset, batch_size = 1024, shuffle = False)
all_loader = DataLoader(dataset, batch_size = 1024, shuffle = True)

In [None]:
# ----- Training loop -----
if RETRAIN:
    model = HelmholtzResNN().to(device)
    # 5e-3 was a bit jittery
    optim = torch.optim.AdamW(model.parameters(), lr = 1e-3, weight_decay = 1e-7)
    loss_function = nn.MSELoss()

    # epochs = 10
    epochs = 30
    # print every N batches
    # print_every = 500

    train_losses = []
    test_losses  = []

    for ep in range(1, epochs + 1):
        # ------------------ TRAIN ------------------
        model.train()
        train_loss_sum = 0.0
        
        # NOTE: Here on train only
        for i, (X_batch, Y_batch) in enumerate(train_loader):
            X_batch = X_batch.to(device)
            Y_batch = Y_batch.to(device)

            optim.zero_grad(set_to_none = True)
            Y_hat = model(X_batch)
            loss = loss_function(Y_hat, Y_batch)
            loss.backward()
            # torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optim.step()

            train_loss_sum += loss.item() * X_batch.size(0)

        epoch_avg_train_loss = train_loss_sum / len(train_loader.dataset)

        # ------------------ EVAL ------------------
        model.eval()
        test_loss_sum = 0.0

        for i, (X_batch, Y_batch) in enumerate(test_loader):
            X_batch = X_batch.to(device)
            Y_batch = Y_batch.to(device)

            # need requires_grad for autograd in model
            X_batch_grad = X_batch.clone().detach().requires_grad_(True)

            Y_hat = model(X_batch_grad)
            loss = loss_function(Y_hat, Y_batch)
            test_loss_sum += loss.item() * X_batch.size(0)

        epoch_avg_test_loss = test_loss_sum / len(test_loader.dataset)

        # store for plotting later
        train_losses.append(epoch_avg_train_loss)
        test_losses.append(epoch_avg_test_loss)

        # Print only every epoch: ~ 1.5 min per epoch
        # ~14 min for 10 epochs on GPU
        print(f"[epoch {ep:03d}] train_loss = {epoch_avg_train_loss:.6f} | test_loss = {epoch_avg_test_loss:.6f}")

In [None]:
if RETRAIN:
    torch.save(model.state_dict(), "trained_model/helmholtz_resnn.pth")
    pd.DataFrame({'train_loss': train_losses, 'test_loss': test_losses}).to_csv("trained_model/helmholtz_resnn_loss.csv", index=False)

In [None]:
if RETRAIN:
    # Assume you already have train_losses and test_losses lists
    epochs = range(1, len(train_losses) + 1)

    plt.figure(figsize = (8, 6))
    plt.plot(epochs, train_losses, label = "Train Loss", marker = "o")
    plt.plot(epochs, test_losses, label = "Test Loss", marker = "s")

    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("Training and Test Loss")
    plt.legend()

    plt.grid(True, linestyle = "--", alpha = 0.6)
    plt.ylim(0, 0.5)
    plt.show()