# ECG_FM

In [1]:
from data import ptbxl
import os
from model.huggingface.utils import download_model
from fairseq_signals import models, tasks
from fairseq_signals.utils import checkpoint_utils
import torch
import numpy as np
from utils import setCWDToProjectDir

In [2]:
setCWDToProjectDir()
print(f"Our current working directory is {os.getcwd()}")

Our current working directory is D:\cardiovascular-ai


In [3]:
pt, yml = download_model('wanglab/ecg-fm-preprint')

In [5]:
# Load the model
model = models.build_model_from_checkpoint(checkpoint_path=pt)

# Load the model state
model.eval()

dataset = ptbxl.PTBXL(sampling_rate=ptbxl.SamplingRate.HZ_500)

record = dataset.load_record(1)

ecg_signal = record.data

print("ECG signal shape:", ecg_signal.shape )

if len(ecg_signal) < 2500:
    raise ValueError("ECG record is shorter than 5 seconds.")

# Extract a 5-second segment (2500 samples at 500 Hz)
segment = ecg_signal[:2500]

# Normalize the segment
ecg_tensor = torch.tensor(segment, dtype=torch.float32).unsqueeze(0)

print("ECG tensor shape before repeat:", ecg_tensor.shape)

ecg_tensor = ecg_tensor.transpose(1,2)

# Run inference with the model
with torch.no_grad():
    output = model(source=ecg_tensor)  # Output shape: [1, 26]
    logits = output["out"]      # extract the logits tensor (check your model's output keys)
    
    probabilities = torch.softmax(logits, dim=-1)
    predicted_class = torch.argmax(probabilities, dim=-1)

print("Logits:", logits)
print("Probabilities:", probabilities)
predicted_class_idx = predicted_class.item()

print("Predicted class index:", predicted_class.item())
print("Confidence score:", probabilities[0][predicted_class_idx].item())





ECG signal shape: (5000, 12)
ECG tensor shape before repeat: torch.Size([1, 2500, 12])
Logits: tensor([[ -8.4053, -12.6755, -26.7521, -28.0138, -12.1268, -14.4197, -11.4347,
         -12.1979,  -1.2968, -10.6023, -17.8656,   3.5591,  -9.6794,  -3.4919,
           5.2512,  -3.2053, -10.8229, -20.2688, -13.5364,  -7.0771, -22.4763,
          -5.7198,  -7.8527,  -9.2080,  -2.5858,  -6.1212]])
Probabilities: tensor([[9.8818e-07, 1.3814e-08, 1.0640e-14, 3.0129e-15, 2.3912e-08, 2.4144e-09,
         4.7774e-08, 2.2270e-08, 1.2079e-03, 1.0982e-07, 7.6964e-11, 1.5521e-01,
         2.7637e-07, 1.3449e-04, 8.4291e-01, 1.7913e-04, 8.8087e-08, 6.9594e-12,
         5.8403e-09, 3.7295e-06, 7.6538e-13, 1.4492e-05, 1.7172e-06, 4.4284e-07,
         3.3284e-04, 9.7010e-06]])
Predicted class index: 14
Confidence score: 0.8429063558578491
