In [7]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [8]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import numpy as np
import functools
import tqdm
from model import DiffWave, DW4BSS

In [9]:
device = torch.device('cuda:1')

def marginal_prob_std(t, sigma):
    t = torch.tensor(t, device=device)
    return torch.sqrt((sigma**(2*t) - 1.) / 2. / np.log(sigma))

def diffusion_coeff(t, sigma):
    return torch.tensor(sigma**t, device=device)

sigma = 30.0
marginal_prob_std_fn = functools.partial(marginal_prob_std, sigma=sigma)
diffusion_coeff_fn = functools.partial(diffusion_coeff, sigma=sigma)

In [10]:
def loss_fn(model, x, marginal_prob_std, eps=1e-5):
    random_t = torch.rand(x.shape[0], device=x.device) * (1.-eps) + eps
    z = torch.randn_like(x)
    std = marginal_prob_std(random_t)
    perturbed_x = x + z * std[:, None, None]
    score = model(perturbed_x, random_t)
    loss = torch.mean(
        torch.sum((score * std[:, None, None] + z)**2, dim=(1, 2)))
    return loss

# Train

In [11]:
from params import params_diffwave_cat, params_diffwave3, params_gau_cat, params_gau3
from datasets_bss import train_dataset
from torch.optim import Adam
from GAU import DW4BSS_GAU, GAUnet3
from model import DW4BSS, DiffWave3 

BSS_net = DiffWave3(params_diffwave3, marginal_prob_std_fn).to(device)
# BSS_net = DW4BSS_GAU(gau_params, marginal_prob_std_fn).to(device)
BSS_net = BSS_net.float()

try:
    the_ckpt = torch.load(r'/home/wyl/projects/_EEG_score/ckpt/diffwave3_ckpt_1.pth')
    BSS_net.load_state_dict(the_ckpt, strict=False)
    print('::: model loaded :::')
except:
    print('::: model does not exist :::')

lr = 2e-4
batch_size = 64
n_epoch = 40

BSS_dataset = train_dataset()
BSS_dataloader = DataLoader(BSS_dataset, batch_size=64, shuffle=True)

optimizer = Adam(BSS_net.parameters(), lr=lr)

tqdm_epoch = tqdm.trange(n_epoch)
for epoch in tqdm_epoch:
    avg_loss = 0.
    num_items = 0
    for _, x in BSS_dataloader:
        x = x.squeeze(1)
        x = x.to(device).float()
        loss = loss_fn(BSS_net, x, marginal_prob_std_fn)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        avg_loss += loss.item() * x.shape[0]
        num_items += x.shape[0]
    tqdm_epoch.set_description('Average Loss: {:5f}'.format(avg_loss / num_items))
    torch.save(BSS_net.state_dict(), r'/home/wyl/projects/_EEG_score/ckpt/diffwave3_ckpt_1.pth')

::: model does not exist :::


  t = torch.tensor(t, device=device)
Average Loss: 254.607459: 100%|██████████| 100/100 [08:23<00:00,  5.03s/it]
