In [1]:
import os, sys
sys.path.append("..")

import torch
from torch import nn
from torch.optim import Adam
import torch.nn.functional as F

import numpy as np

from matplotlib import pyplot as plt

from tqdm import tqdm
from IPython.display import clear_output

from MNIST_models.plotters import plot_trajectories, plot_images

from src.tools import load_dataset

from src.fid import save_model_samples

import wandb
import gc

import os
SEED = 0xBADBEEF
torch.manual_seed(SEED); np.random.seed(SEED)

In [2]:
batch_size = 64
IMG_SIZE = 32
IMG_CHANNELS = 3
ZC = 1
Z_STD = 1.0
GAMMA = 0.5

TIME_DIM = 128
UNET_BASE_FACTOR = 48
N_STEPS = 10

lr = 1e-4

G_ITERS = 10
D_ITERS = 1
f_ITERS = 2
MAX_STEPS = 50000

model_channels = 32

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
sampler3, test_sampler3, loader3, test_loader3 = load_dataset('MNIST-colored_3', './datasets/MNIST', img_size=IMG_SIZE, batch_size=batch_size, device=device)
sampler2, test_sampler2, loader2, test_loader2 = load_dataset('MNIST-colored_2', './datasets/MNIST', img_size=IMG_SIZE, batch_size=batch_size, device=device)
Y_sampler = sampler3
X_sampler = sampler2

Y_loader_test = loader3
X_loader_test = loader2

In [5]:
from EDM_models.D import SongUNet_D
from EDM_models.G import SongUNet_G
from EDM_models.f import SongUNet_f

from EDM_models.enot import SDE_denoiser, G_wrapper

In [6]:
D = SongUNet_D(IMG_SIZE, IMG_CHANNELS, model_channels=model_channels*2).to(device)
G = SongUNet_G(IMG_SIZE, IMG_CHANNELS+1, IMG_CHANNELS, model_channels=model_channels).to(device)
G = G_wrapper(G, ZC, Z_STD)
f = SongUNet_f(IMG_SIZE, IMG_CHANNELS, IMG_CHANNELS, model_channels=model_channels).to(device)
sde = SDE_denoiser(denoiser=f, n_steps=N_STEPS).to(device)

sde_opt = Adam(sde.parameters(), lr=lr*10)
G_opt = Adam(G.parameters(), lr=lr)
D_opt = Adam(D.parameters(), lr=lr)
    
print('D params:', np.sum([np.prod(p.shape) for p in D.parameters()]))
print('G params:', np.sum([np.prod(p.shape) for p in G.parameters()]))
print('sde params:', np.sum([np.prod(p.shape) for p in sde.parameters()]))

D params: 4748865
G params: 3539523
sde params: 3881955


In [7]:
from src.fid import calculate_inception_stats, calculate_fid_from_inception_stats

def calc_fid(G_samples_path, Y_samples_path, num_expected, batch):
    mu_X, sigma_X = calculate_inception_stats(image_path=G_samples_path, num_expected=num_expected, max_batch_size=batch)
    mu_Y, sigma_Y = calculate_inception_stats(image_path=Y_samples_path, num_expected=num_expected, max_batch_size=batch)
    fid = calculate_fid_from_inception_stats(mu_X, sigma_X, mu_Y, sigma_Y)
    return fid


In [8]:
def trainENOT(X_sampler, Y_sampler, G, G_opt, D, D_opt, sde, sde_opt):
    
    for step in tqdm(range(MAX_STEPS)):
            
        for G_iter in range(G_ITERS):

            for f_iter in range(f_ITERS):
                x0 = X_sampler.sample(batch_size)
                xN = G(x0)
                
                t = torch.rand(batch_size).to(device)
                xt = x0 + (xN - x0) * t[:, None, None, None] + torch.randn_like(x0)*torch.sqrt(t*(1-t)*GAMMA)[:, None, None, None]
                
                f_loss = ((sde.denoiser(xt, t) - xN) ** 2).mean()
                sde_opt.zero_grad(); f_loss.backward(); sde_opt.step()

            x0 = X_sampler.sample(batch_size)
            xN = G(x0)

            t = torch.rand(batch_size).to(device)
            xt = x0 + (xN - x0) * t[:, None, None, None] + torch.randn_like(x0)*torch.sqrt(t*(1-t)*GAMMA)[:, None, None, None]
            
            f_x_t = (sde.denoiser(xt, t) - xt)
            E = (xN - xt)

            G_loss = ((f_x_t*E).mean() - (f_x_t*f_x_t).mean()/2)*2 - D(xN).mean()
            G_opt.zero_grad(); G_loss.backward(); G_opt.step()
        
        
        if step % 50 == 0:
            clear_output(wait=True)
            
            with torch.no_grad():
                X = X_sampler.sample(batch_size)

                T_XZ_np = []
                for i in range(100):
                    T_XZ_np.append(G(X).cpu().numpy())
                T_XZ_np = np.array(T_XZ_np)
                wandb.log({f'G var' : T_XZ_np.var(axis=0).mean().item()}, step=step)

                T_X_np = []
                for i in range(100):
                    T_X_np.append(sde(X, GAMMA).cpu().numpy())
                T_X_np = np.array(T_X_np)
                wandb.log({f'sde var' : T_X_np.var(axis=0).mean().item()}, step=step)
            
                G_dataset = G(X).detach()
                f_dataset = sde(X, GAMMA).detach()
                
                wandb.log({f'G L2' : F.mse_loss(X.detach(), G_dataset).item()}, step=step)
                wandb.log({f'sde L2' : F.mse_loss(X.detach(), f_dataset).item()}, step=step)
                torch.cuda.empty_cache(); gc.collect()
                
                fig1 = plot_trajectories(sde, GAMMA, X_sampler, 3)
                wandb.log({"trajectories": wandb.Image(fig1)}, step=step)
                plt.close(fig1)
                torch.cuda.empty_cache(); gc.collect()

                fig2 = plot_images(G, X_sampler, 4, 2)
                wandb.log({"G_images": wandb.Image(fig2)}, step=step)
                plt.close(fig2)
                torch.cuda.empty_cache(); gc.collect()
                
                fig3 = plot_images(sde, X_sampler, 4, 2, GAMMA)
                wandb.log({"SDE_images": wandb.Image(fig3)}, step=step)
                plt.close(fig3)
                torch.cuda.empty_cache(); gc.collect()
                
                l2, lpips = save_model_samples('samplesG', G, X_loader_test, 32, 1000, device, 'samplesY', Y_loader_test)
                fid = calc_fid('samplesG', 'samplesY', 1000, 32)
                wandb.log({f'FID' : fid}, step=step)
                wandb.log({f'lpips' : lpips}, step=step)
                torch.cuda.empty_cache(); gc.collect()
            
    
        for D_iter in range(D_ITERS):
            x0 = X_sampler.sample(batch_size)
            x1 = Y_sampler.sample(batch_size)
            xN = G(x0)
            D_loss = (- D(x1) + D(xN)).mean()
            D_opt.zero_grad(); D_loss.backward(); D_opt.step()

        wandb.log({f'f_loss' : f_loss.item()}, step=step)
        wandb.log({f'G_loss' : G_loss.item()}, step=step)
        wandb.log({f'D_loss' : D_loss.item()}, step=step)

In [9]:
wandb.init(project='MNIST_EDM')

[34m[1mwandb[0m: Currently logged in as: [33milyasudakov[0m ([33msudakov[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


In [10]:
stats = trainENOT(X_sampler, Y_sampler, G, G_opt, D, D_opt, sde, sde_opt)

  x = x + self.delta_t*(self.denoiser(x, t) - x)/(1-torch.tensor(t)[:, None, None, None].cuda()) + torch.randn_like(x)*np.sqrt(gamma*self.delta_t)
  x = x + self.delta_t*(self.denoiser(x, t) - x)/(1-torch.tensor(t)[:, None, None, None].cuda())


Setting up [LPIPS] perceptual loss: trunk [vgg], v[0.1], spatial [off]


  self.load_state_dict(torch.load(model_path, map_location='cpu'), strict=False)


Loading model from: /home/sudakovcom/Desktop/NOT/.conda/lib/python3.11/site-packages/lpips/weights/v0.1/vgg.pth


  0%|          | 0/1000 [00:07<?, ?it/s]
  0%|          | 0/1000 [00:06<?, ?it/s]


Loading Inception-v3 model...
Loading images from "samplesG"...
Calculating statistics for 1000 images...


100%|██████████| 1000/1000 [00:05<00:00, 178.92batch/s]


Loading Inception-v3 model...
Loading images from "samplesY"...
Calculating statistics for 1000 images...


100%|██████████| 1000/1000 [00:05<00:00, 180.09batch/s]
  0%|          | 11/50000 [02:23<181:05:39, 13.04s/it]


KeyboardInterrupt: 

Error in callback <bound method _WandbInit._pause_backend of <wandb.sdk.wandb_init._WandbInit object at 0x7b6f948c2650>> (for post_run_cell), with arguments args (<ExecutionResult object at 7b6fc8f95d50, execution_count=10 error_before_exec=None error_in_exec= info=<ExecutionInfo object at 7b6fc910a310, raw_cell="stats = trainENOT(X_sampler, Y_sampler, G, G_opt, .." store_history=True silent=False shell_futures=True cell_id=vscode-notebook-cell://ssh-remote%2B172.23.73.58/home/sudakovcom/Desktop/ENOT_TOY/EDM.ipynb#X11sdnNjb2RlLXJlbW90ZQ%3D%3D> result=None>,),kwargs {}:


BrokenPipeError: [Errno 32] Broken pipe