In [2]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import torch
import torchaudio
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import random
import time
import os
%matplotlib qt

In [3]:
class Plotter():
    def __init__(self, fs=8000, win_size=96, n_wins=10, n_bands=129, n_classes=2, msd_labels=None, FIG_SIZE=(12, 8), blit=True):
        self.blit = blit
        self.win_size = win_size
        self.n_wins = n_wins
        self.n_bands = n_bands
        self.n_classes = n_classes
        self.msd_labels = msd_labels
        self.fs = fs
        
        self.signal_length = 1024
        self.signal = np.zeros(self.signal_length)
        self.time = np.linspace(0, self.signal_length / self.fs, self.signal_length)

        self.spec = np.zeros((n_bands, win_size * n_wins))

        self.act = np.zeros((n_classes, n_wins))

        self.fig, (self.ax_signal, self.ax_spec, self.ax_act) = plt.subplots(3, 1, figsize=FIG_SIZE)

        self.line_signal, = self.ax_signal.plot(self.time, self.signal, color='blue')
        self.ax_signal.set_title("Señal en Tiempo")
        self.ax_signal.set_xlabel("Tiempo (s)")
        self.ax_signal.set_ylabel("Amplitud")
        self.ax_signal.set_xlim(0, self.signal_length / self.fs)
        self.ax_signal.set_ylim(-1.5, 1.5)

        self.rect = patches.Rectangle((0, -1.5), self.signal_length / self.fs, 3.0, linewidth=2, edgecolor='green', facecolor='none', alpha=0.7)
        self.ax_signal.add_patch(self.rect)

        self.img_spec = self.ax_spec.imshow(self.spec, vmin=0, vmax=1, interpolation="None", cmap="jet", aspect='auto')
        self.ax_spec.set_title("Espectrograma")
        self.ax_spec.invert_yaxis()

        self.img_act = self.ax_act.imshow(self.act, vmin=0, vmax=1, interpolation="none", aspect='auto')
        self.ax_act.set_title("Activaciones de Clase")

        if msd_labels is not None:
            self.ax_act.set_yticks(np.linspace(0, len(msd_labels), len(msd_labels), endpoint=False))
            self.ax_act.set_yticklabels(msd_labels)
            self.ax_act.set_ylim(-0.5, len(msd_labels)-0.5)

        self.fig.tight_layout()
        self.fig.canvas.draw()

        if self.blit:
            self.axbackground_signal = self.fig.canvas.copy_from_bbox(self.ax_signal.bbox)
            self.axbackground_spec = self.fig.canvas.copy_from_bbox(self.ax_spec.bbox)
            self.axbackground_act = self.fig.canvas.copy_from_bbox(self.ax_act.bbox)

        plt.ion()
        plt.show()

    def update(self, new_signal, new_spec_col, new_act_col, label):
        self.signal = np.delete(self.signal, [k for k in range(len(new_signal))], axis=0)
        self.signal = np.concatenate((self.signal, new_signal), axis=0)
        self.line_signal.set_ydata(self.signal)

        self.spec = np.delete(self.spec, [k for k in range(self.win_size)], axis=1)
        self.spec = np.concatenate((self.spec, new_spec_col), axis=1)
        self.img_spec.set_data(self.spec)
        self.img_spec.autoscale()

        self.act = np.delete(self.act, 0, axis=1)
        self.act = np.concatenate((self.act, new_act_col), axis=1)
        self.img_act.set_data(self.act)
        self.img_act.autoscale()

        start_time = 0
        end_time = self.signal_length / self.fs
        min_amp = self.signal.min()
        max_amp = self.signal.max()

        self.rect.set_xy((start_time, min_amp))
        self.rect.set_width(end_time - start_time)
        self.rect.set_height(max_amp - min_amp)

        if label == 0:
            self.rect.set_edgecolor('red')
        elif label == 1:
            self.rect.set_edgecolor('green')
        else:
            self.rect.set_edgecolor('black')

        if self.blit:
            self.fig.canvas.restore_region(self.axbackground_signal)
            self.fig.canvas.restore_region(self.axbackground_spec)
            self.fig.canvas.restore_region(self.axbackground_act)

            self.ax_signal.draw_artist(self.line_signal)
            self.ax_signal.draw_artist(self.rect)
            self.ax_spec.draw_artist(self.img_spec)
            self.ax_act.draw_artist(self.img_act)

            self.fig.canvas.blit(self.ax_signal.bbox)
            self.fig.canvas.blit(self.ax_spec.bbox)
            self.fig.canvas.blit(self.ax_act.bbox)
        else:
            self.fig.canvas.draw()

        self.fig.canvas.flush_events()

In [4]:
fs = 8000
duration = 1.0
t = np.linspace(0, duration, int(fs * duration), endpoint=False)

In [5]:
def generate_signal(func, freq=5, phase=0):
    global t

    if func == 'sine':
        return np.sin(2 * np.pi * freq * t + phase)
    elif func == 'tangent':
        return np.tan(2 * np.pi * freq * t + phase)
    else:
        raise ValueError("Función desconocida. Usa 'sine' o 'tangent'.")

In [6]:
n_fft = 256
hop_length = 128
window = torch.hamming_window(n_fft)

In [7]:
def compute_stft(signal):
    signal_tensor = torch.tensor(signal, dtype=torch.float32)
    stft = torch.stft(signal_tensor, n_fft=n_fft, hop_length=hop_length, window=window, return_complex=True)
    magnitude = torch.abs(stft)
    magnitude = magnitude / (magnitude.max() + 1e-8)
    return magnitude

In [8]:
class TrigDataset(Dataset):
    def __init__(self, num_samples=1000, transform=None):
        self.num_samples = num_samples
        self.transform = transform
        self.classes = ['sine', 'tangent']

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        label = random.randint(0, 1)
        func = self.classes[label]

        freq = random.uniform(1, 50)
        phase = random.uniform(0, 2 * np.pi)
        
        signal = generate_signal(func, freq, phase)

        stft = compute_stft(signal)

        if self.transform:
            stft = self.transform(stft)

        stft = stft.unsqueeze(0)

        return stft, label

In [9]:
class STFTClassifier(nn.Module):
    def __init__(self, n_classes=2):
        super(STFTClassifier, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)

        self.fc1 = nn.Linear(32 * 32 * 15, 128)
        self.fc2 = nn.Linear(128, n_classes)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [10]:
def train_model():
    train_dataset = TrigDataset(num_samples=3000)
    test_dataset = TrigDataset(num_samples=600)

    batch_size = 16
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    model = STFTClassifier(n_classes=2)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    num_epochs = 10
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    for epoch in range(num_epochs):
        running_loss = 0.0
        model.train()
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()

            outputs = model(inputs)
            loss = criterion(outputs, labels)

            loss.backward()
            optimizer.step()

            running_loss += loss.item() * inputs.size(0)

        epoch_loss = running_loss / len(train_dataset)

        model.eval()
        correct = 0
        total = 0
        confusion_matrix = np.zeros((2, 2), dtype=int)

        with torch.no_grad():
            for inputs, labels in test_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

                for t_label, p_label in zip(labels.view(-1), predicted.view(-1)):
                    confusion_matrix[t_label.long(), p_label.long()] += 1

        accuracy = correct / total * 100
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}, Test Accuracy: {accuracy:.2f}%")

    model.eval()
    correct = 0
    total = 0
    confusion_matrix = np.zeros((2, 2), dtype=int)

    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            for t_label, p_label in zip(labels.view(-1), predicted.view(-1)):
                confusion_matrix[t_label.long(), p_label.long()] += 1

    accuracy = correct / total * 100
    print(f"Precisión en el conjunto de prueba: {accuracy:.2f}%")
    print("Matriz de Confusión:")
    print(confusion_matrix)

    torch.save(model.state_dict(), "stft_classifier.pth")
    print("Modelo guardado como 'stft_classifier.pth'.")

In [11]:
def real_time_plotting(model_path="stft_classifier.pth"):
    model = STFTClassifier(n_classes=2)
    model.load_state_dict(torch.load(model_path))
    model.eval()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    labels = ['Seno', 'Tangente']
    plotter = Plotter(fs=8000, n_bands=129, n_classes=2, msd_labels=labels)

    try:
        while True:
            label = random.randint(0, 1)
            func = ['sine', 'tangent'][label]

            freq = random.uniform(1, 50)
            phase = random.uniform(0, 2 * np.pi)

            signal = generate_signal(func, freq, phase)

            stft = compute_stft(signal).unsqueeze(0).unsqueeze(0)
            stft = stft.to(device)

            with torch.no_grad():
                output = model(stft)
                probabilities = F.softmax(output, dim=1).cpu().numpy()[0]
                predicted_label = np.argmax(probabilities)

            spec_col = stft.cpu().numpy().squeeze(0).squeeze(0)
            act_col = probabilities.reshape(-1, 1)

            win_size = plotter.win_size
            if spec_col.shape[1] > win_size:
                spec_col = spec_col[:, -win_size:]
            elif spec_col.shape[1] < win_size:
                padding = win_size - spec_col.shape[1]
                spec_col = np.pad(spec_col, ((0,0),(padding,0)), mode='constant')

            if len(signal) > plotter.signal_length:
                signal_plot = signal[-plotter.signal_length:]
            else:
                padding = plotter.signal_length - len(signal)
                signal_plot = np.pad(signal, (padding, 0), 'constant')

            plotter.update(signal_plot, spec_col, act_col, predicted_label)

            time.sleep(0.5)

    except KeyboardInterrupt:
        print("Finalizando la visualización en tiempo real.")

In [None]:
if __name__ == "__main__":
    modelo_guardado = "stft_classifier.pth"
    if not os.path.exists(modelo_guardado):
        print("Entrenando el modelo...")
        train_model()
    else:
        print("Modelo ya entrenado. Iniciando la visualización en tiempo real.")

    real_time_plotting(model_path=modelo_guardado)

Modelo ya entrenado. Iniciando la visualización en tiempo real.


  model.load_state_dict(torch.load(model_path))
