In [42]:
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger
import matplotlib.pyplot as plt

In [43]:
def qpsk_augmentation(x):
    x = x.clone()
    x += torch.randn_like(x) * 0.05  # white noise
    if torch.rand(1).item() < 0.5:
        phase_shift = 2 * np.pi * torch.rand(1).item()
        rotation = torch.tensor([
            [np.cos(phase_shift), -np.sin(phase_shift)],
            [np.sin(phase_shift),  np.cos(phase_shift)]
        ], dtype=torch.float32, device=x.device)
        x = rotation @ x
    return x

In [44]:
class QPSKDataset(Dataset):
    def __init__(self, num_samples, time, symbol_rate, sample_rate, carrying_freq, augment_fn):
        self.num_samples = num_samples
        self.time = time
        self.symbol_rate = symbol_rate
        self.sample_rate = sample_rate
        self.carrying_freq = carrying_freq
        self.augment_fn = augment_fn

    def __len__(self):
        return self.num_samples

    def generate_qpsk(self):
        num_symbols = int(self.time * self.symbol_rate)
        num_samples = int(self.time * self.sample_rate)
        samples_per_symbol = int(self.sample_rate / self.symbol_rate)

        bits = np.random.randint(0, 2, size=num_symbols * 2)
        symbols = bits.reshape(-1, 2)

        phase_map = {
            (0, 0): (1, 1),
            (0, 1): (-1, 1),
            (1, 1): (-1, -1),
            (1, 0): (1, -1)
        }

        iq = np.array([phase_map[tuple(b)] for b in symbols])
        i_vals, q_vals = iq[:, 0], iq[:, 1]

        i_samples = np.repeat(i_vals, samples_per_symbol)
        q_samples = np.repeat(q_vals, samples_per_symbol)

        t = np.linspace(0, self.time, int(self.time * self.sample_rate), endpoint=False)

        carrier_cos = np.cos(2 * np.pi * self.carrying_freq * t) * (np.sqrt(2)/2)
        carrier_sin = np.sin(2 * np.pi * self.carrying_freq * t) * (np.sqrt(2)/2)

        signal = i_samples * carrier_cos - q_samples * carrier_sin
        iq_signal = np.stack([i_samples, q_samples], axis=0)  # Shape: [2, T]
        return iq_signal.astype(np.float32)

    def __getitem__(self, idx):
        clean = self.generate_qpsk()
        noisy_1 = self.augment_fn(torch.tensor(clean.copy()))
        noisy_2 = self.augment_fn(torch.tensor(clean.copy()))
        return noisy_1, noisy_2


class QPSKDataModule(pl.LightningDataModule):
    def __init__(self, batch_size=32, num_workers=4, **signal_params):
        super().__init__()
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.signal_params = signal_params

    def setup(self, stage=None):
        self.dataset = QPSKDataset(
            num_samples=10000,
            augment_fn=qpsk_augmentation,
            **self.signal_params
        )
        self.val_dataset = QPSKDataset(
            num_samples=1000,
            augment_fn=qpsk_augmentation,
            **self.signal_params
        )

    def train_dataloader(self):
        return DataLoader(
            self.dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            drop_last=True,
        )
    
    def val_dataloader(self):
        return DataLoader(
            self.val_dataset, 
            batch_size=self.batch_size, 
            shuffle=False, 
            num_workers=self.num_workers
        )


In [45]:
class Encoder(nn.Module):
    def __init__(self, input_channels=2, hidden_dim=64, latent_dim=2):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(input_channels, hidden_dim, kernel_size=5, padding=2),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, hidden_dim, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool1d(1),
        )
        self.mu_head = nn.Linear(hidden_dim, latent_dim)
        self.logvar_head = nn.Linear(hidden_dim, latent_dim)

    def forward(self, x):
        h = self.net(x).squeeze(-1)
        mu = self.mu_head(h)
        logvar = self.logvar_head(h)
        return mu, logvar


def reparameterize(mu, logvar):
    std = torch.exp(0.5 * logvar)
    eps = torch.randn_like(std)
    return mu + eps * std


class Denoiser(nn.Module):
    def __init__(self, latent_dim=2):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(latent_dim, 64),
            nn.ReLU(),
            nn.Linear(64, latent_dim),
        )

    def forward(self, z):
        out = self.mlp(z)
        norm = torch.norm(out, dim=-1, keepdim=True) + 1e-8
        return out / norm  # Ensure output is on unit circle


class QPSKDenoiserVSSL(pl.LightningModule):
    def __init__(self, ema_decay=0.99, lr=1e-3):
        super().__init__()
        self.save_hyperparameters()
        self.student = Encoder()
        self.teacher = Encoder()
        self.denoiser = Denoiser()
        self.ema_decay = ema_decay
        self.lr = lr
        self.register_buffer("global_step_float", torch.tensor(0.0))

        # Initialize teacher with student weights
        self._update_teacher(0.0)

    def _update_teacher(self, decay=None):
        decay = self.ema_decay if decay is None else decay
        for t_param, s_param in zip(self.teacher.parameters(), self.student.parameters()):
            t_param.data = decay * t_param.data + (1.0 - decay) * s_param.data

    def forward(self, x):
        mu, logvar = self.student(x)
        z = reparameterize(mu, logvar)
        return self.denoiser(z)

    def training_step(self, batch, batch_idx):
        x1, x2 = batch  # each shape: [B, 2, T]

        # Teacher (no grad)
        with torch.no_grad():
            mu1, logvar1 = self.teacher(x1)
        p_mu, p_logvar = mu1, logvar1

        # Student
        mu2, logvar2 = self.student(x2)
        z = reparameterize(mu2, logvar2)

        # Denoise
        z_hat = self.denoiser(z)

        # KL between Gaussians
        kl = 0.5 * torch.sum(
            torch.exp(logvar2 - p_logvar)
            + (mu2 - p_mu) ** 2 / torch.exp(p_logvar)
            - 1 + p_logvar - logvar2,
            dim=1
        ).mean()

        # Likelihood: want denoised z_hat ≈ z sampled from prior
        with torch.no_grad():
            z_target = reparameterize(p_mu, p_logvar)
            z_target = z_target / (torch.norm(z_target, dim=-1, keepdim=True) + 1e-8)

        recon_loss = F.mse_loss(z_hat, z_target)

        # Phase derivative regularization
        phase = torch.atan2(z_hat[:,1], z_hat[:,0])  # shape: [B]
        phase_diff = torch.diff(phase)
        phase_smoothness = (phase_diff**2).mean()

        constellation_error = self._calculate_constellation_error(z_hat)

        loss = kl + recon_loss + 0.1 * phase_smoothness + 0.1 * constellation_error

        self.log_dict({
            "train/loss": loss,
            "train/kl": kl,
            "train/recon": recon_loss,
            "train/phase_smoothness": phase_smoothness,
            "train/constellation_error": constellation_error
        }, prog_bar=True)


        # EMA update
        self._update_teacher()
        self.global_step_float += 1.0

        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)
        return [optimizer], [scheduler]


    def _calculate_constellation_error(self, z_hat):
        """Calculate how close points are to ideal QPSK constellation points"""
        z_norm = z_hat / (torch.norm(z_hat, dim=-1, keepdim=True) + 1e-8)
        
        ideal_points = torch.tensor([[1, 1], [-1, 1], [-1, -1], [1, -1]], 
                                  dtype=z_hat.dtype, device=z_hat.device) * (np.sqrt(2)/2)
        
        distances = torch.cdist(z_norm, ideal_points)
        min_distances = distances.min(dim=1)[0]
        
        return min_distances.mean()
    

    def validation_step(self, batch, batch_idx):
        x1, x2 = batch
        
        with torch.no_grad():
            clean_mu, _ = self.teacher(x1)
            clean_z = reparameterize(clean_mu, torch.zeros_like(clean_mu))
            clean_z = clean_z / (torch.norm(clean_z, dim=-1, keepdim=True) + 1e-8)
            
            mu, logvar = self.student(x2)
            z = reparameterize(mu, logvar)
            z_hat = self.denoiser(z)
            z_hat = z_hat / (torch.norm(z_hat, dim=-1, keepdim=True) + 1e-8)
            
            mse = F.mse_loss(z_hat, clean_z)
            phase_error = self._phase_error(z_hat, clean_z)
            constellation_error = self._calculate_constellation_error(z_hat)
            
            self.log_dict({
                "val/mse": mse,
                "val/phase_error": phase_error,
                "val/constellation_error": constellation_error
            }, prog_bar=True)
            
            if batch_idx == 0:
                self._log_constellation(z_hat[:16], clean_z[:16])
    
    def _phase_error(self, pred, target):
        pred_phase = torch.atan2(pred[:,1], pred[:,0])
        target_phase = torch.atan2(target[:,1], target[:,0])
        phase_diff = torch.atan2(torch.sin(pred_phase - target_phase), 
                               torch.cos(pred_phase - target_phase))
        return torch.abs(phase_diff).mean()
    
    def _log_constellation(self, pred, target):
        fig, ax = plt.subplots(figsize=(6,6))
        
        circle = plt.Circle((0, 0), 1, fill=False, linestyle='--', alpha=0.3)
        ax.add_patch(circle)
        
        ideal_points = np.array([[1,1], [-1,1], [-1,-1], [1,-1]]) * (np.sqrt(2)/2)
        ax.scatter(ideal_points[:,0], ideal_points[:,1], c='g', marker='x', label='Ideal')
        
        pred = pred.detach().cpu().numpy()
        target = target.detach().cpu().numpy()
        ax.scatter(target[:,0], target[:,1], c='b', alpha=0.5, label='Clean')
        ax.scatter(pred[:,0], pred[:,1], c='r', alpha=0.5, label='Denoised')
        
        ax.set_xlim(-1.2, 1.2)
        ax.set_ylim(-1.2, 1.2)
        ax.grid(True)
        ax.legend()
        ax.set_title("QPSK Constellation")
        
        self.logger.experiment.add_figure("constellation", fig, self.global_step)
        plt.close(fig)
    
    def predict_step(self, batch, batch_idx=None):
        if isinstance(batch, (list, tuple)):
            noisy_signal = batch[0]
        else:
            noisy_signal = batch
            
        with torch.no_grad():
            mu, logvar = self.student(noisy_signal)
            z = reparameterize(mu, logvar)
            z_hat = self.denoiser(z)
            z_hat = z_hat / (torch.norm(z_hat, dim=-1, keepdim=True) + 1e-8)
            
        return z_hat
    
    

In [47]:
def train_model():
    signal_params = dict(
        time=2.0,
        symbol_rate=6,
        sample_rate=480,
        carrying_freq=5.0
    )

    datamodule = QPSKDataModule(batch_size=128, num_workers=4, **signal_params)
    model = QPSKDenoiserVSSL(ema_decay=0.995, lr=1e-3)

    logger = TensorBoardLogger(save_dir="tb_logs", name="qpsk_vssl")
    trainer = Trainer(
        max_epochs=100,
        logger=logger,
        check_val_every_n_epoch=5,
        accelerator='auto',
        devices=1
    )
    trainer.fit(model, datamodule)
    return model

In [48]:
def visualize_results(model):
    test_dataset = QPSKDataset(
        num_samples=10,
        augment_fn=qpsk_augmentation,
        time=2.0,
        symbol_rate=6,
        sample_rate=480,
        carrying_freq=5.0
    )
    
    noisy_signals, _ = next(iter(DataLoader(test_dataset, batch_size=10)))
    
    model.eval()
    denoised = model.predict_step(noisy_signals)
    
    for i in range(3):
        fig, ax = plt.subplots(1, 2, figsize=(12, 4))
        
        ax[0].scatter(noisy_signals[i,0], noisy_signals[i,1], alpha=0.5)
        ax[0].set_title(f"Noisy Signal {i+1}")
        ax[0].grid(True)
        
        ax[1].scatter(denoised[i,0], denoised[i,1], c='r')
        ax[1].set_title(f"Denoised Signal {i+1}")
        ax[1].grid(True)
        ax[1].set_xlim(-1.1, 1.1)
        ax[1].set_ylim(-1.1, 1.1)

        plt.show()

In [None]:
trained_model = train_model()
visualize_results(trained_model)