# ============================================
# **Cell 1/5: Realistic Synthetic Drilling + Engineering Logs (PyTorch)**
# ============================================

In [None]:
!pip install -q librosa soundfile scikit-learn

import os
import numpy as np
import matplotlib.pyplot as plt

import librosa

from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

print("PyTorch version:", torch.__version__)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

## Global parameters

In [None]:
SAMPLE_RATE = 16000
SEGMENT_DURATION = 2.0
SEGMENT_SAMPLES = int(SAMPLE_RATE * SEGMENT_DURATION)

## Simple frequency-domain band-pass filter

In [None]:
def bandpass_filter(wav, sr, f_low, f_high):
    n = len(wav)
    freqs = np.fft.rfftfreq(n, d=1.0/sr)
    spec = np.fft.rfft(wav)
    mask = (freqs >= f_low) & (freqs <= f_high)
    spec_filtered = spec * mask
    wav_filtered = np.fft.irfft(spec_filtered, n=n)
    return wav_filtered.astype(np.float32)

## Noise components: mud pump, top drive, drillstring

In [None]:
def mud_pump_noise(n_samples):
    t = np.linspace(0, SEGMENT_DURATION, SEGMENT_SAMPLES, endpoint=False)
    signals = []
    for _ in range(n_samples):
        base = np.random.uniform(3, 7)
        sig = 0.8 * np.sin(2 * np.pi * base * t)
        sig += 0.4 * np.sin(2 * np.pi * 2*base * t + np.random.uniform(0, 2*np.pi))
        sig += 0.2 * np.sin(2 * np.pi * 3*base * t + np.random.uniform(0, 2*np.pi))
        noise = 0.1 * np.random.randn(SEGMENT_SAMPLES)
        sig = (sig + noise).astype(np.float32)
        sig = bandpass_filter(sig, SAMPLE_RATE, 1, 30)
        signals.append(sig)
    return np.stack(signals, axis=0)

def top_drive_noise(n_samples):
    t = np.linspace(0, SEGMENT_DURATION, SEGMENT_SAMPLES, endpoint=False)
    signals = []
    for _ in range(n_samples):
        base = np.random.uniform(60, 150)
        sig = 0.6 * np.sin(2 * np.pi * base * t)
        sig += 0.3 * np.sin(2 * np.pi * 2*base * t + np.random.uniform(0, 2*np.pi))
        sig += 0.15 * np.sin(2 * np.pi * 3*base * t + np.random.uniform(0, 2*np.pi))
        noise = 0.1 * np.random.randn(SEGMENT_SAMPLES)
        sig = (sig + noise).astype(np.float32)
        sig = bandpass_filter(sig, SAMPLE_RATE, 40, 400)
        signals.append(sig)
    return np.stack(signals, axis=0)

def drillstring_resonance(n_samples):
    t = np.linspace(0, SEGMENT_DURATION, SEGMENT_SAMPLES, endpoint=False)
    signals = []
    for _ in range(n_samples):
        modes = np.random.randint(2, 5)
        sig = np.zeros_like(t, dtype=np.float32)
        for _ in range(modes):
            f = np.random.uniform(20, 80)
            amp = np.random.uniform(0.2, 0.7)
            phase = np.random.uniform(0, 2*np.pi)
            sig += amp * np.sin(2 * np.pi * f * t + phase)
        sig += 0.15 * np.random.randn(SEGMENT_SAMPLES)
        sig = bandpass_filter(sig, SAMPLE_RATE, 10, 120)
        signals.append(sig.astype(np.float32))
    return np.stack(signals, axis=0)

## Engineering logs generator (RPM, WOB, Torque, SPP, ROP, Flow)

In [None]:
def generate_engineering_logs(n_samples, mode):
    """
    mode: 'normal', 'stick_slip', 'bit_bounce'
    Produces realistic ranges per condition.
    """
    logs = []
    for _ in range(n_samples):
        if mode == "normal":
            rpm   = np.random.normal(120, 5)
            wob   = np.random.normal(80, 5)     # kN
            torque= np.random.normal(15, 1.5)   # kNm
            spp   = np.random.normal(220, 10)   # bar
            rop   = np.random.normal(10, 1.0)   # m/h
            flow  = np.random.normal(1800, 80)  # l/min
        elif mode == "stick_slip":
            rpm   = np.random.normal(115, 10)    # more variation
            wob   = np.random.normal(85, 8)
            torque= np.random.normal(20, 3.0)    # higher, more variable
            spp   = np.random.normal(230, 12)
            rop   = np.random.normal(9, 1.5)
            flow  = np.random.normal(1800, 80)
        elif mode == "bit_bounce":
            rpm   = np.random.normal(120, 5)
            wob   = np.random.normal(95, 10)     # higher WOB
            torque= np.random.normal(18, 2.5)
            spp   = np.random.normal(240, 15)
            rop   = np.random.normal(8, 2.0)     # slower, unstable
            flow  = np.random.normal(1800, 100)
        else:
            rpm=wob=torque=spp=rop=flow=0.0

        logs.append([rpm, wob, torque, spp, rop, flow])

    return np.array(logs, dtype=np.float32)

## Drilling conditions (signal + engineering logs)

In [None]:
def generate_normal_drilling_v2(n_samples):
    td = top_drive_noise(n_samples)
    mp = mud_pump_noise(n_samples)
    ds = drillstring_resonance(n_samples)
    signals = td + 0.7*mp + 0.5*ds
    signals += 0.1 * np.random.randn(n_samples, SEGMENT_SAMPLES).astype(np.float32)
    signals = signals / (np.max(np.abs(signals), axis=1, keepdims=True) + 1e-6)
    logs = generate_engineering_logs(n_samples, "normal")
    return signals.astype(np.float32), logs

def generate_stick_slip_v2(n_samples):
    base, logs = generate_normal_drilling_v2(n_samples)
    t = np.linspace(0, SEGMENT_DURATION, SEGMENT_SAMPLES, endpoint=False)
    signals = []
    for i in range(n_samples):
        sig = base[i].copy()
        freq_mod = np.random.uniform(0.5, 3.0)
        mod = 1.0 + 0.5 * np.sin(2 * np.pi * freq_mod * t + np.random.uniform(0,2*np.pi))
        sig *= mod
        torque_freq = np.random.uniform(0.2, 1.0)
        sig += 0.3 * np.sin(2 * np.pi * torque_freq * t + np.random.uniform(0,2*np.pi))
        for _ in range(np.random.randint(2, 5)):
            length = np.random.randint(800, 1600)
            start = np.random.randint(0, SEGMENT_SAMPLES - length)
            window = np.hanning(length) * np.random.uniform(0.5, 1.2)
            sig[start:start+length] += window
        sig += 0.15 * np.random.randn(SEGMENT_SAMPLES)
        sig = sig / (np.max(np.abs(sig)) + 1e-6)
        signals.append(sig.astype(np.float32))
    signals = np.stack(signals, axis=0)
    logs = generate_engineering_logs(n_samples, "stick_slip")
    return signals, logs

def generate_bit_bounce_v2(n_samples):
    base, logs = generate_normal_drilling_v2(n_samples)
    signals = []
    for i in range(n_samples):
        sig = base[i].copy()
        n_impacts = np.random.randint(5, 15)
        for _ in range(n_impacts):
            center = np.random.randint(400, SEGMENT_SAMPLES-400)
            width = np.random.randint(80, 200)
            pulse = np.hanning(width) * np.random.uniform(0.8, 1.5)
            end = min(center+width, SEGMENT_SAMPLES)
            sig[center:end] += pulse[:end-center]
        sig = bandpass_filter(sig, SAMPLE_RATE, 20, 600)
        sig += 0.2 * np.random.randn(SEGMENT_SAMPLES)
        sig = sig / (np.max(np.abs(sig)) + 1e-6)
        signals.append(sig.astype(np.float32))
    signals = np.stack(signals, axis=0)
    logs = generate_engineering_logs(n_samples, "bit_bounce")
    return signals, logs

## Build dataset

In [None]:
N_PER_CLASS = 300

Xn_sig, Xn_logs = generate_normal_drilling_v2(N_PER_CLASS)
Xs_sig, Xs_logs = generate_stick_slip_v2(N_PER_CLASS)
Xb_sig, Xb_logs = generate_bit_bounce_v2(N_PER_CLASS)

X_wave = np.concatenate([Xn_sig, Xs_sig, Xb_sig], axis=0)
X_logs = np.concatenate([Xn_logs, Xs_logs, Xb_logs], axis=0)

y_normal = np.zeros(N_PER_CLASS, dtype=np.int64)
y_stick  = np.ones(N_PER_CLASS, dtype=np.int64)
y_bounce = np.full(N_PER_CLASS, 2, dtype=np.int64)
y_class = np.concatenate([y_normal, y_stick, y_bounce], axis=0)

classes = ["normal", "stick_slip", "bit_bounce"]
num_classes = len(classes)
num_log_features = X_logs.shape[1]

print("Waveform dataset shape:", X_wave.shape)
print("Engineering logs shape:", X_logs.shape)
print("Class distribution:", np.bincount(y_class), "=>", classes)

# Quick visualization
plt.figure(figsize=(10,4))
plt.plot(X_wave[0])
plt.title("Example waveform (normal_v2)")
plt.xlabel("Samples")
plt.tight_layout()
plt.show()

plt.figure(figsize=(6,4))
plt.scatter(X_logs[:,0], X_logs[:,2], c=y_class, cmap="viridis", alpha=0.5)
plt.xlabel("RPM")
plt.ylabel("Torque")
plt.title("Engineering logs space (RPM vs Torque)")
plt.tight_layout()
plt.show()

print("Cell 1/5 ready (Realistic Synthetic Drilling + Logs).")

# ============================================
# **Cell 2/5: Log-Mel Features + Dual-Input Dataset**
# ============================================

In [None]:
N_MELS = 64
HOP_LENGTH = 256
N_FFT = 1024

def compute_log_mel(wav, sr=SAMPLE_RATE):
    mel = librosa.feature.melspectrogram(
        y=wav,
        sr=sr,
        n_fft=N_FFT,
        hop_length=HOP_LENGTH,
        n_mels=N_MELS,
        power=2.0,
    )
    log_mel = librosa.power_to_db(mel, ref=np.max)
    return log_mel.astype(np.float32)  # (F, T)

print("Computing log-mel features...")
mel_list = []
for i, seg in enumerate(X_wave):
    if (i+1) % 200 == 0:
        print(f"{i+1}/{len(X_wave)} processed...")
    m = compute_log_mel(seg)  # (F, T)
    mel_list.append(m)

X_mel = np.stack(mel_list, axis=0)  # (N, F, T)
print("Raw log-mel shape (N, F, T):", X_mel.shape)

# Normalize mel globally
mean_mel = X_mel.mean()
std_mel = X_mel.std() + 1e-6
X_mel = (X_mel - mean_mel) / std_mel

# Add channel dim for CNN: (N, 1, F, T)
X_mel = X_mel[:, np.newaxis, :, :]

# Normalize engineering logs feature-wise
logs_mean = X_logs.mean(axis=0, keepdims=True)
logs_std = X_logs.std(axis=0, keepdims=True) + 1e-6
X_logs_norm = (X_logs - logs_mean) / logs_std

# Train/val split
X_mel_train, X_mel_val, X_logs_train, X_logs_val, y_train, y_val = train_test_split(
    X_mel, X_logs_norm, y_class, test_size=0.2, random_state=42, stratify=y_class
)

print("X_mel_train:", X_mel_train.shape, "X_logs_train:", X_logs_train.shape)
print("X_mel_val:", X_mel_val.shape, "X_logs_val:", X_logs_val.shape)

class DrillingFusionDataset(Dataset):
    def __init__(self, X_mel, X_logs, y):
        self.X_mel = torch.tensor(X_mel, dtype=torch.float32)     # (N,1,F,T)
        self.X_logs = torch.tensor(X_logs, dtype=torch.float32)   # (N,Flogs)
        self.y = torch.tensor(y, dtype=torch.long)

    def __len__(self):
        return self.X_mel.shape[0]

    def __getitem__(self, idx):
        return self.X_mel[idx], self.X_logs[idx], self.y[idx]

train_dataset = DrillingFusionDataset(X_mel_train, X_logs_train, y_train)
val_dataset   = DrillingFusionDataset(X_mel_val,  X_logs_val,  y_val)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2, pin_memory=True)
val_loader   = DataLoader(val_dataset,   batch_size=32, shuffle=False, num_workers=2, pin_memory=True)

print("Train/Val samples:", len(train_dataset), len(val_dataset))
print("Cell 2/5 ready.")

# ============================================
# **Cell 3/5: Audio Branch (Lite Conformer) + Logs Branch (MLP) + Fusion**
# ============================================

In [None]:
class SpecAugmentTorch(nn.Module):
    """
    Simple SpecAugment for (B, 1, F, T) mel-spectrograms.
    """
    def __init__(self, time_mask_param=15, freq_mask_param=8,
                 num_time_masks=2, num_freq_masks=2):
        super().__init__()
        self.time_mask_param = time_mask_param
        self.freq_mask_param = freq_mask_param
        self.num_time_masks = num_time_masks
        self.num_freq_masks = num_freq_masks

    def forward(self, x):
        if not self.training:
            return x
        B, C, F, T = x.shape
        x_aug = x.clone()
        for b in range(B):
            # Time masks
            for _ in range(self.num_time_masks):
                t = np.random.randint(0, self.time_mask_param + 1)
                if t == 0 or t >= T:
                    continue
                t0 = np.random.randint(0, T - t + 1)
                x_aug[b, :, :, t0:t0+t] = 0.0
            # Frequency masks
            for _ in range(self.num_freq_masks):
                f = np.random.randint(0, self.freq_mask_param + 1)
                if f == 0 or f >= F:
                    continue
                f0 = np.random.randint(0, F - f + 1)
                x_aug[b, :, f0:f0+f, :] = 0.0
        return x_aug

class LiteConformerBlock(nn.Module):
    def __init__(self, d_model=128, d_ff=256, num_heads=2, dropout=0.3):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)
        self.ffn1 = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.SiLU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model),
            nn.Dropout(dropout),
        )

        self.ln2 = nn.LayerNorm(d_model)
        self.mha = nn.MultiheadAttention(embed_dim=d_model, num_heads=num_heads,
                                         dropout=dropout, batch_first=True)
        self.dropout_mha = nn.Dropout(dropout)

        self.ln3 = nn.LayerNorm(d_model)
        self.pw_conv1 = nn.Conv1d(d_model, 2*d_model, kernel_size=1)
        self.dw_conv  = nn.Conv1d(d_model, d_model, kernel_size=3, padding=1, groups=d_model)
        self.bn       = nn.BatchNorm1d(d_model)
        self.pw_conv2 = nn.Conv1d(d_model, d_model, kernel_size=1)
        self.dropout_conv = nn.Dropout(dropout)

        self.ln4 = nn.LayerNorm(d_model)
        self.ffn2 = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.SiLU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        # x: (B, T, D)
        y = self.ln1(x)
        y = self.ffn1(y)
        x = x + 0.5*y

        y = self.ln2(x)
        attn_out, _ = self.mha(y, y, y)
        y = self.dropout_mha(attn_out)
        x = x + y

        y = self.ln3(x)
        y = y.transpose(1,2)  # (B,D,T)
        y = self.pw_conv1(y)
        D = y.size(1)//2
        y_lin  = y[:, :D, :]
        y_gate = y[:, D:, :]
        y = y_lin * torch.sigmoid(y_gate)
        y = self.dw_conv(y)
        y = self.bn(y)
        y = F.silu(y)
        y = self.pw_conv2(y)
        y = self.dropout_conv(y)
        y = y.transpose(1,2)
        x = x + y

        y = self.ln4(x)
        y = self.ffn2(y)
        x = x + 0.5*y
        return x

class AudioBranch(nn.Module):
    """
    CNN front-end + Lite Conformer encoder on log-mel.
    Input: (B,1,F,T) -> Output: (B, d_model)
    """
    def __init__(self, d_model=128, num_blocks=2, dropout=0.3):
        super().__init__()
        self.specaug = SpecAugmentTorch()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=(3,3), padding=1),
            nn.BatchNorm2d(32),
            nn.SiLU(),
            nn.MaxPool2d(kernel_size=(2,2)),   # F/2, T/2
            nn.Conv2d(32, 64, kernel_size=(3,3), padding=1),
            nn.BatchNorm2d(64),
            nn.SiLU(),
            nn.MaxPool2d(kernel_size=(2,2)),   # F/4, T/4
        )
        self.d_model = d_model
        self.proj = None
        self.blocks = nn.ModuleList([
            LiteConformerBlock(d_model=d_model, d_ff=2*d_model, num_heads=2, dropout=dropout)
            for _ in range(num_blocks)
        ])
        self.ln_out = nn.LayerNorm(d_model)

    def forward(self, x):
        x = self.specaug(x)
        x = self.conv(x)        # (B,C,F',T')
        B, C, Fp, Tp = x.shape
        x = x.permute(0,3,1,2).contiguous().view(B, Tp, C*Fp)  # (B,T',C*F')
        if self.proj is None:
            in_dim = C*Fp
            self.proj = nn.Linear(in_dim, self.d_model).to(x.device)
        x = self.proj(x)        # (B,T',d_model)
        for blk in self.blocks:
            x = blk(x)
        x = self.ln_out(x)
        x = x.mean(dim=1)       # global average pooling over time
        return x                # (B,d_model)

class LogsBranch(nn.Module):
    """
    Simple MLP for engineering logs.
    Input: (B, num_logs) -> Output: (B, d_log)
    """
    def __init__(self, num_logs, d_hidden=64, d_out=64, dropout=0.2):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(num_logs, d_hidden),
            nn.SiLU(),
            nn.Dropout(dropout),
            nn.Linear(d_hidden, d_out),
            nn.SiLU(),
        )

    def forward(self, x):
        return self.net(x)

class DrillingFusionModel(nn.Module):
    """
    Fusion model:
    - Audio branch: Lite Conformer on log-mel
    - Logs branch: MLP on engineering logs
    - Fusion: concatenation + final classifier
    """
    def __init__(self, num_classes, num_logs, d_audio=128, d_logs=64, dropout=0.3):
        super().__init__()
        self.audio_branch = AudioBranch(d_model=d_audio, num_blocks=2, dropout=dropout)
        self.logs_branch  = LogsBranch(num_logs=num_logs, d_hidden=64, d_out=d_logs, dropout=0.2)

        self.fusion = nn.Sequential(
            nn.Linear(d_audio + d_logs, 128),
            nn.SiLU(),
            nn.Dropout(dropout),
            nn.Linear(128, num_classes)
        )

    def forward(self, mel, logs):
        # mel:  (B,1,F,T)
        # logs: (B,num_logs)
        a_emb = self.audio_branch(mel)
        l_emb = self.logs_branch(logs)
        x = torch.cat([a_emb, l_emb], dim=1)
        logits = self.fusion(x)
        return logits

model = DrillingFusionModel(num_classes=num_classes,
                            num_logs=num_log_features,
                            d_audio=128,
                            d_logs=64,
                            dropout=0.3).to(device)

print(model)
print("Cell 3/5 ready.")

# ============================================
# **Cell 4/5: Training Loop for Fusion Model**
# ============================================

In [None]:
def train_one_epoch(model, loader, optimizer, criterion, device):
    model.train()
    total_loss = 0.0
    correct = 0
    total = 0
    for mel, logs, y in loader:
        mel  = mel.to(device)    # (B,1,F,T)
        logs = logs.to(device)   # (B,num_logs)
        y    = y.to(device)

        optimizer.zero_grad()
        logits = model(mel, logs)
        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()*mel.size(0)
        preds = logits.argmax(dim=1)
        correct += (preds == y).sum().item()
        total += mel.size(0)
    return total_loss/total, correct/total

def eval_one_epoch(model, loader, criterion, device):
    model.eval()
    total_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for mel, logs, y in loader:
            mel  = mel.to(device)
            logs = logs.to(device)
            y    = y.to(device)
            logits = model(mel, logs)
            loss = criterion(logits, y)
            total_loss += loss.item()*mel.size(0)
            preds = logits.argmax(dim=1)
            correct += (preds == y).sum().item()
            total += mel.size(0)
    return total_loss/total, correct/total

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)

EPOCHS = 20
best_val_acc = 0.0
train_losses, val_losses = [], []
train_accs, val_accs = [], []

for epoch in range(1, EPOCHS+1):
    train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, criterion, device)
    val_loss, val_acc     = eval_one_epoch(model, val_loader,   criterion, device)

    train_losses.append(train_loss)
    val_losses.append(val_loss)
    train_accs.append(train_acc)
    val_accs.append(val_acc)

    print(f"Epoch {epoch:02d}/{EPOCHS} - "
          f"train_loss: {train_loss:.4f} - train_acc: {train_acc:.4f} - "
          f"val_loss: {val_loss:.4f} - val_acc: {val_acc:.4f}")

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), "best_drilling_fusion_model.pt")

plt.figure(figsize=(12,4))
plt.subplot(1,2,1)
plt.plot(train_losses, label="train")
plt.plot(val_losses, label="val")
plt.title("Loss"); plt.legend()

plt.subplot(1,2,2)
plt.plot(train_accs, label="train")
plt.plot(val_accs, label="val")
plt.title("Accuracy"); plt.legend()

plt.tight_layout()
plt.show()

print("Cell 4/5 ready.")

# ============================================
# **Cell 5/5: Evaluation, Confusion Matrix, and Visualization**
# ============================================

In [None]:
# Load best checkpoint
model.load_state_dict(torch.load("best_drilling_fusion_model.pt", map_location=device))
model.eval()

all_targets = []
all_preds = []

with torch.no_grad():
    for mel, logs, y in val_loader:
        mel = mel.to(device)
        logs = logs.to(device)
        logits = model(mel, logs)
        preds = logits.argmax(dim=1).cpu().numpy()
        all_preds.append(preds)
        all_targets.append(y.numpy())

all_preds   = np.concatenate(all_preds,   axis=0)
all_targets = np.concatenate(all_targets, axis=0)

print("=== Drilling Condition Classification Report (Fusion) ===")
print(classification_report(all_targets, all_preds, digits=4, target_names=classes))

cm = confusion_matrix(all_targets, all_preds)

import seaborn as sns
plt.figure(figsize=(6,5))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
            xticklabels=classes, yticklabels=classes)
plt.xlabel("Predicted")
plt.ylabel("True")
plt.title("Confusion Matrix - Drilling Conditions (Audio + Logs)")
plt.tight_layout()
plt.show()

# Visualize some examples with both modalities
def visualize_example(idx):
    mel = X_mel_val[idx, 0]       # (F,T)
    logs = X_logs_val[idx]        # (num_logs)
    true_c = classes[y_val[idx]]
    pred_c = classes[all_preds[idx]]

    plt.figure(figsize=(10,4))
    plt.subplot(1,2,1)
    plt.imshow(mel, aspect="auto", origin="lower", cmap="magma")
    plt.title(f"Log-Mel | True: {true_c} | Pred: {pred_c}")
    plt.xlabel("Time frames")
    plt.ylabel("Mel bins")

    plt.subplot(1,2,2)
    feature_names = ["RPM","WOB","Torque","SPP","ROP","Flow"]
    plt.bar(range(len(feature_names)), logs, tick_label=feature_names)
    plt.xticks(rotation=45)
    plt.title("Engineering Logs (normalized)")
    plt.tight_layout()
    plt.show()

print("Sample validation examples (fusion):")
for i in range(min(5, len(X_mel_val))):
    visualize_example(i)

print("Project completed: Fusion of drilling signal + engineering logs with Lite Conformer + MLP.")