In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from datetime import datetime
from torch import amp
from torch.cuda.amp import GradScaler

In [2]:
class FileLogger:
    def __init__(self, path="training.log"):
        self.path = path
        os.makedirs(os.path.dirname(path), exist_ok=True)
        with open(self.path, 'w') as f:
            f.write(f"== Log started at {datetime.now()} ==\n")

    def log(self, msg):
        timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        with open(self.path, 'a') as f:
            f.write(f"[{timestamp}] {msg}\n")
            
def init_weights(m):
    if isinstance(m, (nn.Conv1d, nn.Linear)):
        nn.init.kaiming_normal_(m.weight, nonlinearity='leaky_relu')
        if m.bias is not None: 
            nn.init.zeros_(m.bias)

In [None]:
class MultiNpyIQDataset(Dataset):
    def __init__(self, folder_path):
        self.file_paths = sorted([
            os.path.join(folder_path, f)
            for f in os.listdir(folder_path)
            if f.endswith(".npy") and os.path.splitext(f)[0].isdigit()
            ], key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        path = self.file_paths[idx]
        arr = np.load(path, mmap_mode='r') 
        return torch.from_numpy(arr.astype(np.float32))

In [None]:
class Generator1D(nn.Module):
    def __init__(self, z_dim=512, out_len=12582912):
        super().__init__()
        self.out_len = out_len
        self.net = nn.Sequential(
            nn.ConvTranspose1d(z_dim, 256, 4, 1, 0),
            nn.BatchNorm1d(256), nn.ReLU(True),

            nn.ConvTranspose1d(256, 128, 4, 2, 1),
            nn.BatchNorm1d(128), nn.ReLU(True),

            nn.ConvTranspose1d(128, 64, 4, 4, 0),
            nn.BatchNorm1d(64), nn.ReLU(True),

            nn.ConvTranspose1d(64, 2, 512, 512),  
            nn.Tanh()
        )

    def forward(self, z):  
        z = z.unsqueeze(2)         
        x = self.net(z)              
        return F.interpolate(x, size=self.out_len, mode='linear') 
class Discriminator1D(nn.Module):
    def __init__(self, in_len=12582912):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(2, 64, 16, 4, 6),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv1d(64, 128, 8, 2, 3),
            nn.BatchNorm1d(128), nn.LeakyReLU(0.2, inplace=True),

            nn.Conv1d(128, 256, 8, 2, 3),
            nn.BatchNorm1d(256), nn.LeakyReLU(0.2, inplace=True),

            nn.Conv1d(256, 1, 4),
        )

    def forward(self, x):  
        return self.net(x).view(-1)

In [None]:
def train_gan(
    
    data_dir,
    epochs=100,
    batch_size=1,
    z_dim=512,
    save_interval=5,
    ckpt_dir="checkpoints",
    label_smooth=0.9,
    warmup_epochs=3,
    log_path="logs/gan_training.log"
):
    os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    logger = FileLogger(log_path)
    os.makedirs(ckpt_dir, exist_ok=True)

    dataset = MultiNpyIQDataset(data_dir)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True,
                        num_workers=4, pin_memory=(device.type == "cuda"))

    G = Generator1D(z_dim=z_dim).to(device)
    D = Discriminator1D().to(device)

    G.apply(init_weights)
    D.apply(init_weights)

    g_opt = torch.optim.Adam(G.parameters(), lr=2e-4, betas=(0.5, 0.9))
    d_opt = torch.optim.Adam(D.parameters(), lr=1e-5, betas=(0.5, 0.9))
    loss_fn = nn.BCEWithLogitsLoss()

    scaler_g = GradScaler()
    scaler_d = GradScaler()

    best_g_loss = float('inf')
    global_step = 0

    for epoch in range(1, epochs + 1):
        for real in loader:
            real = real.to(device, non_blocking=True)
            B = real.size(0)

            if epoch <= warmup_epochs:
                z = torch.randn(B, z_dim, device=device)
                with amp.autocast(device_type='cuda'):
                    fake = G(z)
                    pred_fake = D(fake)
                    g_loss = loss_fn(pred_fake, torch.full_like(pred_fake, label_smooth))
                g_opt.zero_grad()
                scaler_g.scale(g_loss).backward()
                scaler_g.step(g_opt)
                scaler_g.update()
                d_loss = torch.tensor(0.0, device=device)
            else:
                # Train Discriminator
                z = torch.randn(B, z_dim, device=device)
                with torch.no_grad():
                    fake = G(z).detach()

                with amp.autocast(device_type='cuda'):
                    d_real = D(real)
                    d_fake = D(fake)
                    real_tgt = torch.full_like(d_real, label_smooth)
                    fake_tgt = torch.zeros_like(d_fake)
                    d_loss = 0.5 * (loss_fn(d_real, real_tgt) + loss_fn(d_fake, fake_tgt))

                d_opt.zero_grad()
                scaler_d.scale(d_loss).backward()
                scaler_d.step(d_opt)
                scaler_d.update()

                # Train Generator
                z = torch.randn(B, z_dim, device=device)
                with amp.autocast(device_type='cuda'):
                    fake = G(z)
                    pred_fake = D(fake)
                    g_loss = loss_fn(pred_fake, torch.full_like(pred_fake, label_smooth))

                g_opt.zero_grad()
                scaler_g.scale(g_loss).backward()
                scaler_g.step(g_opt)
                scaler_g.update()

            global_step += 1

        logger.log(f"Epoch {epoch}/{epochs} | D_loss={d_loss.item():.4f} | G_loss={g_loss.item():.4f}")

        if epoch % save_interval == 0 or epoch == epochs:
            torch.save(G.state_dict(), os.path.join(ckpt_dir, f"G_epoch{epoch}.pt"))
            torch.save(D.state_dict(), os.path.join(ckpt_dir, f"D_epoch{epoch}.pt"))

        if g_loss.item() < best_g_loss and epoch > warmup_epochs:
            best_g_loss = g_loss.item()
            torch.save(G.state_dict(), os.path.join(ckpt_dir, "G_best.pt"))
            torch.save(D.state_dict(), os.path.join(ckpt_dir, "D_best.pt"))
            logger.log(f"[BEST] Updated G_loss={best_g_loss:.4f} at epoch {epoch}")


In [6]:
train_gan(
    data_dir="../data",
    epochs=100,
    batch_size=1,
    z_dim=512,
    save_interval=5,
    ckpt_dir="checkpoints",
    log_path="logs/gan_training.log"
)

  scaler_g = GradScaler()
  scaler_d = GradScaler()
