In [1]:
import os

os.chdir("..")
print(f"Changed working directory to: {os.getcwd()}")

Changed working directory to: /home/jovyan/work/FlareSense


In [2]:
import torch
import mlflow
import dagshub
import itertools
import torchmetrics
import src.utils.data15min as data
import pytorch_lightning as pl

from huggingface_hub import snapshot_download
from torchvision import transforms
from tqdm.notebook import tqdm
from src.models.ResNet50BinaryClassifier import ResNet50BinaryClassifier

mlflow.pytorch.autolog()
torch.set_float32_matmul_precision("high")

2023-12-27 13:36:18.462399: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-12-27 13:36:18.529158: I tensorflow/core/util/util.cc:169] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-12-27 13:36:18.544633: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2023-12-27 13:36:18.898128: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvi

In [3]:
DATA_FOLDER_PATH = "data/raw/exported/"
INSTRUMENTS = ["Australia-ASSA_02", "Australia-ASSA_62"]

In [4]:
# download needed files
snapshot_download(
    "StellarMilk/ecallisto-bursts",
    repo_type="dataset",
    allow_patterns=[f"{instrument}.zip" for instrument in INSTRUMENTS] + ["metadata.csv"],
    local_dir=DATA_FOLDER_PATH,
    revision="main",
)

for instrument in INSTRUMENTS:
    # if data available, skip
    if os.path.exists(f"{DATA_FOLDER_PATH}{instrument}"):
        print(f"Skipping {instrument}")
        continue
    
    # unzip if needed
    print(f"Unzipping {instrument}")
    !unzip -q {DATA_FOLDER_PATH}{instrument}.zip -d {DATA_FOLDER_PATH}

Fetching 3 files:   0%|          | 0/3 [00:00<?, ?it/s]

Skipping Australia-ASSA_02
Skipping Australia-ASSA_62


In [5]:
data_module = data.ECallistoDataModule(
    data_folder=DATA_FOLDER_PATH,
    batch_size=64,
    num_workers=12,
    val_ratio=0.15,
    test_ratio=0.15,
    img_size=(224, 224),
    use_augmented_data=False,
    filter_instruments=["Australia-ASSA_02", "Australia-ASSA_62"],
    seed=0,
)
data_module.setup()

data_module_aug = data.ECallistoDataModule(
    data_folder=DATA_FOLDER_PATH,
    batch_size=64,
    num_workers=12,
    val_ratio=0.15,
    test_ratio=0.15,
    img_size=(224, 224),
    use_augmented_data=True,
    filter_instruments=["Australia-ASSA_02", "Australia-ASSA_62"],
    seed=0,
)
data_module_aug.setup()

In [6]:
print("Train dataset without augmented data:")
print(data_module.train_dataset.metadata.type.value_counts().to_string(header=False), "\n")

print("Train dataset with augmented data:")
print(data_module_aug.train_dataset.metadata.type.value_counts().to_string(header=False), "\n")

print("Validation dataset:")
print(data_module.val_dataset.metadata.type.value_counts().to_string(header=False), "\n")

print("Test dataset:")
print(data_module.test_dataset.metadata.type.value_counts().to_string(header=False))

Train dataset without augmented data:
no_burst    29967
III          1569
VI             54
II             28
V              28
IV              5
I               2 

Train dataset with augmented data:
no_burst    29967
III         18277
VI            333
V             324
II            253
IV             36
I              24 

Validation dataset:
no_burst    6422
III          329
VI            13
II            10
V              6
I              1
IV             1 

Test dataset:
no_burst    6369
III          385
VI            15
II             6
V              5
IV             2


In [None]:
lst_lr = [1e-3, 3e-4, 1e-4, 3e-5]
lst_weight_decay = [3e-3, 1e-3, 3e-4, 1e-4]
lst_use_data_augmentation = [False, True]

for use_data_augmentation in lst_use_data_augmentation:
    data_module = data.ECallistoDataModule(
        data_folder=DATA_FOLDER_PATH,
        batch_size=64,
        num_workers=12,
        val_ratio=0.15,
        test_ratio=0.15,
        img_size=(224, 224),
        use_augmented_data=use_data_augmentation,
        filter_instruments=["Australia-ASSA_02", "Australia-ASSA_62"],
        seed=0,
    )
    data_module.setup()
    
    for current_lr, current_weight_decay in itertools.product(lst_lr, lst_weight_decay):
        model = ResNet50BinaryClassifier(lr=current_lr, weight_decay=current_weight_decay)

        dagshub.init("FlareSense", "FlareSense", mlflow=True)
        mlflow.start_run()

        mlflow.log_params({
            "model": "ResNet50",
            "batch_size": data_module.batch_size,               
            "val_ratio": data_module.val_ratio,
            "test_ratio": data_module.test_ratio,
            "use_data_augmentation": data_module.use_augmented_data,
            "filter_instruments": data_module.filter_instruments,
        })

        run_id = mlflow.active_run().info.run_id
        print(f"Run ID: {run_id}")
        print(f"Link: https://dagshub.com/FlareSense/FlareSense/experiments/#/experiment/m_{run_id}")

        trainer = pl.Trainer(max_epochs=30, log_every_n_steps=1)

        trainer.fit(
            model,
            train_dataloaders=data_module.train_dataloader(),
            val_dataloaders=data_module.val_dataloader(),
        )

        trainer.test(model, dataloaders=data_module.test_dataloader())

        mlflow.end_run()