In [None]:
model.eval()
dummy_input = torch.randn(1, 2, 1024, device=device)
onnx_path = "../models/amc_model.onnx"

torch.onnx.export(model, dummy_input, onnx_path, 
                  input_names=['input'], output_names=['output'], 
                  dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}})

print(f"Model exported to {onnx_path}")

## 5. Export to ONNX
Export the trained PyTorch model to ONNX format for use in the Inference UI.

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

num_epochs = 5

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for i, (inputs, labels) in enumerate(train_loader):
        inputs, labels = inputs.to(device), labels.to(device)
        labels = labels.long()
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}, Accuracy: {100 * correct / total:.2f}%")

## 4. Training Loop
Train the model for a few epochs.

In [None]:
class SimpleCNN1D(nn.Module):
    def __init__(self, num_classes):
        super(SimpleCNN1D, self).__init__()
        self.conv1 = nn.Conv1d(2, 64, kernel_size=3, padding=1)
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool1d(2)
        self.conv2 = nn.Conv1d(64, 128, kernel_size=3, padding=1)
        # 1024 -> 512 -> 256
        self.flatten_size = 128 * 256
        self.fc1 = nn.Linear(self.flatten_size, 256)
        self.fc2 = nn.Linear(256, num_classes)

    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x)))
        x = self.pool(self.relu(self.conv2(x)))
        x = x.view(-1, self.flatten_size)
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x

model = SimpleCNN1D(num_classes=start_classes).to(device)
print(model)

## 3. Define Model
A simple 1D CNN architecture for modulation classification.

In [None]:
# Transpose to (N, Channels, Length)
X_transposed = X.transpose(0, 2, 1).astype(np.float32)

X_train, X_test, y_train, y_test = train_test_split(X_transposed, y_encoded, test_size=0.2, random_state=42, stratify=y_encoded)

# Create DataLoaders
batch_size = 64

train_data = TensorDataset(torch.from_numpy(X_train), torch.from_numpy(y_train))
test_data = TensorDataset(torch.from_numpy(X_test), torch.from_numpy(y_test))

train_loader = DataLoader(train_data, shuffle=True, batch_size=batch_size)
test_loader = DataLoader(test_data, shuffle=False, batch_size=batch_size)

print(f"Train batches: {len(train_loader)}, Test batches: {len(test_loader)}")

## 2. Preprocessing & Splitting
Split the data into training and validation sets. We also transpose the data to `(N, C, L)` format required by PyTorch 1D convolutions if needed, but our data is `(N, 1024, 2)`.
Common RF CNNs like ResNet1D expect `(N, Channels, Length)`.
So `(N, 1024, 2)` -> `(N, 2, 1024)`.

In [None]:
DATA_PATH = '../models/data/rf_dataset.pkl'

with open(DATA_PATH, 'rb') as f:
    dataset = pickle.load(f)

X = dataset['data']
y = dataset['labels']
snr = dataset['snr']

print(f"Loaded dataset: X shape {X.shape}, y shape {y.shape}")

# Encode labels
le = LabelEncoder()
y_encoded = le.fit_transform(y)
start_classes = len(np.unique(y_encoded))
print(f"Classes: {le.classes_}")

## 1. Load Dataset
Load the pickle file generated from `dataset_generation.ipynb`.

In [None]:
import pickle
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
import matplotlib.pyplot as plt

# Check device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")