In [13]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.io import decode_image
from torchvision import transforms
import matplotlib.pyplot as plt
import os
import pandas as pd
from PIL import Image
from tqdm import tqdm
import numpy as np

In [14]:
class ConvICNN(nn.Module):
    def __init__(self, linear_layers=4, convex_layers=5, downscale=[(2048, 128), (128, 64), (64, 32), (32, 1)], beta=1e-6):
        super().__init__()
        self.beta = beta
        self.linear_layers = linear_layers
        self.convex_layers = convex_layers
        self.downscale = downscale

        # Linear block
        # Here we use either identity activation or average pooling
        self.conv_direct = nn.Sequential(nn.Conv2d(3, 128, kernel_size=3, padding=1, bias=False))
        self.conv_sqr = nn.Sequential(nn.Conv2d(3, 128, kernel_size=3, padding=1, bias=False))
        for i in range(linear_layers - 2):
            self.conv_direct.append(nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=False))
            self.conv_sqr.append(nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=False))
        self.conv_direct.append(nn.Conv2d(128, 128, kernel_size=3, padding=1))
        self.conv_sqr.append(nn.Conv2d(128, 128, kernel_size=3, padding=1))

        # Convexity-preserving layers
        self.convex_pre_act = nn.CELU()
        self.convex = nn.Sequential()
        for i in range(convex_layers):
            self.convex.append(nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1))#, padding=1, offset=1))
            self.convex.append(nn.CELU())

        self.linear = nn.Sequential()
        for i in range(len(downscale) - 1):
            self.linear.append(nn.Linear(downscale[i][0], downscale[i][1]))
            self.linear.append(nn.CELU())
        self.linear.append(nn.Linear(downscale[-1][0], downscale[-1][1]))


    def forward(self, x):
        x_input = x
        linear = self.conv_direct(x)
        linear += torch.square(self.conv_sqr(x))
        out = self.convex(self.convex_pre_act(linear))
        out = out.reshape(out.size(0), -1)
        out = self.linear(out)
        strong_convexity = (self.beta / 2) * torch.sum(torch.square(x.reshape(x.size(0), -1)), dim=1, keepdim=True)
        return out + strong_convexity

    def update_weights(self):
        for i in range(0, self.convex_layers):
            # each Conv2d
            self.convex[i*2].weight.data.clamp_(min=0)

        for i in range(0, len(self.downscale)):
            # each Linear
            self.linear[i*2].weight.data.clamp_(min=0)

    def gradient(self, x):
        """Compute gradient of the network w.r.t. input"""
        #if not x.requires_grad:
        #    x = x.clone().requires_grad_(True)
        y = self.forward(x)
        grad = torch.autograd.grad(
            outputs=y,
            inputs=x,
            grad_outputs=torch.ones_like(y),
            create_graph=True,
            retain_graph=True
        )[0]
        return grad#.detach()

In [15]:
class W2GNLoss:
    """
    Wasserstein-2 Generative Network Loss
    """
    def __init__(self, psi_theta, psi_omega, lambda_cycle=1.0, devices=['cuda'], compute_extra_reg=False):
        self.lambda_cycle = lambda_cycle
        self.devices = devices
        self.dev0 = devices[0]
        self.dev1 = devices[1] if len(devices) > 1 else devices[0]
        self.compute_extra_reg = compute_extra_reg

        # Initialize primal and dual potentials
        self.psi_theta = psi_theta.to(self.dev0)
        self.psi_omega = psi_omega.to(self.dev1)

    def theta_device(self):
        return self.dev0

    def omega_device(self):
        return self.dev1

    def to_theta(self, t: torch.Tensor):
        """
        Transfer tensor to psi_theta device
        """
        if self.dev0 != self.dev1:
            return t.to(self.dev0, non_blocking=True)
        else:
            return t

    def to_omega(self, t: torch.Tensor):
        """
        Transfer tensor to psi_omega device
        """
        if self.dev0 != self.dev1:
            return t.to(self.dev1, non_blocking=True)
        else:
            return t

    def compute_correlations(self, X, Y, grad_psi_omega):
        """
        Compute Monte-Carlo estimate of correlations
        L_Corr = (1/K) * [sum psi_theta(x) + sum(<grad_psi_omega(y), y> - psi_theta(grad_psi_omega(y)))]
        """
        batch_size = X.shape[0]

        # Term 1: E_P[psi_theta(x)]
        term1 = self.psi_theta(X).mean()

        # Term 2: E_Q[<grad_psi_omega(y), y> - psi_theta(grad_psi_omega(y))]
        # Compute gradient of conjugate potential

        # Inner product <grad_psi_omega(y), y>
        inner_prod = torch.sum(grad_psi_omega * self.to_omega(Y), dim=(1, 2, 3), keepdim=True)

        # Evaluate primal potential at gradient
        psi_at_grad = self.psi_theta(self.to_theta(grad_psi_omega))

        term2 = (self.to_theta(inner_prod) - psi_at_grad).mean()

        return term1 + term2 # THETA

    def compute_cycle_consistency(self, X, Y):
        """
        Compute cycle consistency regularization
        R_Y = E_Q[||grad_psi_theta(grad_psi_omega(y)) - y||^2] + E_Q[||grad_psi_omega(grad_psi_theta(x)) - x||^2]
        """
        # Forward: omega -> theta
        grad_psi_omega = self.psi_omega.gradient(self.to_omega(Y))
        grad_psi_theta = self.psi_theta.gradient(self.to_theta(grad_psi_omega))

        # Compute squared difference
        diff = grad_psi_theta - self.to_theta(Y).detach()
        cycle_loss_y = torch.mean(diff ** 2)

        # Same for X
        grad_psi_theta_x = self.psi_theta.gradient(self.to_theta(X))
        grad_psi_omega_x = self.psi_omega.gradient(self.to_omega(grad_psi_theta_x))

        diff_x = grad_psi_omega_x - self.to_omega(X).detach()
        
        cycle_loss_x = torch.mean(diff_x ** 2)
        cycle_loss = cycle_loss_y + self.to_theta(cycle_loss_x)
        
        return cycle_loss # THETA

    def compute_extra_R(self, X):
        """
        Compute extra regularization term
        R_X = (1/K) * [sum ||grad_psi_omega(grad_psi_theta(x)) - x||^2]
        """
        grad_psi_theta = self.psi_theta.gradient(self.to_theta(X))
        grad_psi_omega = self.psi_omega.gradient(self.to_omega(grad_psi_theta))

        # Compute squared difference
        diff = grad_psi_omega - self.to_omega(X)
        r_reg = torch.mean(torch.sum(diff ** 2, dim=(1, 2, 3)))
        return r_reg # OMEGA

    def compute_reg(self):
        """
        Compute regularization of model params
        """
        reg_total = 0.0
        for p in self.psi_theta.parameters():
            reg_total += torch.sum(torch.abs(p))
        for p in self.psi_omega.parameters():
            reg_total += self.to_theta(torch.sum(torch.abs(p)))
        return reg_total

    def compute_loss(self, opt, X, Y):
        """
        Compute loss and calculate gradients. It is optimal to store X at device 0 and Y at device 1
        Args:
            opt: optimizer
            X: batch from source distribution P
            Y: batch from target distribution Q
        """
        X.requires_grad_(True)
        Y.requires_grad_(True)
        opt[0].zero_grad()
        opt[1].zero_grad()

        # Compute cycle consistency
        loss_cycle = self.compute_cycle_consistency(X, Y)

        grad_psi_omega = self.psi_omega.gradient(self.to_omega(Y))
        grad_psi_omega_d = grad_psi_omega.detach() # heuristic from Appendix C.1

        opt[0].zero_grad()
        opt[1].zero_grad()
        
        # Compute neg W dist & correlations
        loss_W = torch.mean(self.psi_theta(self.to_theta(X)) - self.psi_theta(self.to_theta(grad_psi_omega_d)))
        #with torch.no_grad():
        #    loss_W_const = (- (self.to_omega(X) ** 2).sum(dim=(1, 2, 3)) / 2).mean() + \
        #    ((grad_psi_omega_d * self.to_omega(Y)).sum(dim=(1, 2, 3)) - (self.to_omega(Y) ** 2).sum(dim=(1, 2, 3)) / 2).mean()
        #loss_corr = self.compute_correlations(X, Y, grad_psi_omega_d)

        # Total loss
        if self.compute_extra_reg:
            loss_R = 1e-10 * self.compute_reg() #self.to_theta(self.compute_extra_R(X))
            loss_R_item = loss_R.item()
            #loss_total = loss_corr + (self.lambda_cycle / 2) * loss_cycle + loss_R
            loss_total = loss_W + (self.lambda_cycle / 2) * loss_cycle + loss_R
        else:
            loss_R = None
            loss_R_item = None
            loss_total = loss_W + (self.lambda_cycle / 2) * loss_cycle

        # Gradient step
        loss_total.backward()
        opt[0].step()
        opt[1].step()
        opt[0].zero_grad()
        opt[1].zero_grad()

        return {
            'loss_corr': loss_W.item(),
            'loss_cycle': loss_cycle.item(),
            'loss_total': loss_total.item(),
            'loss_R': loss_R_item
        }

    def pretrain_loss(self, X):
        X.requires_grad_(True)
        grad_psi_theta = self.psi_theta.gradient(self.to_theta(X))
        loss_mse = F.mse_loss(grad_psi_theta, X)
        loss_mse.backward()

        return loss_mse

    def apply_pretrain_weights(self):
        self.psi_omega.load_state_dict(self.psi_theta.state_dict())

    def generate(self, X):
        """Generate samples: g(x) = grad_psi_theta(x)"""
        return self.psi_theta.gradient(X).detach()

    def inverse(self, Y):
        """Inverse mapping: g^{-1}(y) = grad_psi_omega(y)"""
        return self.psi_omega.gradient(Y).detach()

In [16]:
class W2GNTrainer:
    def __init__(self, loss_c, lr=1e-3, pre_lr=1e-3, betas_theta=(0.8, 0.99), betas_omega=(0.4, 0.4)):
        self.loss_c = loss_c
        self.optimizers = [
            torch.optim.Adam(
                list(self.loss_c.psi_theta.parameters()),
                lr=lr, betas=betas_theta
            ),
            torch.optim.Adam(
                list(self.loss_c.psi_omega.parameters()),
                lr=lr, betas=betas_omega
        )]

        self.preoptimizer = torch.optim.Adam(
           list(self.loss_c.psi_theta.parameters()),
           lr=pre_lr
        )

        self.history = {'loss_corr': [], 'loss_cycle': [], 'loss_total': [], 'loss_R': []}
        self.history_pre = []

    def step(self, X, Y):
        #self.optimizer.zero_grad()
        losses = self.loss_c.compute_loss(self.optimizers, X, Y)
        #self.optimizer.step()
        # Clamp weights in convexity-preserving layers to be non-negative
        with torch.no_grad():
            self.loss_c.psi_theta.update_weights()
            self.loss_c.psi_omega.update_weights()

        for x in losses.keys():
            if losses[x] is not None:
                self.history[x].append(losses[x])

        return losses

    def pretrain_step(self, X):
        self.preoptimizer.zero_grad()
        loss = self.loss_c.pretrain_loss(X)
        self.preoptimizer.step()
        with torch.no_grad():
            self.loss_c.psi_theta.update_weights()

        self.history_pre.append(loss)

        return loss

    def apply_pretrain(self):
        self.loss_c.apply_pretrain_weights()
        del self.preoptimizer

    def generate(self, X):
        """Generate samples: g(x) = grad_psi_theta(x)"""
        #with torch.no_grad():
        return self.loss_c.psi_theta.gradient(X)

    def inverse(self, Y):
        """Inverse mapping: g^{-1}(y) = grad_psi_omega(y)"""
        #with torch.no_grad():
        return self.loss_c.psi_omega.gradient(Y)

In [17]:
class DiffusionNoiseSchedule:
    """Noise schedule for diffusion process"""
    def __init__(self, timesteps=1000, beta_start=0.0001, beta_end=0.02):
        self.timesteps = timesteps
        self.betas = np.linspace(beta_start, beta_end, timesteps)
        self.alphas = 1 - self.betas
        self.alpha_bars = np.cumprod(self.alphas)
        
    def add_noise(self, x, t):
        """Add noise to image at timestep t"""
        alpha_bar = self.alpha_bars[t]
        noise = torch.randn_like(x)
        noisy = np.sqrt(alpha_bar) * x + np.sqrt(1 - alpha_bar) * noise
        return noisy, noise
    
    def get_alpha_bar(self, t):
        return self.alpha_bars[t]


In [18]:
class DiffTrainer:
    """Diffusion model trainer using W2GN framework"""
    def __init__(self, loss_c, noise_schedule, lr=1e-3, betas_theta=(0.8, 0.99), betas_omega=(0.4, 0.4)):
        self.loss_c = loss_c
        self.noise_schedule = noise_schedule
        self.optimizers = [
            torch.optim.Adam(
                list(self.loss_c.psi_theta.parameters()),
                lr=lr, betas=betas_theta
            ),
            torch.optim.Adam(
                list(self.loss_c.psi_omega.parameters()),
                lr=lr, betas=betas_omega
            )
        ]
        
        self.history = {
            'loss_corr': [], 
            'loss_cycle': [], 
            'loss_total': [], 
            'loss_R': []
        }
    
    def step(self, clean_images, t_min=0, t_max=None):
        """
        Training step: denoise images
        Args:
            clean_images: batch of clean images (X)
            t_min: minimum timestep for noise
            t_max: maximum timestep for noise
        """
        if t_max is None:
            t_max = self.noise_schedule.timesteps - 1
            
        batch_size = clean_images.shape[0]
        
        # Sample random timesteps for each image in batch
        t = np.random.randint(t_min, t_max, size=batch_size)
        
        # Add noise to create Y (noisy images)
        noisy_images = []
        for i, img in enumerate(clean_images):
            noisy, _ = self.noise_schedule.add_noise(img.unsqueeze(0), t[i])
            noisy_images.append(noisy)
        noisy_images = torch.cat(noisy_images, dim=0)
        
        # X = clean images, Y = noisy images
        X = clean_images
        Y = noisy_images.to(self.loss_c.omega_device())
        
        # Training step
        losses = self.loss_c.compute_loss(self.optimizers, X, Y)
        
        # Update weights
        with torch.no_grad():
            self.loss_c.psi_theta.update_weights()
            self.loss_c.psi_omega.update_weights()
        
        # Record history
        for key in losses.keys():
            if losses[key] is not None:
                self.history[key].append(losses[key])
        
        return losses
    
    def denoise(self, noisy_image):
        """
        Denoise a single image
        Args:
            noisy_image: noisy input image
        Returns:
            denoised image
        """
        with torch.no_grad():
            denoised = self.loss_c.generate(noisy_image)
        return denoised
    
    def progressive_denoise(self, image, steps=10):
        """
        Progressive denoising from high noise to clean
        Args:
            image: input image
            steps: number of denoising steps
        Returns:
            list of progressively denoised images
        """
        results = [image]
        current = image
        
        timesteps = np.linspace(self.noise_schedule.timesteps-1, 0, steps).astype(int)
        
        for t in timesteps:
            # Add noise at level t
            noisy, _ = self.noise_schedule.add_noise(current, t)
            # Denoise
            denoised = self.denoise(noisy)
            results.append(denoised)
            current = denoised
            
        return results

In [7]:
import kagglehub

# Download latest version
path = kagglehub.dataset_download("balraj98/summer2winter-yosemite")

In [8]:
class S2WDataset(torch.utils.data.Dataset):
    def __init__(self, root, target_shape, is_val=False):
        super().__init__()
        self.root = root
        self.target_shape = target_shape
        self.is_val = is_val

        self.trainA = "trainA"
        self.testA = "testA"
        self.trainB = "trainB"
        self.testB = "testB"

        if not is_val:
            #self.df = pd.read_csv(os.path.join(root, "metadata.csv"))
            self.imagesA = os.listdir(os.path.join(root, self.trainA))
            self.imagesB = os.listdir(os.path.join(root, self.trainB))
        else:
            self.imagesA = os.listdir(os.path.join(root, self.testA))
            self.imagesB = os.listdir(os.path.join(root, self.testB))

        self.transforms_A = transforms.Compose([
            transforms.RandomRotation(10),
            transforms.CenterCrop(200),
            transforms.Resize((128, 128)),
            transforms.RandomHorizontalFlip(0.5),
            transforms.ColorJitter(0.025, 0.025, 0.025, 0.025),
            #transforms.ToTensor(),
            #transforms.Lambda(lambda x: x+torch.randn_like(x) * 0.01),
            #transforms.Lambda(lambda t: t.clamp(0,1)),
        ])
        self.transforms_B = self.transforms_A

        self.pics = min(len(self.imagesA), len(self.imagesB))

    def __len__(self):
        return self.pics

    def get(self, idx, pic_dirA, pic_dirB):
        img_pathA = os.path.join(self.root, pic_dirA, self.imagesA[idx])
        img_pathB = os.path.join(self.root, pic_dirB, self.imagesB[idx])

        # Load the image
        imgA = self.transforms_A(Image.open(img_pathA))#.resize((self.target_shape, self.target_shape))
        imgB = self.transforms_B(Image.open(img_pathB))#.resize((self.target_shape, self.target_shape))

        #return torch.stack([self.transforms_A(imgA), self.transforms_B(imgB)])#.permute(0, 3, 1, 2)

        out = torch.as_tensor(np.stack([
            np.asarray(imgA, dtype=np.float32), np.asarray(imgB, dtype=np.float32)
        ]) / 128.0 - 1.0).permute(0, 3, 1, 2)

        #out = out + torch.randn_like(out) * 0.01
        #out = out.clamp(0, 1)
        return out, {}

    @staticmethod
    def take(t, A: bool):
        if A:
            return t[:, 0]
        else:
            return t[:, 1]

    @staticmethod
    def to_image(t):
        return (t.permute(0, 2, 3, 1) + 1.0) * 128.0

    def get_train(self, idx):
        return self.get(idx, self.trainA, self.trainB)

    def get_val(self, idx):
        return self.get(idx, self.testA, self.testB)

    def __getitem__(self, idx):
        if self.is_val:
            return self.get_val(idx)
        return self.get_train(idx)

In [9]:
batch_size = 8
dataset = S2WDataset(path, 128)

datasetVal = S2WDataset(path, 128, is_val=True)

dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=batch_size,
    num_workers=1,
    pin_memory=True,
    shuffle=True
)

dataVal = torch.utils.data.DataLoader(
    datasetVal,
    batch_size=batch_size,
    num_workers=1,
    pin_memory=True,
    shuffle=True
)

In [10]:
psi_theta = ConvICNN()
psi_omega = ConvICNN()

In [19]:
loss_c = W2GNLoss(psi_theta, psi_omega, lambda_cycle=35000.0, compute_extra_reg=False, devices=["cuda:0", "cuda:1"])
ns = DiffusionNoiseSchedule()
trainer = DiffTrainer(loss_c, ns, lr=1e-4)

In [20]:
# Main training loop
epochs = 30
for epoch in range(epochs):
    d_iter = iter(dataloader)

    for i in tqdm(range(len(dataloader)), desc=f"Epoch {epoch+1}/{epochs}", leave=False):
        imgs = next(d_iter)[0]
        X = dataloader.dataset.take(imgs, A=True)
        #Y = dataloader.dataset.take(imgs, A=False)
        #print(X.shape, Y.shape)
        # Training step
        losses = trainer.step(X.to(trainer.loss_c.theta_device()), t_min=100, t_max=900)
        #losses = trainer.step(X.to(trainer.loss_c.theta_device()), Y.to(trainer.loss_c.omega_device()))

        loss_R = f"{losses['loss_R']:.4f}" if losses.get('loss_R') is not None else '?'
        print(f"\nEpoch {epoch}: " +
            f"W Dist={losses['loss_corr']:.4f}, " +
            f"Cycle={losses['loss_cycle']:.4f}, " +
            f"Reg={loss_R}, " +
            f"Total={losses['loss_total']:.4f}")

Epoch 1/30:   1%|▎                                   | 1/121 [00:01<02:36,  1.30s/it]


Epoch 0: W Dist=0.0059, Cycle=55578.7891, Reg=?, Total=972628800.0000


Epoch 1/30:   2%|▌                                   | 2/121 [00:02<02:15,  1.14s/it]


Epoch 0: W Dist=-168599.7812, Cycle=1284363.5000, Reg=?, Total=22476193792.0000


Epoch 1/30:   2%|▉                                   | 3/121 [00:03<02:07,  1.08s/it]


Epoch 0: W Dist=-191532.8750, Cycle=707656.3125, Reg=?, Total=12383794176.0000


Epoch 1/30:   3%|█▏                                  | 4/121 [00:04<02:03,  1.05s/it]


Epoch 0: W Dist=-55666.3945, Cycle=370090.0000, Reg=?, Total=6476519424.0000


Epoch 1/30:   4%|█▍                                  | 5/121 [00:05<02:00,  1.03s/it]


Epoch 0: W Dist=11677.1152, Cycle=75816.8672, Reg=?, Total=1326806784.0000


Epoch 1/30:   5%|█▊                                  | 6/121 [00:06<01:58,  1.03s/it]


Epoch 0: W Dist=11513.1768, Cycle=101454.4375, Reg=?, Total=1775464192.0000


Epoch 1/30:   6%|██                                  | 7/121 [00:07<01:56,  1.02s/it]


Epoch 0: W Dist=12419.6875, Cycle=83500.6406, Reg=?, Total=1461273600.0000


Epoch 1/30:   7%|██▍                                 | 8/121 [00:08<01:54,  1.01s/it]


Epoch 0: W Dist=10908.3486, Cycle=106934.5781, Reg=?, Total=1871366016.0000


Epoch 1/30:   7%|██▋                                 | 9/121 [00:09<01:53,  1.02s/it]


Epoch 0: W Dist=13611.6738, Cycle=57771.3672, Reg=?, Total=1011012544.0000


Epoch 1/30:   8%|██▉                                | 10/121 [00:10<01:51,  1.01s/it]


Epoch 0: W Dist=10784.1729, Cycle=60982.4297, Reg=?, Total=1067203328.0000


Epoch 1/30:   9%|███▏                               | 11/121 [00:11<01:50,  1.01s/it]


Epoch 0: W Dist=11345.3369, Cycle=80385.5625, Reg=?, Total=1406758784.0000


Epoch 1/30:  10%|███▍                               | 12/121 [00:12<01:50,  1.01s/it]


Epoch 0: W Dist=13668.4824, Cycle=56809.6016, Reg=?, Total=994181696.0000


Epoch 1/30:  11%|███▊                               | 13/121 [00:13<01:48,  1.00s/it]


Epoch 0: W Dist=9089.8418, Cycle=57634.0859, Reg=?, Total=1008605568.0000


Epoch 1/30:  12%|████                               | 14/121 [00:14<01:47,  1.01s/it]


Epoch 0: W Dist=8842.7520, Cycle=57658.2148, Reg=?, Total=1009027584.0000


Epoch 1/30:  12%|████▎                              | 15/121 [00:15<01:46,  1.00s/it]


Epoch 0: W Dist=12163.4912, Cycle=58359.4453, Reg=?, Total=1021302464.0000


Epoch 1/30:  13%|████▋                              | 16/121 [00:16<01:45,  1.00s/it]


Epoch 0: W Dist=13416.4824, Cycle=57784.6641, Reg=?, Total=1011245056.0000


Epoch 1/30:  14%|████▉                              | 17/121 [00:17<01:44,  1.00s/it]


Epoch 0: W Dist=9676.7891, Cycle=50792.0234, Reg=?, Total=888870080.0000


Epoch 1/30:  15%|█████▏                             | 18/121 [00:18<01:44,  1.02s/it]


Epoch 0: W Dist=11389.6289, Cycle=53602.5547, Reg=?, Total=938056128.0000


Epoch 1/30:  16%|█████▍                             | 19/121 [00:19<01:43,  1.01s/it]


Epoch 0: W Dist=7130.0073, Cycle=58326.3281, Reg=?, Total=1020717824.0000


Epoch 1/30:  17%|█████▊                             | 20/121 [00:20<01:42,  1.01s/it]


Epoch 0: W Dist=8348.3789, Cycle=50374.9648, Reg=?, Total=881570176.0000


Epoch 1/30:  17%|██████                             | 21/121 [00:21<01:41,  1.01s/it]


Epoch 0: W Dist=14430.2402, Cycle=51785.8984, Reg=?, Total=906267648.0000


Epoch 1/30:  18%|██████▎                            | 22/121 [00:22<01:39,  1.01s/it]


Epoch 0: W Dist=10906.4072, Cycle=56766.8594, Reg=?, Total=993430912.0000


Epoch 1/30:  19%|██████▋                            | 23/121 [00:23<01:38,  1.00s/it]


Epoch 0: W Dist=9454.2559, Cycle=56283.3750, Reg=?, Total=984968512.0000


Epoch 1/30:  20%|██████▉                            | 24/121 [00:24<01:37,  1.01s/it]


Epoch 0: W Dist=8598.7744, Cycle=54606.9453, Reg=?, Total=955630144.0000


Epoch 1/30:  21%|███████▏                           | 25/121 [00:25<01:36,  1.01s/it]


Epoch 0: W Dist=18134.5430, Cycle=58317.5469, Reg=?, Total=1020575168.0000


Epoch 1/30:  21%|███████▌                           | 26/121 [00:26<01:36,  1.01s/it]


Epoch 0: W Dist=7842.8672, Cycle=59866.2891, Reg=?, Total=1047667904.0000


Epoch 1/30:  22%|███████▊                           | 27/121 [00:27<01:41,  1.08s/it]


Epoch 0: W Dist=7488.5425, Cycle=48774.5547, Reg=?, Total=853562176.0000


Epoch 1/30:  23%|████████                           | 28/121 [00:28<01:38,  1.06s/it]


Epoch 0: W Dist=-37280.0234, Cycle=490291.8125, Reg=?, Total=8580069376.0000


Epoch 1/30:  24%|████████▍                          | 29/121 [00:29<01:36,  1.04s/it]


Epoch 0: W Dist=5741.6265, Cycle=60452.4609, Reg=?, Total=1057923840.0000


Epoch 1/30:  25%|████████▋                          | 30/121 [00:30<01:33,  1.03s/it]


Epoch 0: W Dist=3897.0173, Cycle=57530.1523, Reg=?, Total=1006781568.0000


Epoch 1/30:  26%|████████▉                          | 31/121 [00:31<01:32,  1.02s/it]


Epoch 0: W Dist=1324.7413, Cycle=45635.6602, Reg=?, Total=798625408.0000


Epoch 1/30:  26%|█████████▎                         | 32/121 [00:32<01:29,  1.01s/it]


Epoch 0: W Dist=2661.4346, Cycle=49611.7109, Reg=?, Total=868207616.0000


Epoch 1/30:  27%|█████████▌                         | 33/121 [00:33<01:29,  1.01s/it]


Epoch 0: W Dist=2209.9307, Cycle=52084.5820, Reg=?, Total=911482432.0000


Epoch 1/30:  28%|█████████▊                         | 34/121 [00:34<01:28,  1.01s/it]


Epoch 0: W Dist=-8151.3906, Cycle=46763.8906, Reg=?, Total=818359936.0000


Epoch 1/30:  29%|██████████                         | 35/121 [00:35<01:27,  1.01s/it]


Epoch 0: W Dist=-26619.9570, Cycle=50761.9023, Reg=?, Total=888306688.0000


Epoch 1/30:  30%|██████████▍                        | 36/121 [00:36<01:26,  1.02s/it]


Epoch 0: W Dist=-20012.3438, Cycle=44275.6094, Reg=?, Total=774803136.0000


Epoch 1/30:  31%|██████████▋                        | 37/121 [00:38<01:31,  1.08s/it]


Epoch 0: W Dist=-10281.1738, Cycle=50663.2305, Reg=?, Total=886596224.0000


Epoch 1/30:  31%|██████████▉                        | 38/121 [00:39<01:27,  1.06s/it]


Epoch 0: W Dist=-8035.1211, Cycle=43258.0234, Reg=?, Total=757007360.0000


Epoch 1/30:  32%|███████████▎                       | 39/121 [00:40<01:26,  1.06s/it]


Epoch 0: W Dist=-3923.5498, Cycle=49543.8203, Reg=?, Total=867012928.0000


Epoch 1/30:  33%|███████████▌                       | 40/121 [00:41<01:24,  1.04s/it]


Epoch 0: W Dist=-501.3381, Cycle=39900.9492, Reg=?, Total=698266112.0000


Epoch 1/30:  34%|███████████▊                       | 41/121 [00:42<01:23,  1.04s/it]


Epoch 0: W Dist=-2555.0027, Cycle=38355.9453, Reg=?, Total=671226496.0000


Epoch 1/30:  35%|████████████▏                      | 42/121 [00:43<01:21,  1.03s/it]


Epoch 0: W Dist=-7830.3521, Cycle=41647.6602, Reg=?, Total=728826240.0000


Epoch 1/30:  36%|████████████▍                      | 43/121 [00:44<01:19,  1.02s/it]


Epoch 0: W Dist=-6758.0762, Cycle=36249.6250, Reg=?, Total=634361664.0000


Epoch 1/30:  36%|████████████▋                      | 44/121 [00:45<01:19,  1.03s/it]


Epoch 0: W Dist=-5580.2314, Cycle=32289.2168, Reg=?, Total=565055744.0000


Epoch 1/30:  37%|█████████████                      | 45/121 [00:46<01:17,  1.02s/it]


Epoch 0: W Dist=-9975.0889, Cycle=34605.3750, Reg=?, Total=605584064.0000


Epoch 1/30:  38%|█████████████▎                     | 46/121 [00:47<01:16,  1.02s/it]


Epoch 0: W Dist=-10434.7979, Cycle=28913.2812, Reg=?, Total=505972000.0000


Epoch 1/30:  39%|█████████████▌                     | 47/121 [00:48<01:15,  1.01s/it]


Epoch 0: W Dist=-7972.7598, Cycle=35494.7266, Reg=?, Total=621149696.0000


Epoch 1/30:  40%|█████████████▉                     | 48/121 [00:49<01:13,  1.01s/it]


Epoch 0: W Dist=-9820.0742, Cycle=29450.9414, Reg=?, Total=515381664.0000


Epoch 1/30:  40%|██████████████▏                    | 49/121 [00:50<01:12,  1.01s/it]


Epoch 0: W Dist=-14130.9355, Cycle=30750.2969, Reg=?, Total=538116032.0000


Epoch 1/30:  41%|██████████████▍                    | 50/121 [00:51<01:11,  1.01s/it]


Epoch 0: W Dist=-12671.3301, Cycle=29393.4238, Reg=?, Total=514372256.0000


Epoch 1/30:  42%|██████████████▊                    | 51/121 [00:52<01:10,  1.01s/it]


Epoch 0: W Dist=-16573.9863, Cycle=30846.2637, Reg=?, Total=539793024.0000


Epoch 1/30:  43%|███████████████                    | 52/121 [00:53<01:09,  1.01s/it]


Epoch 0: W Dist=-9512.5586, Cycle=26783.3789, Reg=?, Total=468699616.0000


Epoch 1/30:  44%|███████████████▎                   | 53/121 [00:54<01:08,  1.01s/it]


Epoch 0: W Dist=-7706.8989, Cycle=28001.4316, Reg=?, Total=490017344.0000


Epoch 1/30:  45%|███████████████▌                   | 54/121 [00:55<01:12,  1.08s/it]


Epoch 0: W Dist=-11169.4629, Cycle=32405.4238, Reg=?, Total=567083712.0000


Epoch 1/30:  45%|███████████████▉                   | 55/121 [00:56<01:09,  1.06s/it]


Epoch 0: W Dist=-9734.2891, Cycle=29974.1387, Reg=?, Total=524537696.0000


Epoch 1/30:  46%|████████████████▏                  | 56/121 [00:57<01:07,  1.04s/it]


Epoch 0: W Dist=-13684.1094, Cycle=28191.2305, Reg=?, Total=493332832.0000


Epoch 1/30:  47%|████████████████▍                  | 57/121 [00:58<01:06,  1.03s/it]


Epoch 0: W Dist=-13028.4590, Cycle=31034.3477, Reg=?, Total=543088000.0000


Epoch 1/30:  48%|████████████████▊                  | 58/121 [00:59<01:04,  1.03s/it]


Epoch 0: W Dist=-13514.5996, Cycle=27025.4316, Reg=?, Total=472931552.0000


Epoch 1/30:  49%|█████████████████                  | 59/121 [01:00<01:03,  1.03s/it]


Epoch 0: W Dist=-10649.8066, Cycle=24191.4492, Reg=?, Total=423339712.0000


Epoch 1/30:  50%|█████████████████▎                 | 60/121 [01:01<01:02,  1.02s/it]


Epoch 0: W Dist=-14790.7803, Cycle=27313.1797, Reg=?, Total=477965856.0000


Epoch 1/30:  50%|█████████████████▋                 | 61/121 [01:02<01:00,  1.01s/it]


Epoch 0: W Dist=-17360.1719, Cycle=24983.4531, Reg=?, Total=437193056.0000


Epoch 1/30:  51%|█████████████████▉                 | 62/121 [01:03<00:59,  1.02s/it]


Epoch 0: W Dist=-20456.4961, Cycle=27401.9863, Reg=?, Total=479514304.0000


Epoch 1/30:  52%|██████████████████▏                | 63/121 [01:04<00:58,  1.02s/it]


Epoch 0: W Dist=-13719.5752, Cycle=28076.7305, Reg=?, Total=491329056.0000


Epoch 1/30:  53%|██████████████████▌                | 64/121 [01:05<00:58,  1.02s/it]


Epoch 0: W Dist=-11644.7129, Cycle=27559.7305, Reg=?, Total=482283648.0000


Epoch 1/30:  54%|██████████████████▊                | 65/121 [01:06<00:57,  1.03s/it]


Epoch 0: W Dist=-11317.1289, Cycle=27445.0840, Reg=?, Total=480277632.0000


Epoch 1/30:  55%|███████████████████                | 66/121 [01:07<00:55,  1.02s/it]


Epoch 0: W Dist=-9503.9668, Cycle=28865.2734, Reg=?, Total=505132768.0000


Epoch 1/30:  55%|███████████████████▍               | 67/121 [01:08<00:55,  1.02s/it]


Epoch 0: W Dist=-11931.6943, Cycle=24379.1484, Reg=?, Total=426623168.0000


Epoch 1/30:  56%|███████████████████▋               | 68/121 [01:09<00:54,  1.02s/it]


Epoch 0: W Dist=-14625.4473, Cycle=24935.0156, Reg=?, Total=436348160.0000


Epoch 1/30:  57%|███████████████████▉               | 69/121 [01:10<00:52,  1.02s/it]


Epoch 0: W Dist=-11418.8027, Cycle=29845.0234, Reg=?, Total=522276480.0000


Epoch 1/30:  58%|████████████████████▏              | 70/121 [01:11<00:52,  1.03s/it]


Epoch 0: W Dist=-12775.9668, Cycle=25608.5430, Reg=?, Total=448136736.0000


Epoch 1/30:  59%|████████████████████▌              | 71/121 [01:13<00:54,  1.09s/it]


Epoch 0: W Dist=-12637.3398, Cycle=20168.5898, Reg=?, Total=352937696.0000


Epoch 1/30:  60%|████████████████████▊              | 72/121 [01:14<00:52,  1.07s/it]


Epoch 0: W Dist=-13305.4492, Cycle=20625.7266, Reg=?, Total=360936896.0000


Epoch 1/30:  60%|█████████████████████              | 73/121 [01:15<00:53,  1.12s/it]


Epoch 0: W Dist=-14211.8789, Cycle=23159.7109, Reg=?, Total=405280736.0000


Epoch 1/30:  61%|█████████████████████▍             | 74/121 [01:16<00:51,  1.09s/it]


Epoch 0: W Dist=-16147.4893, Cycle=21848.3438, Reg=?, Total=382329856.0000


Epoch 1/30:  62%|█████████████████████▋             | 75/121 [01:17<00:49,  1.07s/it]


Epoch 0: W Dist=-16177.7539, Cycle=19667.2578, Reg=?, Total=344160832.0000


Epoch 1/30:  63%|█████████████████████▉             | 76/121 [01:18<00:47,  1.05s/it]


Epoch 0: W Dist=-14222.6533, Cycle=22063.8477, Reg=?, Total=386103136.0000


Epoch 1/30:  64%|██████████████████████▎            | 77/121 [01:19<00:45,  1.04s/it]


Epoch 0: W Dist=-14292.7305, Cycle=18090.8730, Reg=?, Total=316575968.0000


Epoch 1/30:  64%|██████████████████████▌            | 78/121 [01:20<00:44,  1.03s/it]


Epoch 0: W Dist=-14116.7197, Cycle=18161.7949, Reg=?, Total=317817312.0000


Epoch 1/30:  65%|██████████████████████▊            | 79/121 [01:21<00:46,  1.10s/it]


Epoch 0: W Dist=-14349.3086, Cycle=17154.1758, Reg=?, Total=300183744.0000


Epoch 1/30:  66%|███████████████████████▏           | 80/121 [01:22<00:43,  1.07s/it]


Epoch 0: W Dist=-15347.4814, Cycle=18504.1016, Reg=?, Total=323806432.0000


Epoch 1/30:  67%|███████████████████████▍           | 81/121 [01:23<00:42,  1.05s/it]


Epoch 0: W Dist=-14425.7441, Cycle=18296.6914, Reg=?, Total=320177664.0000


Epoch 1/30:  68%|███████████████████████▋           | 82/121 [01:24<00:40,  1.04s/it]


Epoch 0: W Dist=-13547.1562, Cycle=16653.0586, Reg=?, Total=291414976.0000


Epoch 1/30:  69%|████████████████████████           | 83/121 [01:25<00:39,  1.04s/it]


Epoch 0: W Dist=-15629.6914, Cycle=19337.8906, Reg=?, Total=338397472.0000


Epoch 1/30:  69%|████████████████████████▎          | 84/121 [01:26<00:38,  1.03s/it]


Epoch 0: W Dist=-12438.6094, Cycle=17924.9023, Reg=?, Total=313673344.0000


Epoch 1/30:  70%|████████████████████████▌          | 85/121 [01:27<00:37,  1.03s/it]


Epoch 0: W Dist=-13623.3848, Cycle=16832.6797, Reg=?, Total=294558272.0000


Epoch 1/30:  71%|████████████████████████▉          | 86/121 [01:28<00:35,  1.02s/it]


Epoch 0: W Dist=-13940.7832, Cycle=16237.5225, Reg=?, Total=284142688.0000


Epoch 1/30:  72%|█████████████████████████▏         | 87/121 [01:29<00:34,  1.02s/it]


Epoch 0: W Dist=-14641.3789, Cycle=18880.8145, Reg=?, Total=330399584.0000


Epoch 1/30:  73%|█████████████████████████▍         | 88/121 [01:30<00:33,  1.01s/it]


Epoch 0: W Dist=-13050.3408, Cycle=15715.0098, Reg=?, Total=274999616.0000


Epoch 1/30:  74%|█████████████████████████▋         | 89/121 [01:31<00:32,  1.02s/it]


Epoch 0: W Dist=-13358.2861, Cycle=17399.9902, Reg=?, Total=304486496.0000


Epoch 1/30:  74%|██████████████████████████         | 90/121 [01:32<00:31,  1.01s/it]


Epoch 0: W Dist=-14104.3330, Cycle=17320.4922, Reg=?, Total=303094496.0000


Epoch 1/30:  75%|██████████████████████████▎        | 91/121 [01:33<00:30,  1.01s/it]


Epoch 0: W Dist=-14804.2480, Cycle=15564.3145, Reg=?, Total=272360672.0000


Epoch 1/30:  76%|██████████████████████████▌        | 92/121 [01:34<00:29,  1.01s/it]


Epoch 0: W Dist=-13543.5820, Cycle=15083.4424, Reg=?, Total=263946704.0000


Epoch 1/30:  77%|██████████████████████████▉        | 93/121 [01:35<00:28,  1.01s/it]


Epoch 0: W Dist=-17288.5234, Cycle=16920.3984, Reg=?, Total=296089696.0000


Epoch 1/30:  78%|███████████████████████████▏       | 94/121 [01:36<00:27,  1.01s/it]


Epoch 0: W Dist=-13672.2061, Cycle=16623.6523, Reg=?, Total=290900256.0000


Epoch 1/30:  79%|███████████████████████████▍       | 95/121 [01:37<00:26,  1.03s/it]


Epoch 0: W Dist=-15098.1025, Cycle=16245.4316, Reg=?, Total=284279936.0000


Epoch 1/30:  79%|███████████████████████████▊       | 96/121 [01:39<00:27,  1.10s/it]


Epoch 0: W Dist=-17232.9062, Cycle=16521.5273, Reg=?, Total=289109472.0000


Epoch 1/30:  80%|████████████████████████████       | 97/121 [01:40<00:25,  1.08s/it]


Epoch 0: W Dist=-13048.0762, Cycle=14606.7324, Reg=?, Total=255604768.0000


Epoch 1/30:  81%|████████████████████████████▎      | 98/121 [01:41<00:24,  1.08s/it]


Epoch 0: W Dist=-15255.1553, Cycle=14455.6289, Reg=?, Total=252958256.0000


Epoch 1/30:  82%|████████████████████████████▋      | 99/121 [01:42<00:24,  1.13s/it]


Epoch 0: W Dist=-15384.6680, Cycle=14340.6953, Reg=?, Total=250946768.0000


Epoch 1/30:  83%|████████████████████████████      | 100/121 [01:43<00:22,  1.09s/it]


Epoch 0: W Dist=-13885.5566, Cycle=14211.5293, Reg=?, Total=248687872.0000


Epoch 1/30:  83%|████████████████████████████▍     | 101/121 [01:44<00:21,  1.07s/it]


Epoch 0: W Dist=-16980.2773, Cycle=15398.1465, Reg=?, Total=269450560.0000


Epoch 1/30:  84%|████████████████████████████▋     | 102/121 [01:45<00:20,  1.05s/it]


Epoch 0: W Dist=-16072.0469, Cycle=17003.5488, Reg=?, Total=297546048.0000


Epoch 1/30:  85%|████████████████████████████▉     | 103/121 [01:46<00:18,  1.04s/it]


Epoch 0: W Dist=-16198.5244, Cycle=15302.3262, Reg=?, Total=267774512.0000


Epoch 1/30:  86%|█████████████████████████████▏    | 104/121 [01:47<00:17,  1.04s/it]


Epoch 0: W Dist=-19455.0117, Cycle=15823.1777, Reg=?, Total=276886144.0000


Epoch 1/30:  87%|█████████████████████████████▌    | 105/121 [01:48<00:16,  1.03s/it]


Epoch 0: W Dist=-15925.1885, Cycle=15093.5684, Reg=?, Total=264121520.0000


Epoch 1/30:  88%|█████████████████████████████▊    | 106/121 [01:49<00:15,  1.03s/it]


Epoch 0: W Dist=-18425.4824, Cycle=15181.9043, Reg=?, Total=265664896.0000


Epoch 1/30:  88%|██████████████████████████████    | 107/121 [01:50<00:14,  1.03s/it]


Epoch 0: W Dist=-16027.6445, Cycle=14739.1602, Reg=?, Total=257919264.0000


Epoch 1/30:  89%|██████████████████████████████▎   | 108/121 [01:51<00:13,  1.03s/it]


Epoch 0: W Dist=-16869.0332, Cycle=13842.7266, Reg=?, Total=242230848.0000


Epoch 1/30:  90%|██████████████████████████████▋   | 109/121 [01:52<00:12,  1.03s/it]


Epoch 0: W Dist=-16011.6777, Cycle=13046.2363, Reg=?, Total=228293120.0000


Epoch 1/30:  91%|██████████████████████████████▉   | 110/121 [01:53<00:11,  1.02s/it]


Epoch 0: W Dist=-15725.8682, Cycle=12196.5742, Reg=?, Total=213424320.0000


Epoch 1/30:  92%|███████████████████████████████▏  | 111/121 [01:54<00:10,  1.03s/it]


Epoch 0: W Dist=-14842.3516, Cycle=13979.1455, Reg=?, Total=244620192.0000


Epoch 1/30:  93%|███████████████████████████████▍  | 112/121 [01:55<00:09,  1.03s/it]


Epoch 0: W Dist=-17642.5664, Cycle=14406.1016, Reg=?, Total=252089136.0000


Epoch 1/30:  93%|███████████████████████████████▊  | 113/121 [01:57<00:08,  1.10s/it]


Epoch 0: W Dist=-17631.1738, Cycle=14921.7441, Reg=?, Total=261112896.0000


Epoch 1/30:  94%|████████████████████████████████  | 114/121 [01:58<00:07,  1.07s/it]


Epoch 0: W Dist=-17204.1445, Cycle=15917.1162, Reg=?, Total=278532320.0000


Epoch 1/30:  95%|████████████████████████████████▎ | 115/121 [01:59<00:06,  1.06s/it]


Epoch 0: W Dist=-15336.0205, Cycle=12862.6211, Reg=?, Total=225080528.0000


Epoch 1/30:  96%|████████████████████████████████▌ | 116/121 [02:00<00:05,  1.12s/it]


Epoch 0: W Dist=-17400.0625, Cycle=13280.4961, Reg=?, Total=232391280.0000


Epoch 1/30:  97%|████████████████████████████████▉ | 117/121 [02:01<00:04,  1.08s/it]


Epoch 0: W Dist=-18501.8359, Cycle=13663.5703, Reg=?, Total=239093984.0000


Epoch 1/30:  98%|█████████████████████████████████▏| 118/121 [02:02<00:03,  1.06s/it]


Epoch 0: W Dist=-18678.4453, Cycle=14457.8311, Reg=?, Total=252993376.0000


Epoch 1/30:  98%|█████████████████████████████████▍| 119/121 [02:03<00:02,  1.05s/it]


Epoch 0: W Dist=-20178.1074, Cycle=14525.9014, Reg=?, Total=254183104.0000


Epoch 1/30:  99%|█████████████████████████████████▋| 120/121 [02:04<00:01,  1.04s/it]


Epoch 0: W Dist=-18756.9863, Cycle=14564.6758, Reg=?, Total=254863072.0000


                                                                                     


Epoch 0: W Dist=-19621.0176, Cycle=14205.0840, Reg=?, Total=248569360.0000


Epoch 2/30:   1%|▎                                   | 1/121 [00:01<02:06,  1.06s/it]


Epoch 1: W Dist=-18329.1680, Cycle=13873.3398, Reg=?, Total=242765104.0000


Epoch 2/30:   2%|▌                                   | 2/121 [00:02<02:03,  1.03s/it]


Epoch 1: W Dist=-20843.0293, Cycle=13373.1797, Reg=?, Total=234009792.0000


Epoch 2/30:   2%|▉                                   | 3/121 [00:03<02:00,  1.03s/it]


Epoch 1: W Dist=-20670.1172, Cycle=15064.7852, Reg=?, Total=263613072.0000


Epoch 2/30:   3%|█▏                                  | 4/121 [00:04<02:10,  1.11s/it]


Epoch 1: W Dist=-18971.3809, Cycle=14613.3418, Reg=?, Total=255714512.0000


Epoch 2/30:   4%|█▍                                  | 5/121 [00:05<02:04,  1.08s/it]


Epoch 1: W Dist=-23391.4258, Cycle=13783.6074, Reg=?, Total=241189744.0000


Epoch 2/30:   5%|█▊                                  | 6/121 [00:06<02:01,  1.06s/it]


Epoch 1: W Dist=-19512.4492, Cycle=16052.0898, Reg=?, Total=280892064.0000


Epoch 2/30:   6%|██                                  | 7/121 [00:07<01:58,  1.04s/it]


Epoch 1: W Dist=-17373.4961, Cycle=13646.5410, Reg=?, Total=238797088.0000


Epoch 2/30:   7%|██▍                                 | 8/121 [00:08<01:56,  1.03s/it]


Epoch 1: W Dist=-20777.4336, Cycle=13881.2119, Reg=?, Total=242900432.0000


Epoch 2/30:   7%|██▋                                 | 9/121 [00:09<01:56,  1.04s/it]


Epoch 1: W Dist=-21628.2461, Cycle=14492.0352, Reg=?, Total=253588976.0000


Epoch 2/30:   8%|██▉                                | 10/121 [00:10<02:04,  1.12s/it]


Epoch 1: W Dist=-16228.4512, Cycle=13667.0176, Reg=?, Total=239156576.0000


Epoch 2/30:   9%|███▏                               | 11/121 [00:12<02:08,  1.16s/it]


Epoch 1: W Dist=-20322.5605, Cycle=15497.0527, Reg=?, Total=271178112.0000


Epoch 2/30:  10%|███▍                               | 12/121 [00:13<02:03,  1.13s/it]


Epoch 1: W Dist=-18606.8984, Cycle=13673.0010, Reg=?, Total=239258912.0000


Epoch 2/30:  11%|███▊                               | 13/121 [00:14<01:58,  1.10s/it]


Epoch 1: W Dist=-19739.1738, Cycle=11510.9316, Reg=?, Total=201421552.0000


Epoch 2/30:  12%|████                               | 14/121 [00:15<01:54,  1.07s/it]


Epoch 1: W Dist=-24311.0801, Cycle=13852.1699, Reg=?, Total=242388672.0000


Epoch 2/30:  12%|████▎                              | 15/121 [00:16<01:51,  1.06s/it]


Epoch 1: W Dist=-17773.3008, Cycle=15293.8945, Reg=?, Total=267625376.0000


Epoch 2/30:  13%|████▋                              | 16/121 [00:17<01:49,  1.04s/it]


Epoch 1: W Dist=-18444.8027, Cycle=13985.9561, Reg=?, Total=244735776.0000


Epoch 2/30:  14%|████▉                              | 17/121 [00:18<01:55,  1.11s/it]


Epoch 1: W Dist=-22371.0273, Cycle=13962.8652, Reg=?, Total=244327776.0000


                                                                                     

KeyboardInterrupt: 

In [21]:
Val_iter = iter(dataVal)
X_test = dataVal.dataset.take(next(Val_iter)[0], A=True).to(trainer.loss_c.theta_device())
X_test.requires_grad_(True)

cols = 8
noisy = []
for i in range(cols):
    noisy.append(trainer.noise_schedule.add_noise(X_test[i], t=900))
X_noise = torch.stack(noisy)

Y_generated = trainer.denoise(X_noise).detach().cpu()

# Visualize
fig, axes = plt.subplots(nrows=2, ncols=cols, figsize=(16, 4))
fig.suptitle('Place at summer / place at winter:')

for i in range(cols):
    X_img = np.asarray(dataVal.dataset.to_image(X_noise.detach().cpu())[i], dtype=np.uint8)
    _ = axes[0][i].imshow(X_img)

    Y_img = np.asarray(dataVal.dataset.to_image(Y_generated.cpu())[i], dtype=np.uint8)
    _ = axes[1][i].imshow(Y_img)

#plt.tight_layout()
fig.show()

AttributeError: 'DiffTrainer' object has no attribute 'generate'