In [1]:
from typing import Callable

from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader, Dataset
import torch

from utils.load_data import *
import numpy as np

BCIC_dataset = load_BCIC(
train_sub=[1,2,3,4,5,6,7,8],
test_sub=[9],
alg_name = 'Tensor_CSPNet',
scenario = 'subject-independent'
)

train_x, train_y, test_x, test_y = BCIC_dataset.generate_training_valid_test_set_subject_independent()


train_x_normalized = np.zeros_like(train_x)

for i in range(train_x.shape[0]):
    tensor = train_x[i]
    l2_norm = np.linalg.norm(tensor)
    train_x_normalized[i] = tensor / l2_norm

data_x_0, data_y_0 = [], []
data_x_1, data_y_1 = [], []
data_x_2, data_y_2 = [], []
data_x_3, data_y_3 = [], []
    
data_x , data_y = [], []

for idx, i in enumerate(train_y):
    lab = int(i)
    if lab == 0:
        data_x_0.append(train_x_normalized[idx])
        data_y_0.append(i)
    if lab == 1:
        data_x_1.append(train_x_normalized[idx])
        data_y_1.append(i)
    if lab == 2:
        data_x_2.append(train_x_normalized[idx])
        data_y_2.append(i)
    if lab == 3:
        data_x_3.append(train_x_normalized[idx])
        data_y_3.append(i)
    data_x.append(train_x_normalized[idx])
    data_y.append(i)
    
        
class Simple_dataset(Dataset):
    def __init__(self, x, y):
        self.x = x,
        self.y = y
    def __len__(self):
        return len(self.y)
    def __getitem__(self, idx):
        return torch.tensor(self.x[idx], dtype=torch.float32), self.y[idx]
        
simple_dataset = Simple_dataset(data_x, data_y)
dataloader = DataLoader(simple_dataset)


2023-04-25 14:53:30.215254: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
from diffusers.models import UNet3DConditionModel as UNet3DModel
from torch import nn
from typing import Tuple



class UNet(nn.Module):
    def __init__(self, image_size: Tuple[int, int, int, int] = [1, 9, 22, 22]) -> None:
        super().__init__()

        self.model_fn = UNet3DModel(
            sample_size=image_size,
            in_channels=9,
            out_channels=9,
            down_block_types = (
            "CrossAttnDownBlock3D",
            "CrossAttnDownBlock3D",
            "CrossAttnDownBlock3D",
            "DownBlock3D",
        ),
        up_block_types = ("UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D"),
        block_out_channels = (128, 256, 256, 512),
        )

    def forward(self, *args, **kwargs):
        return self.model_fn(*args, **kwargs, return_dict=True).sample

In [3]:
from typing import List, Union

from pytorch_lightning import LightningModule
from torch import optim, Tensor
from torch.nn import functional as F
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity

from .consistency_models.consistency_models import (
    ConsistencySamplingAndEditing,
    ConsistencyTraining,
    ema_decay_rate_schedule,
    karras_schedule,
    timesteps_schedule,
)
from consistency_models.utils import update_ema_model


unet = UNet()
ema_unet = UNet()

def loss_fn(predicted, target):
    # Compute losses
    clamp = lambda x: x.clamp(min=-1.0, max=1.0)
    lpips_loss = self.lpips(clamp(predicted), clamp(target))
    overflow_loss = F.mse_loss(predicted, clamp(predicted).detach())
    loss = lpips_loss + overflow_loss
    return loss

optimizer = torch.optim.Adam(unet.parameters(), lr=2e-5, betas=(0.5, 0.999))

consistency_training = ConsistencyTraining(
    sigma_min = 0.1, # minimum std of noise
    sigma_max = 10.0, # maximum std of noise
    rho = 7.0, # karras-schedule hyper-parameter
    sigma_data = 0.5, # std of the data
    initial_timesteps = 2, # number of discrete timesteps during training start
    final_timesteps = 150, # number of discrete timesteps during training end
)

for step in range(100):
    # Zero out Grads
    optimizer.zero_grad()

    # Forward Pass
    for batch, target in dataloader:
        predicted, target = consistency_training(
            unet,
            ema_unet,
            batch,
            step,
            total_training_steps=100
            # my_kwarg=my_kwarg, # passed to the model as kwargs useful for conditioning
        )

        # Loss Computation
        loss = loss_fn(predicted, target)

        # Backward Pass & Weights Update
        loss.backward()
        optimizer.step()

        # EMA Update
        num_timesteps = timesteps_schedule(
            step,
            max_steps,
            initial_timesteps=2,
            final_timesteps=150,
        )
        ema_decay_rate = ema_decay_rate_schedule(
            num_timesteps,
            initial_ema_decay_rate=0.95,
            initial_timesteps=2,
        )
        update_ema_model(ema_model, online_model, ema_decay_rate)

  return torch.tensor(self.x[idx], dtype=torch.float32), self.y[idx]
