In [1]:
import pickle as pkl
import torch

In [2]:
with open("../data/pickles/50salads_target.pkl", "rb") as f:
    data = pkl.load(f)
with open("../data/pickles/50_salads_one_hot.pkl", "rb") as f:
    data_oh = pkl.load(f)

In [3]:
from torch.utils.data import DataLoader
from dataset.dataset import SaladsDataset
from denoisers.UnetDenoiser import UnetDenoiser
from ddpm.ddmp_multinomial import Diffusion

salads_dataset = SaladsDataset(data_oh)

In [4]:
denoiser = UnetDenoiser(in_ch=20, out_ch=20, max_input_dim=salads_dataset.sequence_length).to('cuda').float()
checkpoint = torch.load('unet_eval_train_2/last.ckpt')
denoiser.load_state_dict(checkpoint['model_state'])

<All keys matched successfully>

In [5]:
loader = DataLoader(
    salads_dataset,
    batch_size=1,
    shuffle=True
)

test = next(iter(loader)).to('cuda')
test

tensor([[[0., 0., 0.,  ..., 1., 0., 0.],
         [0., 0., 0.,  ..., 1., 0., 0.],
         [0., 0., 0.,  ..., 1., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 1., 0.],
         [0., 0., 0.,  ..., 0., 1., 0.],
         [0., 0., 0.,  ..., 0., 1., 0.]]], device='cuda:0',
       dtype=torch.float64)

In [6]:
diffuser = Diffusion(noise_steps=200)

In [7]:
t = torch.tensor([199]).to('cuda') 
x_t, eps = diffuser.noise_data(test, t)

In [8]:
eps_hat = denoiser(x_t.permute(0, 2, 1).float(), t)

In [9]:
from torch import nn

criterion = nn.MSELoss()
criterion(eps_hat, eps.permute(0, 2, 1))

tensor(0.0709, device='cuda:0', dtype=torch.float64,
       grad_fn=<MseLossBackward0>)

In [10]:
sqrt_alpha_hat = torch.sqrt(diffuser.alpha_hat[t])[:, None, None]
sqrt_one_minus_alpha_hat = torch.sqrt(1 - diffuser.alpha_hat[t])[:, None, None]
x_hat = diffuser.denoise(denoiser, x_t.permute(0, 2, 1).float(), t)
x_hat

  t_tensor = torch.tensor(t).long().to(self.device)


tensor([[[  5.7263, -20.7785,  -9.4052,  ...,   0.5287,  -8.2503, -13.0716],
         [  6.4067,   7.0115,  -2.1612,  ...,  -2.2935,   1.7831,  -0.3032],
         [ 10.0543,  13.9898,  -0.2321,  ..., -14.1483,  16.3022, -10.3928],
         ...,
         [ 10.6984,  15.1259,   6.2504,  ..., -14.2516, -18.4344,   6.0720],
         [ -0.5810,  -6.5327, -14.3074,  ...,   3.1594,   1.1031,  -2.7491],
         [-17.1704,   0.7972,   0.2797,  ...,   2.2053, -14.4310,  17.4099]]],
       device='cuda:0')

In [21]:
torch.softmax(x_hat, dim=1)[0][:, 0]

tensor([1.0950e-08, 2.1622e-08, 8.2994e-07, 6.6863e-14, 3.6213e-20, 6.1329e-09,
        5.8262e-10, 1.0000e+00, 1.6127e-17, 7.3481e-26, 6.2130e-11, 1.0840e-08,
        5.7231e-15, 2.6599e-11, 3.5455e-08, 1.6197e-10, 4.9502e-09, 1.5804e-06,
        1.9962e-11, 1.2460e-18], device='cuda:0')

In [11]:
(x_t - sqrt_one_minus_alpha_hat * eps) / sqrt_alpha_hat

tensor([[[0., 0., 0.,  ..., 1., 0., 0.],
         [0., 0., 0.,  ..., 1., 0., 0.],
         [0., 0., 0.,  ..., 1., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 1.],
         [0., 0., 0.,  ..., 0., 0., 1.],
         [0., 0., 0.,  ..., 0., 0., 1.]]], device='cuda:0',
       dtype=torch.float64)

In [23]:
eps_hat

tensor([[[-0.5872,  0.9667,  0.9136,  ...,  0.1632,  0.0249,  0.3153],
         [-0.2047, -1.3089, -1.4075,  ..., -0.1091, -1.3485,  0.6684],
         [-0.2906, -1.1066,  0.4978,  ..., -0.6633, -1.2183, -0.2044],
         ...,
         [ 0.0032, -1.1601,  0.2727,  ...,  0.6827,  0.9684, -0.6331],
         [-0.8048,  0.4859, -0.6551,  ...,  0.5285, -1.4002,  0.2146],
         [ 0.7995,  0.3563, -0.1769,  ..., -0.7588,  0.8646, -0.9706]]],
       device='cuda:0', grad_fn=<ConvolutionBackward0>)