In [1]:
import torch
import numpy as np
from sklearn.preprocessing import StandardScaler
from momentfm.utils.data import load_from_tsfile
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from momentfm import MOMENTPipeline
from tqdm import tqdm
import time
from collections import Counter

  from .autonotebook import tqdm as notebook_tqdm


In [2]:

class TS_Dataset(Dataset):
    def __init__(self, file_path):
        """
        Parameters
        ----------
        file_path : str
            Path to the time series (TS) file.
        """
        self.file_path = file_path
        
        # Load data from the TS file
        self.data, self.labels, meta_data = load_from_tsfile(file_path, return_meta_data=True)
        print(f"data.shape={self.data.shape}")
        # Assume data has shape: (n_samples, n_channels, series_length)
        self.n_samples, self.n_channels, self.series_length = self.data.shape
        
        # Normalize each channel of each sample independently
        mean = np.mean(self.data, axis=-1, keepdims=True)
        std = np.std(self.data, axis=-1, keepdims=True)
        std_adj = np.where(std == 0, 1, std)  # Avoid division by zero
        self.data = (self.data - mean) / std_adj
        
        self._length = self.n_samples
        self.n_classes = len(meta_data['class_values'])
        
        # Print dataset summary
        print(f"Dataset Loaded: {file_path} | Samples: {self.n_samples}, Channels: {self.n_channels}, Series Length: {self.series_length}, Classes: {self.n_classes}")
        
    def __len__(self):
        return self._length
    
    def __getitem__(self, idx):
        return self.data[idx], int(self.labels[idx])
    
    @property
    def timeseries(self):
        return self.data
    
    @property
    def labels_prop(self):
        return self.labels


def train_epoch(model, device, train_dataloader, criterion, optimizer, scheduler, reduction='mean'):
    '''
    Train only classification head
    '''
    model.to(device)
    model.train()
    losses = []
    total_correct = 0  # Track total correct predictions

    for batch_x, batch_labels in train_dataloader:
        optimizer.zero_grad()
        batch_x = batch_x.to(device).float()
        batch_labels = batch_labels.to(device)

        # Forward pass with mixed precision
        with torch.autocast(device_type='cuda', dtype=torch.bfloat16 if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8 else torch.float32):
            output = model(x_enc=batch_x, reduction=reduction)
            loss = criterion(output.logits, batch_labels)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        scheduler.step()

        # Calculate accuracy
        predictions = output.logits.argmax(dim=1)  # Get predicted class
        total_correct += (predictions == batch_labels).sum().item()  # Count correct predictions
        losses.append(loss.item())
    
    # Calculate average loss and accuracy
    avg_loss = np.mean(losses)
    accuracy = total_correct / len(train_dataloader.dataset)  # Accuracy over the entire dataset
    return avg_loss, accuracy


def evaluate_epoch(dataloader, model, criterion, device, phase='val', reduction='mean'):
    model.eval()
    model.to(device)
    total_loss, total_correct = 0, 0

    with torch.no_grad():
        for batch_x, batch_labels in dataloader:
            batch_x = batch_x.to(device).float()
            batch_labels = batch_labels.to(device)

            output = model(x_enc=batch_x, reduction=reduction)
            loss = criterion(output.logits, batch_labels)
            total_loss += loss.item()
            #total_correct += (output.logits.argmax(dim=1) == batch_labels).sum().item()
            predictions = output.logits.argmax(dim=1)
            total_correct += (predictions == batch_labels).sum().item()
            print(f"P: {[p.item() for p in predictions]} C: {[b.item() for b in batch_labels]}")
            #for i in range(len(batch_labels)):
                #print(f"[{i:02d}] P: {predictions[i].item()} C: {batch_labels[i].item()}")
    
    avg_loss = total_loss / len(dataloader)
    accuracy = total_correct / len(dataloader.dataset)
    return avg_loss, accuracy

## `ptbxl_classification` 소스코드에서 코드 가져옴

In [3]:

#train_dataset = TS_Dataset("../label_data/m_w2_train.ts")
#test_dataset = TS_Dataset("../label_data/m_w2_test.ts")
train_dataset = TS_Dataset("../label_data/TP_W3_train.ts")
test_dataset = TS_Dataset("../label_data/TP_W3_test.ts")

batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
val_loader = test_loader

print(f"Num Train Set: {len(train_loader)}")
print(f"n_channels={train_dataset.n_channels},num_class={train_dataset.n_classes}")

# Initialize the model and move it to GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

model = MOMENTPipeline.from_pretrained(
    "AutonLab/MOMENT-1-large", 
    model_kwargs={
        'task_name': 'classification',
        'n_channels': train_dataset.n_channels,
        'num_class': train_dataset.n_classes,
        'freeze_encoder': True, # Freeze the patch embedding layer
        'freeze_embedder': True, # Freeze the transformer encoder
        'freeze_head': False, # The linear forecasting head must be trained
        ## NOTE: Disable gradient checkpointing to supress the warning when linear probing the model as MOMENT encoder is frozen
        'enable_gradient_checkpointing': False,
        # Choose how embedding is obtained from the model: One of ['mean', 'concat']
        # Multi-channel embeddings are obtained by either averaging or concatenating patch embeddings 
        # along the channel dimension. 'concat' results in embeddings of size (n_channels * d_model), 
        # while 'mean' results in embeddings of size (d_model)
        'reduction': 'mean',
    },
    # local_files_only=True,  # Whether or not to only look at local files (i.e., do not try to download the model).
    ).to(device)

model.init()


data.shape=(12291, 19, 384)
Dataset Loaded: ../label_data/TP_W3_train.ts | Samples: 12291, Channels: 19, Series Length: 384, Classes: 3
data.shape=(599, 19, 384)
Dataset Loaded: ../label_data/TP_W3_test.ts | Samples: 599, Channels: 19, Series Length: 384, Classes: 3
Num Train Set: 193
n_channels=19,num_class=3
Using device: cuda




In [4]:
for batch_x, batch_labels in train_loader:
    print(f"Batch shape: {batch_x.shape} {batch_labels}")  # Should be (batch_size, n_channels, 380)
    break  # Only check the first batch

Batch shape: torch.Size([64, 19, 384]) tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])


In [5]:
label_counts = Counter(train_dataset.labels_prop.tolist())
print(f"label_counts={label_counts}")
total = sum(label_counts.values())
weights = [(total / label_counts[str(i)]) if label_counts[str(i)] > 0 else 0 for i in range(train_dataset.n_classes)]
print(f"weights={weights}")

label_counts=Counter({'1': 11710, '0': 581})
weights=[21.15490533562823, 1.0496157130657557, 0]


In [None]:
epoch = 10
if False:
    # train dataset의 라벨을 가져와서 클래스별 샘플 수 세기
    label_counts = Counter(train_dataset.labels_prop)
    total = sum(label_counts.values())

    # 클래스 가중치 계산: 적은 클래스에 더 높은 가중치를 부여
    weights = [(total / label_counts[str(i)]) if label_counts[str(i)] > 0 else 0 for i in range(train_dataset.n_classes)]
    print(f"weights={weights}")
    class_weights = torch.tensor(weights, dtype=torch.float).to(device)
    criterion = torch.nn.CrossEntropyLoss(weight=class_weights)
else:
    criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.head.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=1e-3, total_steps=epoch * len(train_loader))
#device = 'cuda:3'

t1 = time.time()
for i in tqdm(range(epoch)):
    train_loss, train_acc = train_epoch(model, device, train_loader, criterion, optimizer, scheduler)
    val_loss, val_acc = evaluate_epoch(val_loader, model, criterion, device, phase='test')
    print(f'[{time.time() - t1:.2f}] Epoch {i}, train loss: {train_loss:.4f}, val loss: {val_loss:.4f}, train acc: {train_acc:.4f}, val accuracy: {val_acc:.4f}')

test_loss, test_accuracy = evaluate_epoch(test_loader, model, criterion, device, phase='test')
print(f'Test loss: {test_loss}, test accuracy: {test_accuracy}')
print(f"Time taken: {time.time() - t1:.2f} seconds, finished at {time.ctime()}")

weights=[21.15490533562823, 1.0496157130657557, 0]


  0%|          | 0/30 [00:00<?, ?it/s]

P: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] C: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1]
P: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] C: [1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
P: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] C: [1, 1, 1, 1, 1, 1

  3%|▎         | 1/30 [09:06<4:24:13, 546.67s/it]

P: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] C: [1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]
[546.67] Epoch 0, train loss: nan, val loss: nan, train acc: 0.0473, val accuracy: 0.0484
P: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] C: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1]
P: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] C: [1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1

  7%|▋         | 2/30 [18:13<4:15:10, 546.81s/it]

P: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] C: [1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]
[1093.58] Epoch 1, train loss: nan, val loss: nan, train acc: 0.0473, val accuracy: 0.0484
P: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] C: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1]
P: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] C: [1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 

 10%|█         | 3/30 [27:20<4:06:05, 546.87s/it]

P: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] C: [1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]
[1640.52] Epoch 2, train loss: nan, val loss: nan, train acc: 0.0473, val accuracy: 0.0484
P: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] C: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1]
P: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] C: [1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 

 13%|█▎        | 4/30 [36:27<3:56:59, 546.91s/it]

P: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] C: [1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]
[2187.49] Epoch 3, train loss: nan, val loss: nan, train acc: 0.0473, val accuracy: 0.0484
P: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] C: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1]
P: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] C: [1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 

 17%|█▋        | 5/30 [45:34<3:47:53, 546.92s/it]

P: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] C: [1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]
[2734.43] Epoch 4, train loss: nan, val loss: nan, train acc: 0.0473, val accuracy: 0.0484
P: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] C: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1]
P: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] C: [1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 

 20%|██        | 6/30 [54:41<3:38:45, 546.91s/it]

P: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] C: [1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]
[3281.33] Epoch 5, train loss: nan, val loss: nan, train acc: 0.0473, val accuracy: 0.0484


In [None]:
path = f'w3_epoch{epoch}.pth'
state_dict = model.state_dict()
print(f"Loaded state_dict head.linear.weight shape: {state_dict['head.linear.weight'].shape}")
print(f"Loaded state_dict head.linear.bias shape: {state_dict['head.linear.bias'].shape}")
torch.save(state_dict, path)
print(f"Model saved to {path}")

In [None]:
from sklearn.metrics import precision_score, recall_score, f1_score

#test_dataset = TS_Dataset("../label_data/m_w2_test.ts")
test_dataset = TS_Dataset("../label_data/3_250301~22_train.ts")

batch_size = 1
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


# Evaluate the model on the test dataset
model.eval()  # Set the model to evaluation mode
test_loss = 0
correct = 0
total = 0

all_labels = []
all_predictions = []
t1 = time.time()

with torch.no_grad():  # Disable gradient computation for evaluation
    for data, labels in test_dataloader:
        # Move data to the appropriate device
        data = data.to(device, dtype=torch.float32)
        labels = labels.to(device)

        # Forward pass
        output = model(x_enc=data, reduction='mean')

        if output is None or output.logits is None:
            raise ValueError("The model's output is None. Check the model's forward implementation.")

        logits = output.logits

        # Compute loss
        loss = criterion(logits, labels)
        test_loss += loss.item()

        # Get predictions
        _, predicted = torch.max(logits, dim=1)  # Get the predicted class indices
        print(f"predicted={predicted} labels={labels}")
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

        # Save all labels and predictions for metric computation
        all_labels.extend(labels.cpu().numpy())
        all_predictions.extend(predicted.cpu().numpy())

# Calculate metrics
test_loss /= len(test_dataloader)
accuracy = 100 * correct / total
precision = precision_score(all_labels, all_predictions, average="weighted")
recall = recall_score(all_labels, all_predictions, average="weighted")
f1 = f1_score(all_labels, all_predictions, average="weighted")

# Print metrics
print(f"Num Test Set: {len(test_dataloader)}")
print(f"Test Loss: {test_loss:.3f}")
print(f"Test Accuracy: {accuracy:.2f}%")
print(f"Precision: {precision:.2f}")
print(f"Recall: {recall:.2f}")
print(f"F1 Score: {f1:.2f}")
print(f"Time taken: {time.time() - t1:.2f} seconds, finished at {time.ctime()}")
