# Classification with FNO
Train an FNO timeseries classifier on the FordA dataset from the UCR/UEA archive.

Much of this comes from the following Keras tutorial: https://keras.io/examples/timeseries/timeseries_classification_from_scratch/

The dataset we are using here is called FordA. The data comes from the UCR archive. The dataset contains 3601 training instances and another 1320 testing instances. Each timeseries corresponds to a measurement of engine noise captured by a motor sensor. For this task, the goal is to automatically detect the presence of a specific issue with the engine. The problem is a balanced binary classification task. The full description of this dataset can be found here: http://www.j-wichard.de/publications/FordPaper.pdf

Later, can include the features mentioned in the paper; namely the autocorrelation values and spectral density features as separate channels, akin to the work we will do later

In [5]:
import importlib
import model_utils
import data_utils
importlib.reload(model_utils)
importlib.reload(data_utils)
from model_utils import FNOClassifier
from data_utils import CustomDataset, RandomSample, RandomTimeTranslateFill0, RandomTimeTranslateReflect, RandomNoise

import os
import optuna
import datetime
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc

import torch
import torch.utils.data as data
from torch.utils.data import DataLoader
import torchvision.transforms.v2 as transforms

import pytorch_lightning as pl
from lightning import Trainer
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.callbacks import EarlyStopping, LearningRateMonitor
pl.__version__

'2.1.3'

In [6]:
# Check if CUDA is available
if torch.cuda.is_available():
    device = torch.device("cuda")
    print(f"GPU is available on device {torch.cuda.get_device_name(0)} with device count: {torch.cuda.device_count()}")
    os.environ["CUDA_VISIBLE_DEVICES"] = "0"
else:
    device = torch.device("cpu")
    print("GPU is not available")

GPU is available on device NVIDIA GeForce RTX 2080 Ti with device count: 1


## Load the data

In [25]:
# Read the data
def readucr(filename):
    data = np.loadtxt(filename, delimiter="\t")
    y = data[:, 0]
    x = data[:, 1:]
    return x, y.astype(int)

root_url = "https://raw.githubusercontent.com/hfawaz/cd-diagram/master/FordA/"

x_train, y_train = readucr(root_url + "FordA_TRAIN.tsv")
x_test, y_test = readucr(root_url + "FordA_TEST.tsv")

# Reshape the data to be ready for multivariate time-series data (multiple channels)
# Shape is (samples, channels, sequence length)
x_train = x_train.reshape((x_train.shape[0], 1, x_train.shape[1]))
x_test = x_test.reshape((x_test.shape[0], 1, x_test.shape[1]))
print("x_train shape: ", x_train.shape)
print("x_test shape: ", x_test.shape)

# Standardize the labels to positive integers. The expected labels will then be 0 and 1.
y_train[y_train == -1] = 0
y_test[y_test == -1] = 0

# Count the number of classes
num_classes = len(np.unique(y_train))
print("Number of classes: " + str(num_classes))

# Canonicalize the data (pass through 0 at the origin)
x_train -= x_train[:, :, 0].reshape(-1, 1, 1)
x_test -= x_test[:, :, 0].reshape(-1, 1, 1)

# Scale the data to be between 0 and 1
min_val = np.min(x_train)
max_val = np.max(x_train)
x_train = ((x_train - min_val) / (max_val - min_val))
x_test = ((x_test - min_val) / (max_val - min_val))

# TODO: I'm not sure if this is possible, as I can't find the the exact frequency of the data online
# Add bandpass filtering
# lowcut = 0.1
# highcut = 0.3
# fs = 500
# order = 2
# x_train = ButterBandpassFilter(x_train, lowcut, highcut, fs, order)
# x_test = ButterBandpassFilter(x_test, lowcut, highcut, fs, order)

# Use 20% of training data for validation
train_set_size = int(len(x_train) * 0.8)
valid_set_size = len(x_train) - train_set_size
print("Training set size: " + str(train_set_size))
print("Validation set size: " + str(valid_set_size))

# # split the x_train and y_train set into two
# seed = torch.Generator().manual_seed(42)
# x_train, x_valid = data.random_split(x_train, [train_set_size, valid_set_size], generator=seed)
# y_train, y_valid = data.random_split(y_train, [train_set_size, valid_set_size], generator=seed)

# Don't randomly split the data, sequentially split the data 
train_x = torch.tensor(x_train[:train_set_size], dtype=torch.float32)
valid_x = torch.tensor(x_train[train_set_size:], dtype=torch.float32)
train_y = torch.tensor(y_train[:train_set_size], dtype=torch.float32) 
valid_y = torch.tensor(y_train[train_set_size:], dtype=torch.float32)
test_x = torch.tensor(x_test, dtype=torch.float32)
test_y = torch.tensor(y_test, dtype=torch.float32)

# Print ratio of classes in train, valid, and test
print("Train class ratio: ", torch.unique(train_y, return_counts=True))
print("Valid class ratio: ", torch.unique(valid_y, return_counts=True))

x_train shape:  (3601, 1, 500)
x_test shape:  (1320, 1, 500)
Number of classes: 2
Training set size: 2880
Validation set size: 721


In [26]:
# Create train, valid, and test data loaders
batch_size = 64 # too large of a batchsize crashes the kernel (memory issues due to fft and irfft)
workers = 0
data_augmentation = None

if data_augmentation == "randomsample":
        n_sample = 400
        seq_length = n_sample
        train_loader = DataLoader(
                CustomDataset(
                        train_x, 
                        train_y, 
                        transform=transforms.RandomApply([RandomSample(n_sample=n_sample)], p=1) # Can't be used with other transforms as it changes the shape of the data
                ),
                batch_size=batch_size,
                shuffle=True,
                drop_last=True,
                num_workers=workers,
        )

elif data_augmentation == "randomnoise":
        seq_length = train_x.shape[2]
        train_loader = DataLoader(
                CustomDataset(
                        train_x, 
                        train_y, 
                        transform=transforms.RandomApply([RandomNoise(mean=0, std=0.1)], p=0.8)
                ),
                batch_size=batch_size,
                shuffle=True,
                drop_last=True,
                num_workers=workers,
        )

elif data_augmentation == "randomtimetranslatefill0":
        seq_length = train_x.shape[2]
        train_loader = DataLoader(
                CustomDataset(
                        train_x, 
                        train_y, 
                        transform=transforms.RandomApply([RandomTimeTranslateFill0(max_shift=100)], p=0.8)
                ),
                batch_size=batch_size,
                shuffle=True,
                drop_last=True,
                num_workers=workers,
        )

elif data_augmentation == "randomtimetranslatereflect":
        seq_length = train_x.shape[2]
        train_loader = DataLoader(
                CustomDataset(
                        train_x, 
                        train_y, 
                        transform=transforms.RandomApply([RandomTimeTranslateReflect(max_shift=100)], p=0.8)
                ),
                batch_size=batch_size,
                shuffle=True,
                drop_last=True,
                num_workers=workers,
        )

elif data_augmentation == "randomnoise_randomtimetranslatefill0":
        seq_length = train_x.shape[2]
        train_loader = DataLoader(
                CustomDataset(
                        train_x, 
                        train_y, 
                        transform=transforms.Compose([
                                transforms.RandomApply([RandomNoise(mean=0, std=0.1)], p=0.5),
                                transforms.RandomApply([RandomTimeTranslateFill0(max_shift=100)], p=0.5),
                        ])
                ),
                batch_size=batch_size,
                shuffle=True,
                drop_last=True,
                num_workers=workers,
        )   

elif data_augmentation == "randomnoise_randomtimetranslatereflect":
        seq_length = train_x.shape[2]
        train_loader = DataLoader(
                CustomDataset(
                        train_x, 
                        train_y, 
                        transform=transforms.Compose([
                                transforms.RandomApply([RandomNoise(mean=0, std=0.1)], p=0.5),
                                transforms.RandomApply([RandomTimeTranslateReflect(max_shift=100)], p=0.5),
                        ])
                ),
                batch_size=batch_size,
                shuffle=True,
                drop_last=True,
                num_workers=workers,
        )    

else:
        seq_length = train_x.shape[2]
        train_loader = DataLoader(
                CustomDataset(
                        train_x, 
                        train_y, 
                        transform=None
                ),
                batch_size=batch_size,
                shuffle=True,
                drop_last=True,
                num_workers=workers, 
        )  

valid_loader = DataLoader(
    CustomDataset(valid_x, valid_y),
    batch_size=batch_size,
    shuffle=False,
    drop_last=True,
    num_workers=workers,
)

test_loader = DataLoader(
    CustomDataset(test_x, test_y),
    batch_size=batch_size,
    shuffle=False,
    drop_last=True,
    num_workers=workers,
)

# Print the size of a batch and type of data
for x, y in train_loader:
    print("train loader")
    print("Sample batch of data (batch size, # channels, sequence length): " + str(x.shape))
    print("Sample batch of labels: " + str(y.shape))
    break

# Check validation loader too
for x, y in valid_loader:
    print("valid loader")
    print("Sample batch of data (batch size, # channels, sequence length): " + str(x.shape))
    print("Sample batch of labels: " + str(y.shape))
    break

train loader
Sample batch of data (batch size, # channels, sequence length): torch.Size([64, 1, 500])
Sample batch of labels: torch.Size([64])
valid loader
Sample batch of data (batch size, # channels, sequence length): torch.Size([64, 1, 500])
Sample batch of labels: torch.Size([64])


## Train and test a model

In [27]:
# Hyperparameters
modes = 15
channels = [256, 256, 256, 16] 
pool_type = "avg" 
pooling = seq_length # Remember to change this if using RandomSample augmentation
proj_dim = 16 # Dimension to project to initially
p_dropout = 0.5
add_noise = False

# Optimizers and learning rate schedulers
# lr schedule options are reducelronplateau, steplr, exponentiallr, cosineannealinglr, and cosineannealingwarmrestarts
# optimizer options are sgd or adam
optimizer = "adam"
momentum = 0 # Only used for SGD optimizer
scheduler = "reducelronplateau"
lr = 1e-2

# Initialize classifier
classifier = FNOClassifier(
                modes=modes, 
                lr=lr, 
                channels=channels, 
                pooling=pooling, 
                optimizer=optimizer, 
                scheduler=scheduler, 
                momentum=momentum, 
                pool_type=pool_type, 
                seq_length=seq_length,
                proj_dim=proj_dim,
                p_dropout=p_dropout, 
                add_noise=add_noise
)

# Print the model
print(classifier)

FNOClassifier(
  (loss): BCELoss()
  (project): Sequential(
    (0): Linear(in_features=500, out_features=8000, bias=True)
    (1): Dropout(p=0.5, inplace=False)
  )
  (fno_layer_0): Sequential(
    (0): SpectralConv1d(
      (weight): ModuleList(
        (0): ComplexDenseTensor(shape=torch.Size([16, 256, 8]), rank=None)
      )
    )
    (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (fno_layer_1): Sequential(
    (0): SpectralConv1d(
      (weight): ModuleList(
        (0): ComplexDenseTensor(shape=torch.Size([256, 256, 8]), rank=None)
      )
    )
    (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (fno_layer_2): Sequential(
    (0): SpectralConv1d(
      (weight): ModuleList(
        (0): ComplexDenseTensor(shape=torch.Size([256, 256, 8]), rank=None)
      )
    )
    (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (fno_layer_3): Sequential(
    (0):

In [28]:
# Create a tensorboard logger
experiment_name = "fixed_data_split_ford"
save_directory = "../logs/"

# Check if save_dir/experiment_name exists, if not create it
if not os.path.exists(save_directory + experiment_name):
    os.makedirs(save_directory + experiment_name)

logger = TensorBoardLogger(save_dir=save_directory, name=experiment_name, version=datetime.datetime.now().strftime("%Y_%m_%d-%H_%M_%S"))
logger.log_hyperparams({
    "modes": modes, 
    "lr": lr, 
    "proj_dim": proj_dim,
    "channels": channels,
    "pool_type": pool_type, 
    "pooling": pooling, 
    "lr_scheduler": scheduler, 
    "batchsize": batch_size, 
    "optimizer": optimizer, 
    "momentum": momentum,
    "p_dropout": p_dropout,
    "add_noise": add_noise,
    "data_augmentation": data_augmentation,
    "neuralop_or_mine": "neuralop_package"
})
print("Tensorboard logs will be saved to: " + logger.log_dir)

callbacks = [
    EarlyStopping(monitor="val_loss", patience=20, mode="min"),
    LearningRateMonitor(logging_interval="step"),
]

# Train the model
trainer = Trainer(max_epochs=500,
                  logger=logger,
                  callbacks=callbacks,
                  accelerator="auto"
)

trainer.fit(model=classifier, 
            train_dataloaders=train_loader,
            val_dataloaders=valid_loader
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name        | Type       | Params | In sizes      | Out sizes    
---------------------------------------------------------------------------
0 | loss        | BCELoss    | 0      | ?             | ?            
1 | project     | Sequential | 4.0 M  | [1, 1, 500]   | [1, 1, 8000] 
2 | fno_layer_0 | Sequential | 66.3 K | [1, 16, 500]  | [1, 256, 500]
3 | fno_layer_1 | Sequential | 1.0 M  | [1, 256, 500] | [1, 256, 500]
4 | fno_layer_2 | Sequential | 1.0 M  | [1, 256, 500] | [1, 256, 500]
5 | fno_layer_3 | Sequential | 65.6 K | [1, 256, 500] | [1, 16, 500] 
6 | dropout     | Dropout    | 0      | [1, 8000]     | [1, 8000]    
7 | fc          | Linear     | 4.0 M  | [1, 8000]     | [1, 500]     
8 | pool        | AvgPool1d  | 0      | [1, 500]      | [1, 1]       
-----------------------------

Tensorboard logs will be saved to: ../logs/fixed_data_split_ford/2024_04_25-03_42_04


SLURM auto-requeueing enabled. Setting signal handlers.


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/home/linneamw/sadow_koastore/personal/linneamw/anaconda3/envs/fno/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py:298: The number of training batches (45) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

/home/linneamw/sadow_koastore/personal/linneamw/anaconda3/envs/fno/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...


In [None]:
# Test the model
trainer.test(dataloaders=test_loader)

In [None]:
# # Get the predictions
# y_pred = []
# y_true = []

# for x, y in test_loader:
#     x = x.to(classifier.device)
#     y = y.to(classifier.device)
#     logits = F.softmax(classifier(x), dim=1)
#     preds = torch.argmax(logits, dim=1)
#     y_pred.extend(logits.cpu().detach().numpy())
#     y_true.extend(y.cpu().detach().numpy())

# # Compute the ROC curve and AUC
# fpr, tpr, _ = roc_curve(y_true, y_pred)
# roc_auc = auc(fpr, tpr)

# # Plot the ROC curve
# plt.figure()
# plt.plot(fpr, tpr, color="darkorange", lw=2, label="ROC curve (area = %0.2f)" % roc_auc)
# plt.plot([0, 1], [0, 1], color="navy", lw=2, linestyle="--")
# plt.xlim([0.0, 1.0])
# plt.ylim([0.0, 1.05])
# plt.xlabel("False Positive Rate")
# plt.ylabel("True Positive Rate")
# plt.title("Receiver Operating Characteristic")
# plt.legend(loc="lower right")


# # Save the figure to the logs directory
# plt.savefig(logger.log_dir + "/roc_curve.png")

## Optuna

In [None]:
# # Create an optuna study to optimize the number of modes, channels, and pooling size
# def objective(trial):
#     # Optimize the number of modes
#     modes = trial.suggest_int("modes", 5, 15)

#     # Optimize the number of fno_layers
#     fno_layers = trial.suggest_int("fno_layers", 2, 5)

#     # Optimize the number of channels
#     channels = [trial.suggest_int(f"n_channels_{i}", 16, 256) for i in range(fno_layers)]

#     # Optimize the pooling size
#     pooling = trial.suggest_int("pooling", 2, 500)

#     # Optimize the learning rate
#     lr = trial.suggest_float("lr", 1e-4, 1e-1, log=True)

#     # Create the model
#     model = FNOClassifier(modes=modes, channels=channels, pooling=pooling)

#     # Create a learning rate scheduler and early stopping callback
#     callbacks = [
#         EarlyStopping(monitor="val_acc", patience=20, mode="max"),
#         LearningRateMonitor(logging_interval="step")
#     ]

#     # Create a tensorboard logger
#     experiment_name = "optuna"
#     logger = TensorBoardLogger(save_dir="logs/", name=experiment_name, version=datetime.datetime.now().strftime("%Y_%m_%d-%H_%M_%S"))

#     logger.log_hyperparams({"modes": modes, "lr": lr, "channels": channels, "pooling": pooling, "lr_scheduler": "ReduceLROnPlateau", "patience": 10, "min_lr": 1e-6, "factor": 0.5, "batchsize": batch_size})

#     # Create a trainer
#     trainer = Trainer(
#         max_epochs=100,
#         logger=logger,
#         callbacks=callbacks,
#         accelerator='auto'
#     )

#     # Train the model
#     trainer.fit(model, train_loader, valid_loader)

#     # Test the model
#     result = trainer.test(dataloaders=test_loader)

#     # Return the validation accuracy
#     return result[0]["test_acc"]

In [None]:
# study = optuna.create_study(direction="maximize")
# study.optimize(objective, n_trials=500)

# # Print the best hyperparameters
# print(study.best_params)