# train conditional

In [1]:
import torch
import numpy as np
from tqdm import tqdm

import os 
import yaml
import logging
import shutil

import dataset
from model import DiffWave
from params import params

from torch.utils.data.dataloader import DataLoader
import torch.nn as nn


def read_yaml(yaml_path):
    with open(yaml_path, 'r') as file:
        yaml_data = yaml.safe_load(file)
        
    return yaml_data

device = 'cuda' if torch.cuda.is_available() else 'cpu'
config_info = read_yaml('config.yaml')

model_dir = config_info['model_dir']
os.makedirs(model_dir, exist_ok=True)
dataset_dir = config_info['signal_dir']
spectrogram_dir = config_info['spectrogram_dir']


logging.basicConfig(filename=os.path.join(model_dir, 'training.log'), level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

batch_size = config_info['batch_size']
epoches = config_info['epoches']


#noise_schedule = config_info['noise_schedule']
noise_schedule = np.linspace(1e-4, 0.05, 50).tolist()
beta = np.array(noise_schedule)
noise_level = np.cumprod(1 - beta)
noise_level = torch.tensor(noise_level.astype(np.float32))
loss_fn = nn.L1Loss()
summary_writer = None

trainDataset = dataset.ctg_dataset(dataset_dir, spectrogram_dir)
train_loader = DataLoader(
    dataset=trainDataset, 
    batch_size=batch_size,
    shuffle=True,
    drop_last=True)

learning_rate = config_info['learning_rate']
model = DiffWave(params).to(device)

# lr = config_info['learning_rate']

if config_info['optimizer'] == 'torch.optim.AdamW':
    
    opt = torch.optim.Adam(model.parameters(), lr=1.0e-4)
    
def train_unconditional(noise_level):
    shutil.copy('config.yaml', os.path.join(model_dir, 'config.yaml'))
    shutil.copy('params.py', os.path.join(model_dir, 'params.py'))
    for epoch in range(epoches):
        train_loss = []
        best_loss = float('inf')
        for step, (signal, spectrogram) in enumerate(train_loader):
            signal = signal.to(device)
            spectrogram = spectrogram.to(device)
            model.train()
            opt.zero_grad()
            
            N, T = signal.shape
            noise_level = noise_level.to(device)
            
            
            # 要查一下为什么是从1e-4到0.05
            # 明天对一下公式
            t = torch.randint(0, len(noise_schedule), [N], device=signal.device)
            noise_scale = noise_level[t].unsqueeze(1)
            noise_scale_sqrt = noise_scale**0.5
            noise = torch.randn_like(signal)
            noisy_signal = noise_scale_sqrt * signal + (1.0 - noise_scale)**0.5 * noise
            

            
            # spetrogram是原来model的contain，后面需要修改
            predicted = model(noisy_signal, t, spectrogram=spectrogram)
            loss = loss_fn(noise, predicted.squeeze(1))
            
            logging.info(f'Batch {step}/{len(train_loader)} loss : {loss}')
            train_loss.append(loss)
            
            loss.backward()
            opt.step()
        
        # 假设train_loss是一个包含PyTorch张量的列表
        train_loss = [item.detach().numpy() for item in train_loss]  
        train_loss = np.array(train_loss, dtype=np.float32)  
        #if train_loss.mean() <= best_loss:
         #   best_loss = train_loss.mean()
          #  torch.save(diffusion.state_dict(), 'runs/best_model.pth')
        if (epoch+1) % 2 == 0:
            savename = str(epoch+1) + '.pth'
            savename = os.path.join(model_dir, savename)
            torch.save(model.state_dict(), savename)


        logging.info(f'Epoch {epoch}/{epoches} loss : {train_loss.mean()}')


train_unconditional(noise_level)
    

    
    

# train unconditional

In [3]:
def train_unconditional(noise_level):
    shutil.copy('config.yaml', os.path.join(model_dir, 'config.yaml'))
    shutil.copy('params.py', os.path.join(model_dir, 'params.py'))
    for epoch in range(epoches):
        train_loss = []
        best_loss = float('inf')
        for step, signal in enumerate(train_loader):
            signal = signal.to(device)
            model.train()
            opt.zero_grad()
            
            N, T = signal.shape
            noise_level = noise_level.to(device)
            
            
            # 要查一下为什么是从1e-4到0.05
            # 明天对一下公式
            t = torch.randint(0, len(noise_schedule), [N], device=signal.device)
            noise_scale = noise_level[t].unsqueeze(1)
            noise_scale_sqrt = noise_scale**0.5
            noise = torch.randn_like(signal)
            noisy_signal = noise_scale_sqrt * signal + (1.0 - noise_scale)**0.5 * noise
            

            
            # spetrogram是原来model的contain，后面需要修改
            predicted = model(noisy_signal, t, spectrogram=None)
            loss = loss_fn(noise, predicted.squeeze(1))
            
            logging.info(f'Batch {step}/{len(train_loader)} loss : {loss}')
            train_loss.append(loss)
            
            loss.backward()
            opt.step()
        
        # 假设train_loss是一个包含PyTorch张量的列表
        train_loss = [item.detach().numpy() for item in train_loss]  
        train_loss = np.array(train_loss, dtype=np.float32)  
        #if train_loss.mean() <= best_loss:
         #   best_loss = train_loss.mean()
          #  torch.save(diffusion.state_dict(), 'runs/best_model.pth')
        if (epoch+1) % 1 == 0:
            savename = str(epoch+1) + '.pth'
            savename = os.path.join(model_dir, savename)
            torch.save(model.state_dict(), savename)


        logging.info(f'Epoch {epoch}/{epoches} loss : {train_loss.mean()}')


train(noise_level)
    

    
    