In [1]:
e_V_mean = 0.0
e_V_std = 0.7556

In [2]:
import torch
from torch.utils.data import DataLoader, Dataset, random_split
import numpy as np

class VoxelGridDataset(Dataset):
    def __init__(self, numpy_file1,numpy_file2,d):
        S = np.load(numpy_file1)
        U = np.load(numpy_file2)

        self.samples = U[:,:d] * S[:d]
        self.samples = self.samples.astype(np.float32)  
        # np.random.shuffle(self.samples)
        self.samples = torch.from_numpy(self.samples)


        data_min = torch.amin(self.samples,dim=0, keepdim= True)
        data_max = torch.amax(self.samples,dim=0, keepdim= True)

        self.samples = ((self.samples - data_min) / (data_max - data_min))*6-3
        print(self.samples.shape)
        
        
    def __len__(self):
        return self.samples.size(0)

    def __getitem__(self, idx):
        return self.samples[idx]

additioanl_name = '_afterTao'
d = 512
full_dataset = VoxelGridDataset(f'../data/S{additioanl_name}.npy', f'../data/U{additioanl_name}.npy', d=d)

# additioanl_name = ''
VT = np.load(f'../data/VT{additioanl_name}.npy')[:d,:]
VT = VT.astype(np.float32)  
VT = torch.from_numpy(VT)

torch.Size([1000, 512])


In [4]:
import torch
import os
import numpy as np
import tqdm
import matplotlib.pyplot as plt
import yaml
import argparse
import copy
import sys
sys.path.append('../') 
from models.dpm_model import * 
from models.dpm_utils import * 
!pip install wandb
import wandb
import torchvision.utils as tvu
import shutil
import glob
import time
import torch.utils.data as data 
import torch.optim as optim
import torch.nn as nn

Collecting wandb
  Using cached wandb-0.18.5-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (9.7 kB)
Collecting docker-pycreds>=0.4.0 (from wandb)
  Using cached docker_pycreds-0.4.0-py2.py3-none-any.whl.metadata (1.8 kB)
Collecting sentry-sdk>=2.0.0 (from wandb)
  Using cached sentry_sdk-2.18.0-py2.py3-none-any.whl.metadata (9.9 kB)
Collecting setproctitle (from wandb)
  Using cached setproctitle-1.3.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (9.9 kB)
Using cached wandb-0.18.5-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (16.0 MB)
Using cached docker_pycreds-0.4.0-py2.py3-none-any.whl (9.0 kB)
Using cached sentry_sdk-2.18.0-py2.py3-none-any.whl (317 kB)
Using cached setproctitle-1.3.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (30 kB)
Installing collected packages: setproctitle, sentry-sdk, docker-pycreds, wandb
Successfully installed do

In [5]:
def dict2namespace(config):
    namespace = argparse.Namespace()
    for key, value in config.items():
        if isinstance(value, dict):
            new_value = dict2namespace(value)
        else:
            new_value = value
        setattr(namespace, key, new_value)
    return namespace

In [6]:
class EMAHelper(object):
    def __init__(self, mu=0.999):
        self.mu = mu
        self.shadow = {}

    def register(self, module):
        if isinstance(module, nn.DataParallel):
            module = module.module
        for name, param in module.named_parameters():
            if param.requires_grad:
                self.shadow[name] = param.data.clone()

    def update(self, module):
        if isinstance(module, nn.DataParallel):
            module = module.module
        for name, param in module.named_parameters():
            if param.requires_grad:
                self.shadow[name].data = (
                    1. - self.mu) * param.data + self.mu * self.shadow[name].data

    def ema(self, module):
        if isinstance(module, nn.DataParallel):
            module = module.module
        for name, param in module.named_parameters():
            if param.requires_grad:
                param.data.copy_(self.shadow[name].data)

    def ema_copy(self, module):
        if isinstance(module, nn.DataParallel):
            inner_module = module.module
            module_copy = type(inner_module)(
                inner_module.config).to(inner_module.config.device)
            module_copy.load_state_dict(inner_module.state_dict())
            module_copy = nn.DataParallel(module_copy, inner_module.args.dataparallel)
        else:
            module_copy = type(module)(module.config).to(module.config.device)
            module_copy.load_state_dict(module.state_dict())
        self.ema(module_copy)
        return module_copy

    def state_dict(self):
        return self.shadow

    def load_state_dict(self, state_dict):
        self.shadow = state_dict

def get_optimizer(config, parameters):
    if config.optim.optimizer == 'Adam':
        return optim.Adam(parameters, lr=config.optim.lr, weight_decay=config.optim.weight_decay,
                          betas=(config.optim.beta1, 0.999), amsgrad=config.optim.amsgrad,
                          eps=config.optim.eps)
    elif config.optim.optimizer == 'RMSProp':
        return optim.RMSprop(parameters, lr=config.optim.lr, weight_decay=config.optim.weight_decay)
    elif config.optim.optimizer == 'SGD':
        return optim.SGD(parameters, lr=config.optim.lr, momentum=0.9)
    else:
        raise NotImplementedError(
            'Optimizer {} not understood.'.format(config.optim.optimizer))
    
def noise_estimation_loss(net, images,  labels=None, augment_pipe=None):
    rnd_normal = torch.randn([images.shape[0],], device=images.device)
    P_mean = args.P_mean
    P_std = args.P_std                
    sigma = (rnd_normal * P_std + P_mean).exp()
    reshaped_sigma = sigma.reshape(images.shape[0],  1, 1, 1, 1)
    n = torch.randn_like(images) * reshaped_sigma
    D_yn = net(images + n, sigma)           
    L_alphas = ((D_yn - images) ** 2)
    return L_alphas.mean()


def noise_estimation_loss_wReg(net, images, V,  labels=None, augment_pipe=None):
    rnd_normal = torch.randn([images.shape[0],], device=images.device)
    P_mean = args.P_mean
    P_std = args.P_std                
    sigma = (rnd_normal * P_std + P_mean).exp()
    reshaped_sigma = sigma.reshape(images.shape[0],  1, 1, 1, 1)
    n = torch.randn_like(images) * reshaped_sigma
    D_yn = net(images + n, sigma)           
    lamda = 0.7
    L_0= ((D_yn - images) ** 2).mean()
    mask = torch.randperm(V.shape[1])[:10000]
    L_1= ((D_yn @ V[:,mask] - images @ V[:,mask]) ** 2).mean()
    loss = L_1
    return loss, L_0.mean().detach().cpu().numpy(), L_1.detach().cpu().numpy()


In [7]:
class Diffusion(object):
    def __init__(self, args, config, device=None, 
                ):
        self.args = args
        self.config = config
        self.sigma_min = self.args.sigma_min
        self.sigma_max = self.args.sigma_max
        self.loss_his = []
        self.val_loss_his = []
        self.val_loss_steps = []
        self.P_mean = self.args.P_mean
        self.P_std = self.args.P_std
        
        if device is None:
            device = (
                self.config.device
                if torch.cuda.is_available()
                else torch.device("cpu")
            )
        self.device = device
        self.num_timesteps = config.diffusion.num_diffusion_timesteps
        
    def round_sigma(self, sigma):
        return torch.as_tensor(sigma)
    
    def train(self, train_loader, val_data_loader, e_V_mean, e_V_std, V=None):
        args, config = self.args, self.config
        tb_logger = self.config.tb_logger
        
        wandb.init(project='SVD-GEN', dir=args.exp, name='svd_gen_after_tao')
        model = Model(config)

        total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        print(f"Number of parameters in the model: {total_params}")
        print(model)

        model = model.to(self.device)
        V=V.to(self.device)
        model = torch.nn.DataParallel(model,self.args.dataparallel)
        optimizer = get_optimizer(self.config, model.parameters())

        if self.config.model.ema:
            ema_helper = EMAHelper(mu=self.config.model.ema_rate)
            ema_helper.register(model)
        else:
            ema_helper = None


        if self.config.optim.lr_decay:
            scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer,
                                                        gamma=self.config.optim.lr_decay)
        else:
            scheduler = None


        start_epoch, step = 0, 0
        if self.args.resume_training:
            states = torch.load(os.path.join(self.args.log_path, "ckpt.pth"))
            model.load_state_dict(states[0])

            states[1]["param_groups"][0]["eps"] = self.config.optim.eps
            states[1]["param_groups"][0]["lr"] = self.config.optim.lr
            optimizer.load_state_dict(states[1])
            start_epoch = states[2]
            step = states[3]
            if self.config.model.ema:
                ema_helper.load_state_dict(states[4])

        for epoch in range(start_epoch, self.config.training.n_epochs):
            data_start = time.time()
            data_time = 0
            total_loss = 0
            num_batches = 0
            if scheduler is not None:
                scheduler.step(epoch)


            for i, x in enumerate(train_loader):

                optimizer.zero_grad()
                
                data_time += time.time() - data_start
                model.train()
                step += 1
                x = x.to(self.device)
                loss_registry = {'simple': noise_estimation_loss_wReg,}

                loss, l0, l1 = loss_registry[config.model.type](net = model, images = x, V=V)
                wandb.log({"Train-Loss": loss, "Train-Loss-alpha":l0, "Train-Loss-shape":l1,}, step=step)
                total_loss += loss.mean().item()
                num_batches += 1
                    
                print(
                    f"step: {step}, loss: {loss.item()}, data time: {data_time / (i+1)}"
                )
                
                try:
                    torch.nn.utils.clip_grad_norm_(
                        model.parameters(), config.optim.grad_clip
                    )
                except Exception:
                    pass

                if self.config.model.ema:
                    ema_helper.update(model)

                if step % self.config.training.snapshot_freq == 0 or step == 1:
                    states = [
                        model.state_dict(),
                        optimizer.state_dict(),
                        epoch,
                        step,
                    ]
                    if self.config.model.ema:
                        states.append(ema_helper.state_dict())

                    torch.save(
                        states,
                        os.path.join(self.args.log_path, "ckpt_{}.pth".format(step)),
                    )
                    torch.save(states, os.path.join(self.args.log_path, "ckpt.pth"))
                
                if step % self.config.training.validation_freq == 0 :   
                    model.eval()
                    self.sample_sequence(model, training = True, step = step)
                
                loss.backward()
                optimizer.step()
                
                data_start = time.time()
                
            average_loss = total_loss / num_batches
            # run["train/loss"].append(average_loss)

            total_loss = 0
            num_batches = 0
            for i, x in enumerate(val_data_loader):
                
                model.eval()
                x = x.to(self.device)
                loss_registry = {'simple': noise_estimation_loss_wReg,}

                val_loss, l0, l1 = loss_registry[config.model.type](net = model, images = x, V=V)
                wandb.log({"Val-Loss": val_loss, "Val-Loss-alpha":l0, "Val-Loss-shape":l1,}, step=step)
                total_loss += val_loss.mean().item()
                num_batches += 1
                
            average_loss = total_loss / num_batches
            # run["val/loss"].append(average_loss)    


    def sample(self):
        model = Model(self.config)
        if self.args.ckpt_id is None:
            states = torch.load(
                os.path.join(self.args.log_path, "ckpt.pth"),
                map_location=self.config.device,
            )
        else:
            states = torch.load(
                os.path.join(
                    self.args.log_path, f"ckpt_{self.args.ckpt_id}.pth"
                ),
                map_location=self.config.device,
            )
            print('load ckpt: ', self.args.ckpt_id)
            
        model = model.to(self.device)
        model = torch.nn.DataParallel(model, self.args.dataparallel)
#         model = model.to(self.device)
        model.load_state_dict(states[0], strict=True)
        print('load epoch: ', states[3])

        model.eval()

       
        if self.args.interpolation:
            self.sample_interpolation(model)
        elif self.args.sequence:
            self.sample_sequence(model)
        else:
            raise NotImplementedError("Sample procedeure not defined")


    def sample_sequence(self, model, training = False, step = None):
        config = self.config

        x = torch.randn(
            config.sampling.batch_size,
            config.data.input_size,
            device=self.device,
        )
        
        data_start = time.time()
        data_time = 0
        # NOTE: This means that we are producing each predicted x0, not x_{t-1} at timestep t.
        with torch.no_grad():
            x = self.sample_alpha(x, model)
        data_time += time.time() - data_start
        print(f"the sample time of {self.config.sampling.batch_size} images with {self.args.sample_type}taks {data_time}")
        generated_feature = x
        np.save(os.path.join(args.generated_folder, "generated_feature_{}.npy".format(step)), generated_feature)



    def sample_interpolation(self, model):
        config = self.config

        def slerp(z1, z2, alpha):
            theta = torch.acos(torch.sum(z1 * z2) / (torch.norm(z1) * torch.norm(z2)))
            return (
                torch.sin((1 - alpha) * theta) / torch.sin(theta) * z1
                + torch.sin(alpha * theta) / torch.sin(theta) * z2
            )

        z1 = torch.randn(
            config.sampling.batch_size,
            config.data.input_size,
            device=self.device,
        )
        z2 = torch.randn(
            config.sampling.batch_size,
            config.data.input_size,
            device=self.device,
        )
        
        alpha = torch.arange(0.0, 1.01, 0.1).to(z1.device)
        z_ = []
        for i in range(alpha.size(0)):
            z_.append(slerp(z1, z2, alpha[i]))

        x = torch.cat(z_, dim=0)
        xs = []

        # Hard coded here, modify to your preferences
        with torch.no_grad():
            for i in range(0, x.size(0), 8):
                xs.append(self.sample_alpha(x[i : i + 8], model))
        x = inverse_data_transform(config, torch.cat(xs, dim=0))
        for i in range(x.size(0)):
            tvu.save_image(x[i], os.path.join(self.args.generated_folder, f"{i}.png"))

    def sample_alpha(self, x, model):
        data_start = time.time()
        data_time = 0
        if self.args.sample_type == "deterministic":
             x = edm_sampler(latents = x, num_steps = self.args.timesteps, net = model, randn_like=torch.randn_like, net_sigma_min = self.sigma_min, net_sigma_max = self.sigma_max )
        elif self.args.sample_type == "stochastic":
             x = edm_sampler(latents = x, num_steps = self.args.timesteps, net = model, randn_like=torch.randn_like, net_sigma_min = self.sigma_min, net_sigma_max = self.sigma_max, S_churn=40, S_min=0.05, S_max=50, S_noise=1.003)
        else:
            raise NotImplementedError
        return x
        data_time += time.time() - data_start
        print(data_time)

    def test(self):
        pass


In [8]:
def edm_sampler(
    latents, num_steps, net,  class_labels=None, randn_like=torch.randn_like, sigma_min=0.002, sigma_max=80, rho=5,
    S_churn=0, S_min=0, S_max=float('inf'), S_noise=1, net_sigma_min= 0.002, net_sigma_max=80
):
    # Adjust noise levels based on what's supported by the network.
    sigma_min = max(sigma_min, net_sigma_min)
    sigma_max = min(sigma_max, net_sigma_max)

    # Time step discretization.
    step_indices = torch.arange(num_steps, device=latents.device)
    t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
    t_steps = torch.cat([torch.as_tensor(t_steps), torch.zeros_like(t_steps[:1])]) # t_N = 0

    # Main sampling loop.
    
    x_next = latents * t_steps[0]
    print('latents:', latents.shape, 'x_next:', x_next.shape, 't_steps:', t_steps.shape)
    for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1
        
        t_cur = torch.ones(x_next.shape[0], device=t_cur.device) * t_cur
        t_next = torch.ones(x_next.shape[0], device=t_next.device) * t_next
        
        x_cur = x_next

        # Increase noise temporarily.
        gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur[0] <= S_max else 0
        t_hat = torch.as_tensor(t_cur + gamma * t_cur)
        x_hat = (x_cur + (t_hat[0] ** 2 - t_cur[0] ** 2).sqrt() * S_noise * randn_like(x_cur))
        print('x_hat:', x_hat.shape, 't_hat:', t_hat.shape)

        # Euler step.
#         denoised = net(x_hat, t_hat, class_labels).to(torch.float64)
        denoised = net(x_hat, t_hat)
        d_cur = (x_hat - denoised) / t_hat[0]
        x_next = x_hat + (t_next[0] - t_hat[0]) * d_cur

        # Apply 2nd order correction.
        if i < num_steps - 1:
            denoised = net(x_next, t_next)
#             denoised = net(x_next, t_next, class_labels)
            d_prime = (x_next - denoised) / t_next[0]
            x_next = x_hat + (t_next[0] - t_hat[0]) * (0.5 * d_cur + 0.5 * d_prime)

    return x_next.cpu()

In [9]:
torch.set_printoptions(sci_mode=False)
args = {
    'config': 'plane.yml',
    'seed': 1234,
    'exp': 'Version0_TAO',
    'doc': 'log_folder',
    'comment': "",
    'verbose': "info",
    'sequence': False,
    'test': False,
    'sample': False,
    'fid': False,
    'interpolation': False,
    'resume_training': False,
    'generated_folder': "Version0_TAO/alphas",
    'ni': False,
    'use_pretrained': False,
    'sample_type': "stochastic",
#     deterministic, stochastic
    'timesteps': 64,
    'dataparallel': [0],
    'P_mean': -1.2,
    'P_std':1.2,
    'sigma_min': 0.002,
    'sigma_max': 80
#     'ckpt_id': '91000',
#     'skip': 500,
}

args['log_path']=os.path.join(args['exp'], "logs", args['doc'])
args = dict2namespace(args)
with open(args.config, "r") as f:
    config = yaml.safe_load(f)
config['tb_logger'] = None
config = dict2namespace(config)
tb_path = os.path.join(args.exp, "tensorboard", args.doc)

In [10]:
# torch_device = "cuda:" + str(args.dataparallel[0])
torch_device = "cuda"
device = torch.device(torch_device) if torch.cuda.is_available() else torch.device("cpu")
print("Using device: {}".format(device))
config.device = device

# set random seed
torch.manual_seed(args.seed)
np.random.seed(args.seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(args.seed)

torch.backends.cudnn.benchmark = True

if not args.test and not args.sample:
    if not args.resume_training:
        if os.path.exists(args.log_path):
            overwrite = False
            if args.ni:
                overwrite = True
            if overwrite:
                shutil.rmtree(args.log_path)
                os.makedirs(args.log_path)
                if os.path.exists(tb_path):
                    shutil.rmtree(tb_path)
            else:
                print("Folder exists. Program halted.")
        else:
            os.makedirs(args.log_path)
        
        if os.path.exists(args.generated_folder):
            if args.ni:
                overwrite = True
            if overwrite:
                shutil.rmtree(args.generated_folder)
                os.makedirs(args.generated_folder)
            else:
                print("Folder exists. Program halted.")
        else:
            os.makedirs(args.generated_folder)

        with open(os.path.join(args.log_path, "config.yml"), "w") as f:
            yaml.dump(config, f, default_flow_style=False)

Using device: cuda


In [11]:

train_size = int(0.9 * len(full_dataset))
val_size = len(full_dataset) - train_size

train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

train_data_loader = DataLoader(
    dataset=train_dataset,
    batch_size = 64,
    num_workers=16,
    shuffle=True,
    drop_last=True
)

val_data_loader = DataLoader(
    dataset=val_dataset,
    batch_size=32,
    num_workers=16,
    shuffle=False,
    drop_last=True
)

In [13]:
runner = Diffusion(args, config, config.device)

In [14]:
runner.train(train_data_loader,val_data_loader, e_V_mean, e_V_std, V=VT)

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mjiajie7745[0m ([33mjiajie7745-bmw-group[0m). Use [1m`wandb login --relogin`[0m to force relogin


Number of parameters in the model: 85886464
Model(
  (temb): Module(
    (dense): ModuleList(
      (0-1): 2 x Linear(in_features=512, out_features=512, bias=True)
    )
  )
  (layer_in): Linear(in_features=512, out_features=512, bias=True)
  (down): ModuleList(
    (0): Module(
      (block): ModuleList(
        (0-1): 2 x ResnetBlock(
          (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (dense1): Linear(in_features=512, out_features=512, bias=True)
          (temb_proj): Linear(in_features=512, out_features=512, bias=True)
          (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (dense2): Linear(in_features=512, out_features=512, bias=True)
        )
      )
    )
    (1): Module(
      (block): ModuleList(
        (0): ResnetBlock(
          (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (dense1): Linear(in_features=512, out_features=1024, bias=True)
 



step: 1, loss: 0.00010809814557433128, data time: 0.25632357597351074
step: 2, loss: 9.638075425755233e-05, data time: 0.1288444995880127
step: 3, loss: 9.654399764258415e-05, data time: 0.08803629875183105
step: 4, loss: 8.647984941489995e-05, data time: 0.06614726781845093
step: 5, loss: 8.144745515892282e-05, data time: 0.05303115844726562
step: 6, loss: 8.752731810091063e-05, data time: 0.044308980305989586
step: 7, loss: 7.625211583217606e-05, data time: 0.03806080136980329
step: 8, loss: 8.283132046926767e-05, data time: 0.033424586057662964
step: 9, loss: 6.998563185334206e-05, data time: 0.029772811465793185
step: 10, loss: 7.791686221025884e-05, data time: 0.026924800872802735
step: 11, loss: 7.199481478892267e-05, data time: 0.024549549276178532
step: 12, loss: 6.52909220661968e-05, data time: 0.022749602794647217
step: 13, loss: 6.700376980006695e-05, data time: 0.021047427104069635
step: 14, loss: 7.917386392364278e-05, data time: 0.019585319927760532




step: 15, loss: 7.135505438782275e-05, data time: 0.35440564155578613
step: 16, loss: 7.200318941613659e-05, data time: 0.17805004119873047
step: 17, loss: 7.636816008016467e-05, data time: 0.12155469258626302
step: 18, loss: 7.32626867829822e-05, data time: 0.09131366014480591
step: 19, loss: 7.34006825950928e-05, data time: 0.07317118644714356
step: 20, loss: 6.915393169037998e-05, data time: 0.061057329177856445
step: 21, loss: 7.309031207114458e-05, data time: 0.05241901533944266
step: 22, loss: 6.702589598717168e-05, data time: 0.04594936966896057
step: 23, loss: 7.095761247910559e-05, data time: 0.040896972020467125
step: 24, loss: 6.972174742259085e-05, data time: 0.03686568737030029
step: 25, loss: 7.690413622185588e-05, data time: 0.033577377145940605
step: 26, loss: 7.058155460981652e-05, data time: 0.03110325336456299
step: 27, loss: 7.485557580366731e-05, data time: 0.028755004589374248
step: 28, loss: 8.122758299577981e-05, data time: 0.02681657246180943
step: 29, loss: 6.