In [None]:
"""
This script implements a Contextual Bandit–style classification approach using a DQN architecture
on HUMS2023 vibration signals. The goal is to classify each signal segment as "Normal" or "Faulty" based
on a simple reward structure (reward = +1 if the predicted label matches the ground truth, otherwise –1).
Although framed as a DQN, the discount factor (gamma) is set to 0, effectively reducing this to a
contextual bandit that selects an action (label) for a given signal.

Note:

    - Training sequences are loaded from two folders:
        1. './data/Contextual_bandit/Normal_19_20'  (labels = 0 for normal signals, days 19 and 20
        2. './data/Contextual_bandit/Faulty_26_27'  (labels = 1 for faulty signals, days 26–27)
      Unlabeled test data for days 21–25 are loaded from:
        './data/Contextual_bandit/unknown21_25'
    - The model input dimensionality is 4095, flattened and the output is a binary choice (0 = Normal, 1 = Faulty).
    - Final predictions on the test set are saved as a NumPy file
"""


%reset -f
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import glob, os
from scipy.io import loadmat
from sklearn.metrics import classification_report, confusion_matrix

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def load_signals(folder):
    """
    Load and flatten vibration signals from .mat files in a specified folder.

    Args:
        folder (str): Path to the directory containing .mat files. Each .mat file
                      should contain a variable 'xr' representing a 1D vibration signal.

    Returns:
        signals (np.ndarray): Array of shape (N, signal_length), where N is the number of .mat files
        and signal_length is the length of each 'xr' array.
        files (List[str]):     Sorted list of full file paths corresponding to each signal.
    """
    signals = []
    files = sorted(glob.glob(os.path.join(folder, '*.mat')))
    for file in files:
        data = loadmat(file)
        if 'xr' in data:
            xr = data['xr'].astype(np.float32).flatten()
            signals.append(xr)
    return np.array(signals), files

train_normal, _ = load_signals('./data/Contextual_bandit/Normal_19_20')   # days 1–20
train_faulty, _ = load_signals('./data/Contextual_bandit/Faulty_26_27')   # days 26–27
test_data, test_files = load_signals('./data/Contextual_bandit/unknown21_25') # unlabeled test (Day 21-25)

X_train = np.vstack([train_normal, train_faulty])
y_train = np.hstack([np.zeros(len(train_normal)), np.ones(len(train_faulty))])

mean, std = X_train.mean(), X_train.std()
X_train = (X_train - mean) / std
test_data = (test_data - mean) / std

X_train = torch.tensor(X_train, dtype=torch.float32).to(device)
y_train = torch.tensor(y_train, dtype=torch.long).to(device)
test_data = torch.tensor(test_data, dtype=torch.float32).to(device)

class DQN(nn.Module):
    """
    DQN for contextual bandit classification.

    The network takes a flattened 1D vibration signal of length 4096 (with one step implicitly removed,
    so input_dim = 4095) and outputs two Q-values corresponding to the two actions:
        0: classify as Normal
        1: classify as Faulty
    Args:
        input_dim (int):   Dimensionality of the flattened input signal (default 4095).
        hidden_dim (int):  Number of units in the first hidden layer (default 256).
    """
    def __init__(self, input_dim=4095, hidden_dim=256):
        super(DQN, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim//2), nn.ReLU(),
            nn.Linear(hidden_dim//2, 2))

    def forward(self, x):
        """
        Forward pass through the DQN.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, input_dim).

        Returns:
            torch.Tensor: Output Q-values of shape (batch_size, 2).
        """
        return self.fc(x)

policy_net = DQN().to(device)
target_net = DQN().to(device)
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()

optimizer = optim.Adam(policy_net.parameters(), lr=1e-3)
criterion = nn.MSELoss()
epsilon, epsilon_min, epsilon_decay = 1.0, 0.05, 0.995
gamma = 0
batch_size = 64
memory = []

def remember(state, action, reward):
    """
    Store an experience tuple (state, action, reward) in replay memory.

    Args:
        state (torch.Tensor): Tensor of shape (input_dim,) representing the current state.
        action (int):         Chosen action (0 for normal, 1 for faulty).
        reward (int or float): Scalar reward received for this action.
    """
    memory.append((state, action, reward))

def replay():
    """
    Sample a random batch from memory and perform a gradient descent step on the
    DQN using the stored (state, action, reward) tuples. Since gamma = 0, the target
    Q-value equals the immediate reward for the chosen action.

    If the replay buffer has fewer than batch_size entries, this function returns immediately.
    """
    if len(memory) < batch_size: return
    batch = np.random.choice(len(memory), batch_size, replace=False)
    states, actions, rewards = zip(*[memory[i] for i in batch])
    states = torch.stack(states)
    actions = torch.tensor(actions).unsqueeze(1).to(device)
    rewards = torch.tensor(rewards).unsqueeze(1).to(device).float()

    q_values = policy_net(states).gather(1, actions)
    loss = criterion(q_values, rewards)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

epochs = 300
for epoch in range(epochs):
    perm = torch.randperm(len(X_train))
    total_reward = 0
    for idx in perm:
        state = X_train[idx]
        label = y_train[idx].item()


        if np.random.rand() < epsilon:
            action = np.random.choice([0,1])
        else:
            q_vals = policy_net(state.unsqueeze(0))
            action = q_vals.argmax().item()

        reward = 1 if action == label else -1
        remember(state, action, reward)
        total_reward += reward

        replay()

    epsilon = max(epsilon_min, epsilon*epsilon_decay)

    if epoch % 10 == 0 or epoch == epochs-1:
        print(f'Epoch {epoch}, Total Reward: {total_reward}, Epsilon: {epsilon:.3f}')
g
target_net.load_state_dict(policy_net.state_dict())

policy_net.eval()
predictions = []
with torch.no_grad():
    for signal in test_data:
        q_vals = policy_net(signal.unsqueeze(0))
        pred_label = q_vals.argmax().item()
        predictions.append(pred_label)

for fname, pred in zip(test_files, predictions):
    print(f"{fname}: {'Faulty' if pred else 'Normal'}")

results = {fname: 'Faulty' if pred else 'Normal' for fname, pred in zip(test_files, predictions)}
np.save('predictions_sensorRF2_training1920_test26_27.npy', results)
