In [None]:
import os

os.chdir("..")
print(f"Changed working directory to: {os.getcwd()}")

In [None]:
import torch
import mlflow
import dagshub
import torchmetrics
import src.utils.data15min as data
import pytorch_lightning as pl

from torchvision import transforms
from tqdm.notebook import tqdm
from src.models.ResNet50BinaryClassifier import ResNet50BinaryClassifier

mlflow.pytorch.autolog()
torch.set_float32_matmul_precision("high")

In [None]:
model = ResNet50BinaryClassifier(lr=1e-4, weight_decay=1e-2)
data_folder_path = "data/raw/exported/"

data_module = data.ECallistoDataModule(
    data_folder=data_folder_path,
    batch_size=64,
    num_workers=0,
    val_ratio=0.15,
    test_ratio=0.15,
    img_size=(224, 224),
    use_augmented_data=True,
    filter_instruments=[],
    seed=0,
)
data_module.setup()

In [None]:
print("Train dataset:")
print(data_module.train_dataset.metadata.type.value_counts().to_string(header=False), "\n")

print("Validation dataset:")
print(data_module.val_dataset.metadata.type.value_counts().to_string(header=False), "\n")

print("Test dataset:")
print(data_module.test_dataset.metadata.type.value_counts().to_string(header=False))

In [None]:
#dagshub.init("FlareSense", "FlareSense", mlflow=True)
#mlflow.start_run()

#mlflow.log_params(
#    {
#        "model": "ResNet50",
#        "batch_size": data_module.batch_size,
#        "val_ratio": data_module.val_ratio,
#        "test_ratio": data_module.test_ratio,
#        "min_factor_val_test": data_module.min_factor_val_test,
#        "max_factor_val_test": data_module.max_factor_val_test,
#        "noburst_to_burst_ratio": data_module.noburst_to_burst_ratio,
#        "split_by_date": data_module.split_by_date,
#        "filter_instruments": data_module.filter_instruments,
#    }
#)

#run_id = mlflow.active_run().info.run_id
#print(f"Run ID: {run_id}")
#print(f"Link: https://dagshub.com/FlareSense/FlareSense/experiments/#/experiment/m_{run_id}")

trainer = pl.Trainer(max_epochs=50, log_every_n_steps=1)

trainer.fit(
    model,
    train_dataloaders=data_module.train_dataloader(),
    val_dataloaders=data_module.val_dataloader(),
)

trainer.test(model, dataloaders=data_module.test_dataloader())

#mlflow.end_run()