In [1]:
import os

os.chdir("..")
print(f"Changed working directory to: {os.getcwd()}")

Changed working directory to: /Users/gabriel.torres/Nextcloud/Development/Pro5D/FlareSense


In [2]:
import torch
import mlflow
import dagshub
import torchmetrics
import src.utils.data15min as data
import pytorch_lightning as pl

from huggingface_hub import snapshot_download
from torchvision import transforms
from tqdm.notebook import tqdm
from src.models.ResNet50BinaryClassifier import ResNet50BinaryClassifier

mlflow.pytorch.autolog()
torch.set_float32_matmul_precision("high")

In [3]:
data_folder_path = "data/raw/exported/"

# download data if needed
snapshot_download(
    "StellarMilk/ecallisto-bursts",
    repo_type="dataset",
    allow_patterns=["*62.zip", "*.csv"],
    local_dir=data_folder_path,
    revision="main",
)

# unzip data if needed
instruments = [file for file in os.listdir(data_folder_path) if file.endswith(".zip")]
for instrument in instruments:
    if os.path.exists(f"{data_folder_path}{instrument[:-4]}"):
        print(f"Skipping {instrument}")
        continue
    print(f"Unzipping {instrument}")
    !unzip -q {data_folder_path}{instrument} -d {data_folder_path}

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

Skipping Australia-ASSA_62.zip


In [4]:
model = ResNet50BinaryClassifier(lr=1e-4, weight_decay=1e-2)

data_module = data.ECallistoDataModule(
    data_folder=data_folder_path,
    batch_size=64,
    num_workers=0,
    val_ratio=0.15,
    test_ratio=0.15,
    img_size=(224, 224),
    use_augmented_data=True,
    filter_instruments=["Australia-ASSA_62"],
    seed=0,
)
data_module.setup()

In [5]:
print("Train dataset:")
print(data_module.train_dataset.metadata.type.value_counts().to_string(header=False), "\n")

print("Validation dataset:")
print(data_module.val_dataset.metadata.type.value_counts().to_string(header=False), "\n")

print("Test dataset:")
print(data_module.test_dataset.metadata.type.value_counts().to_string(header=False))

Train dataset:
no_burst    21353
III         15418
V             265
II            221
VI            182
IV             24 

Validation dataset:
no_burst    4581
III          283
II             8
V              4
VI             2 

Test dataset:
no_burst    4588
III          273
V              7
VI             5
II             3
IV             1


In [6]:
# dagshub.init("FlareSense", "FlareSense", mlflow=True)
# mlflow.start_run()

# mlflow.log_params(
#    {
#        "model": "ResNet50",
#        "batch_size": data_module.batch_size,
#        "val_ratio": data_module.val_ratio,
#        "test_ratio": data_module.test_ratio,
#        "min_factor_val_test": data_module.min_factor_val_test,
#        "max_factor_val_test": data_module.max_factor_val_test,
#        "noburst_to_burst_ratio": data_module.noburst_to_burst_ratio,
#        "split_by_date": data_module.split_by_date,
#        "filter_instruments": data_module.filter_instruments,
#    }
# )

# run_id = mlflow.active_run().info.run_id
# print(f"Run ID: {run_id}")
# print(f"Link: https://dagshub.com/FlareSense/FlareSense/experiments/#/experiment/m_{run_id}")

trainer = pl.Trainer(max_epochs=10, log_every_n_steps=1, fast_dev_run=50)

trainer.fit(
    model,
    train_dataloaders=data_module.train_dataloader(),
    val_dataloaders=data_module.val_dataloader(),
)

trainer.test(model, dataloaders=data_module.test_dataloader())

# mlflow.end_run()

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Running in `fast_dev_run` mode: will run the requested loop using 50 batch(es). Logging and checkpointing is suppressed.
2023/12/26 23:00:34 INFO mlflow.utils.autologging_utils: Created MLflow autologging run with ID '304f8e53ca924d87a6c732fc0fc16c96', which will track hyperparameters, performance metrics, model artifacts, and lineage information for the current pytorch workflow

  | Name      | Type            | Params
----------------------------------------------
0 | precision | BinaryPrecision | 0     
1 | recall    | BinaryRecall    | 0     
2 | resnet50  | ResNet          | 23.5 M
----------------------------------------------
22.1 M    Trainable params
1.4 M     Non-trainable params
23.5 M    Total params
94.040    Total estimated model params size (MB)
/opt/homebrew/lib/python3.11/site-packages/pytorch_lightning/trainer/connector

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_steps=50` reached.
/opt/homebrew/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Testing: |          | 0/? [00:00<?, ?it/s]

[{'test_loss': 0.0, 'test_precision': 0.0, 'test_recall': 0.0}]