In [None]:
import os

os.chdir("..")
print(f"Changed working directory to: {os.getcwd()}")

In [None]:
import torch
import mlflow
import torchmetrics
import src.utils.data as data

from tqdm.notebook import tqdm
from torchvision import transforms

In [None]:
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"

In [None]:
run_id = "115ecd0ffec84cbfbfb6431b4d879464" # Change to model

mlflow.set_tracking_uri('https://dagshub.com/FlareSense/Flaresense.mlflow')
model = mlflow.pytorch.load_model(f"runs:/{run_id}/model/").to(device)
model.eval()

In [None]:
data_folder_path = "data/raw/burst_images/"

data_module = data.ECallistoDataModule(
    data_folder=data_folder_path,
    transform=transforms.Compose(
        [
            transforms.ToPILImage(),
            transforms.Resize((193, 240), antialias=True),
            transforms.ToTensor(),
        ]
    ),
    batch_size=32,
    num_workers=0,
    val_ratio=0.2,
    test_ratio=0.2,
    split_by_date=True,
    filter_instruments=["australia_assa_02"],    
)
data_module.setup()

val_loader = data_module.val_dataloader()

In [None]:
val_labels_list = []
val_preds_list = []
with torch.no_grad():
    for batch in tqdm(val_loader):
        images, info = batch
        images = images.to(device)
        binary_labels = [0 if label == "no_burst" else 1 for label in info['label']]
        binary_labels = torch.tensor(binary_labels).int().view(-1, 1)
        binary_labels = binary_labels.to(device)
        
        outputs = model(images.expand(-1, 3, -1, -1))
        predictions = (outputs >= 0.5).int()
        
        val_labels_list.append(binary_labels)
        val_preds_list.append(predictions)

In [None]:
len(val_preds_list)

In [None]:
len(val_loader)

In [None]:
# Alle Validierungsdaten wurden gesammelt
val_labels = torch.cat(val_labels_list, dim=0)
val_preds = torch.cat(val_preds_list, dim=0)

In [None]:
# Initialisieren Sie den ConfusionMatrix-Metriken von torchmetrics
confmat_metric = torchmetrics.ConfusionMatrix(num_classes=2, task="binary").to(device)
confmat_metric(val_preds, val_labels)