In [1]:
import argparse
import torch
import torch.nn as nn
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
from torch.optim import Adam
from torch.utils.data import DataLoader, random_split
from audio_encoders_pytorch.audio_encoders_pytorch import AutoEncoder1d, STFT, Bottleneck, VariationalBottleneck
import utils.load_datasets
import os
from tqdm import tqdm
from accelerate import Accelerator
import random
import optuna
from torch.utils.data import random_split
import numpy as np
from torch import Tensor
from einops import rearrange
class STFTAutoEncoder1d(AutoEncoder1d):
    def __init__(
        self,
        in_channels: int,
        channels: int,
        multipliers: Sequence[int],
        factors: Sequence[int],
        num_blocks: Sequence[int],
        resnet_groups: int = 8,
        stft_num_fft: int = 1024,
        stft_hop_length: int = 256,
        stft_win_length: Optional[int] = None,
        bottleneck: Union[Bottleneck, List[Bottleneck]] = [],
        bottleneck_channels: Optional[int] = None,
    ):
        self.frequency_channels = stft_num_fft // 2 + 1
        super().__init__(
            in_channels=in_channels * self.frequency_channels * 2,  # Real and imaginary parts
            channels=channels,
            multipliers=multipliers,
            factors=factors,
            num_blocks=num_blocks,
            resnet_groups=resnet_groups,
            bottleneck=bottleneck,
            bottleneck_channels=bottleneck_channels,
        )
        self.stft = STFT(
            num_fft=stft_num_fft,
            hop_length=stft_hop_length,
            window_length=stft_win_length,
            use_complex=False,
        )

    def encode(self, x: Tensor, with_info: bool = False) -> Union[Tensor, Tuple[Tensor, Any]]:
        stft_real, stft_imag = self.stft.encode(x)
        stft_combined = torch.cat([stft_real, stft_imag], dim=1)
        stft_flat = rearrange(stft_combined, "b c f l -> b (c f) l")
        return super().encode(stft_flat, with_info=with_info)

    def decode(self, z: Tensor, with_info: bool = False) -> Union[Tensor, Tuple[Tensor, Any]]:
        stft_flat, info = super().decode(z, with_info=True)
        stft_combined = rearrange(stft_flat, "b (c f) l -> b c f l", f=self.frequency_channels)
        stft_real, stft_imag = stft_combined.chunk(2, dim=1)
        waveform = self.stft.decode(stft_real, stft_imag)
        return (waveform, info) if with_info else waveform

    def forward(self, x: Tensor, with_info: bool = False) -> Union[Tensor, Tuple[Tensor, Any]]:
        z, info_encoder = self.encode(x, with_info=True)
        y, info_decoder = self.decode(z, with_info=True)
        info = {
            **dict(latent=z),
            **prefix_dict("encoder_", info_encoder),
            **prefix_dict("decoder_", info_decoder),
        }
        return (y, info) if with_info else y

    def loss(self, x: Tensor, with_info: bool = False) -> Union[Tensor, Tuple[Tensor, Dict]]:
        y, info = self(x, with_info=True)
        loss = F.mse_loss(x, y)
        return (loss, info) if with_info else loss
    
stft_autoencoder = STFTAutoEncoder1d(
    in_channels=2,  # Assuming mono audio
    channels=64,
    multipliers=[1, 2, 4, 8],
    factors=[4, 4, 2],
    num_blocks=[2, 2, 2],
    stft_num_fft=1024,
    stft_hop_length=256,
    bottleneck=[VariationalBottleneck(channels=512)],  # Optional: add a variational bottleneck
)



In [3]:
def setup_training(config, model):
    optimizer = Adam(model.parameters(), lr=config['learning_rate'], betas=tuple(config['adam_betas']))
    criterion = nn.MSELoss()
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, config['gamma'])
    return optimizer, criterion, scheduler



def setup_dataloader(batch_size, num_workers, val_split=0.2):
    dataset = utils.load_datasets.DeepSig2018Dataset(
        "/ext/trey/experiment_diffusion/experiment_rfdiffusion/dataset/GOLD_XYZ_OSC.0001_1024.hdf5")
    val_size = int(len(dataset) * val_split)
    train_size = len(dataset) - val_size
    train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True,
                              num_workers=num_workers)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=num_workers)

    return train_loader, val_loader

def evaluate_model(model, data_loader, accelerator):
    model.eval()
    total_loss = 0.0
    num_samples = 0

    with torch.no_grad():
        for x, _ in data_loader:
            x = x.to(accelerator.device)
            y = model.encode(x)
            y = model.decode(y)
            loss = torch.nn.functional.mse_loss(y, x)
            total_loss += loss.item() * x.size(0)
            num_samples += x.size(0)

    avg_loss = total_loss / num_samples
    return avg_loss


def setup_accelerator(config):
    accelerator = Accelerator(log_with="wandb")
    run_name = str(random.randint(0, 10e5))
    accelerator.init_trackers(
        config['project_name'],
        config=config,
        init_kwargs={"wandb": {"name": run_name}}
    )
    return accelerator, run_name

def train_model(model, optimizer, criterion, scheduler, train_loader, val_loader, accelerator, config):
    model, optimizer, train_loader, val_loader, scheduler = accelerator.prepare(
        model, optimizer, train_loader, val_loader, scheduler
    )
    num_training_steps = config['epochs'] * len(train_loader)
    print(num_training_steps)
    progress_bar = tqdm(range(num_training_steps), disable=not accelerator.is_local_main_process)

    model.train()
    step = 1

    for epoch in range(config['epochs']):
        for x, _ in train_loader:
            y = model(x)
            loss = criterion(y, x)

            accelerator.backward(loss)
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()

            progress_bar.update(1)
            accelerator.log({"training_loss": loss, "learning_rate": scheduler.get_last_lr()[0]}, step=step)
            step += 1

        if epoch % config['save_every'] == 0 and accelerator.is_main_process:
            save_checkpoint(accelerator.unwrap_model(model), optimizer, epoch, config['model_save_dir'], f'model_epoch_{epoch}.pth')
            validation_loss = evaluate_model(model, val_loader, accelerator)
            wandb.log({"validation_loss": validation_loss})

    accelerator.end_training()
    
def save_checkpoint(model, optimizer, epoch, save_dir, filename):
    checkpoint_path = os.path.join(save_dir, filename)
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict()
    }, checkpoint_path)
    
def main():
    config = {
        'learning_rate':1e4,
        'epochs':10,
        'project_name':'autoencoder_feature',
        'adam_betas':(0.9,0.999),
        'gamma':0.9,
        'base_save_dir': "/home/trey/experiment_rfdiffusion/models/rfdiffusion_diffusion",
        'num_workers':8
        
    }

    # Construct model_save_dir
    config['model_save_dir'] = os.path.join(config['base_save_dir'], config['project_name'])
    os.makedirs(config['model_save_dir'], exist_ok=True)

    accelerator, run_name = setup_accelerator(config)

    model = STFTAutoEncoder1d(
    in_channels=2,  # Assuming mono audio
    channels=64,
    multipliers=[1, 2, 4, 8],
    factors=[4, 4, 2],
    num_blocks=[2, 2, 2],
    stft_num_fft=1024,
    stft_hop_length=256,
    bottleneck=[VariationalBottleneck(channels=512)],  # Optional: add a variational bottleneck
    )
    optimizer, criterion, scheduler = setup_training(config, model)
    train_loader, val_loader = setup_dataloader(256,8)


    print(f"Training on {accelerator.num_processes} GPUs")
    print(f"Number of parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
    print(f"Models will be saved in: {config['model_save_dir']}")

    train_model(model, optimizer, criterion, scheduler, train_loader, val_loader, accelerator, config)

    if accelerator.is_main_process:
        final_checkpoint_path = os.path.join(config['model_save_dir'], f'model_{run_name}.pth')
        save_checkpoint(accelerator.unwrap_model(model), optimizer, config['epochs'], config['model_save_dir'],
                        final_checkpoint_path)

    print("Training complete and models saved.")


if __name__ == "__main__":
    main()

Widget Javascript not detected.  It may not be installed or enabled properly. Reconnecting the current kernel may help.


Training on 1 GPUs
Number of parameters: 24391500
Models will be saved in: /home/trey/experiment_rfdiffusion/models/rfdiffusion_diffusion/autoencoder_feature
79880



  0%|                                                                                                                                                                                                                                                                                               | 0/79880 [00:00<?, ?it/s][A

RuntimeError: istft requires a complex-valued input tensor matching the output from stft with return_complex=True.