In [None]:
import sys
from pathlib import Path

parent_dir = Path.cwd().parent
sys.path.insert(0, str(parent_dir))

from model_finetuning.ModelClassDL_AF import *

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

import numpy as np
import pandas as pd
from scipy.io import loadmat
from fairseq_signals.models.wav2vec2.wav2vec2_cmsc import Wav2Vec2CMSCModel
from omegaconf import OmegaConf
import os
from scipy.special import expit

In [None]:
path_to_foundational_model = "path/to/fm"
path_to_finetuned_weights_dl_af = "path/to/finetuned"
path_to_ecgs = "path/to/ecgs"

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}", flush=True)

In [None]:
ckpt_encoder = torch.load(path_to_foundational_model, weights_only = False)
cfg_encoder = ckpt_encoder['cfg']
cfg_encoder = OmegaConf.create(cfg_encoder["model"])
cfg_encoder["saliency"] = False
encoder = Wav2Vec2CMSCModel(cfg_encoder)

In [None]:
state_dict = torch.load(path_to_finetuned_weights_dl_af)
finetuned_model = ECGClassificationModel(encoder, feature_dim=768, num_outputs=1)
finetuned_model.load_state_dict(state_dict)
finetuned_model.to(device)

In [None]:
ecg_dirs = os.listdir(path_to_ecgs)
test_samples = []
IDs = []
for ecg_file in ecg_dirs:
    ecg_samples = []
    dir = path_to_ecgs + ecg_file
    ecg = loadmat(dir)["feats"]
    ecg_tensor = torch.from_numpy(ecg).float()
    test_samples.append(ecg_tensor) 
    IDs.append(ecg_file)

X_test = torch.stack(test_samples, dim=0)
print(X_test.size()) 

In [None]:
test_dataset = TensorDataset(X_test)
test_loader = DataLoader(test_dataset, batch_size=32)  # Adjust batch_size as needed

# Collect predictions in batches
predictions = []

finetuned_model.eval()
with torch.no_grad():
    for batch in test_loader:
        batch_x = batch[0].to("cuda")
        outputs = finetuned_model(batch_x).cpu().numpy()
        predictions.append(outputs)
    all_predictions = np.concatenate(predictions, axis=0)

In [None]:
#Transform to probability (0 - 1)
all_predictions_transformed = expit(all_predictions.flatten())

In [None]:
results = pd.DataFrame({"ID" : IDs, "Pred. AF Probability" : all_predictions_transformed})