In [1]:
# Package
import os
import torch
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import DatasetFolder 
from PIL import Image
from tqdm import tqdm
from models import *
from function import *
import numpy as np


!nvidia-smi
print("CUDA Available : ", torch.cuda.is_available())
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

Wed Sep 27 15:40:46 2023       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 537.42                 Driver Version: 537.42       CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                     TCC/WDDM  | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA GeForce RTX 3090      WDDM  | 00000000:2B:00.0  On |                  N/A |
|  0%   53C    P8              37W / 350W |    491MiB / 24576MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [2]:
# Argument
epochs = 100
T = 1000
batch_size = 32
learning_rate = 1e-3
out_path = 'results'
reverse_path = 'caisa_xt_T1000'
if not os.path.isdir(out_path):
    os.mkdir(out_path)
    
if not os.path.isdir(reverse_path):
    os.mkdir(reverse_path)


In [3]:
# Preprocessing
test_tfm = transforms.Compose([
    transforms.Resize([160,160]),
    transforms.ToTensor(),
])
#test_tfm = transforms.Compose([
#    transforms.ToTensor(),
#])


toPIL = transforms.ToPILImage()

In [4]:
# Casia
train_set = DatasetFolder("../../dataset/casia100_dataset/casia", loader=lambda x: Image.open(x), extensions="jpg", transform=test_tfm)
train_set_size = int(len(train_set) * 0.8)
train_set, test_set = torch.utils.data.random_split(train_set, [train_set_size, len(train_set) - train_set_size])

#train_set = torchvision.datasets.CIFAR10(root='../../dataset/cifar10_dataset', train=True, download=True, transform=test_tfm)
#test_set = torchvision.datasets.CIFAR10(root='../../dataset/cifar10_dataset', train=False, download=True, transform=test_tfm)

train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=0)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=0)

Debugger :  class -> idx 
{'0': 0, '1': 1, '2': 2, '3': 3, '4': 4, '5': 5, '6': 6, '7': 7, '8': 8, '9': 9, '10': 10, '11': 11, '12': 12, '13': 13, '14': 14, '15': 15, '16': 16, '17': 17, '18': 18, '19': 19, '20': 20, '21': 21, '22': 22, '23': 23, '24': 24, '25': 25, '26': 26, '27': 27, '28': 28, '29': 29, '30': 30, '31': 31, '32': 32, '33': 33, '34': 34, '35': 35, '36': 36, '37': 37, '38': 38, '39': 39, '40': 40, '41': 41, '42': 42, '43': 43, '44': 44, '45': 45, '46': 46, '47': 47, '48': 48, '49': 49, '50': 50, '51': 51, '52': 52, '53': 53, '54': 54, '55': 55, '56': 56, '57': 57, '58': 58, '59': 59, '60': 60, '61': 61, '62': 62, '63': 63, '64': 64, '65': 65, '66': 66, '67': 67, '68': 68, '69': 69, '70': 70, '71': 71, '72': 72, '73': 73, '74': 74, '75': 75, '76': 76, '77': 77, '78': 78, '79': 79, '80': 80, '81': 81, '82': 82, '83': 83, '84': 84, '85': 85, '86': 86, '87': 87, '88': 88, '89': 89, '90': 90, '91': 91, '92': 92, '93': 93, '94': 94, '95': 95, '96': 96, '97': 97, '98': 98, '99

In [5]:
forward = ForwardProcess(T)
reverse = Unet(3, 3).to(device)
#criterion = diceLoss(3)
#criterion = nn.L1Loss()
optimizer = torch.optim.Adam(reverse.parameters(), lr=learning_rate)

for epoch in range(epochs):
    reverse.train()
    train_loss = 0
    
    for batch in tqdm(train_loader):
        optimizer.zero_grad()
        imgs, labels = batch
        imgs, labels = imgs.to(device), labels.to(device)
        timesteps = torch.randint(0, T, (imgs.size(0), ))
        #print(timesteps.shape)
        results, noise = forward.NoisePredictor(imgs, timesteps)

        output = reverse(results, timesteps.to(device))
        loss = F.l1_loss(noise, output)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        #break
    print(f'Epoch : {epoch} | Train | loss = {train_loss}')

    reverse.eval()
    
    #img = torch.randn((1, 3, imgs.size(2), imgs.size(3)), device=device)
    timesteps = torch.full((1,), T-1, device='cpu', dtype=torch.long)
    img, noise = forward.NoisePredictor(imgs[0].unsqueeze(0), timesteps)
    #print(img.shape)
    for i in range(0,T)[::-1]:
        t = torch.full((1,), i, device='cpu', dtype=torch.long)
        with torch.no_grad():
            img = forward.output_sample(reverse, img, t)
        if i % 10 == 0:
            img_tmp = img.squeeze(0)
            img_tmp = toPIL(img_tmp)
            img_tmp.save(f'{reverse_path}/img_{epoch}_{i}.png')
        
    print(f'Epoch : {epoch} | Eval | Generate successfully !')
            

100%|██████████| 399/399 [02:16<00:00,  2.93it/s]


Epoch : 0 | Train | loss = 0.004857020519700131
Epoch : 0 | Eval | Generate successfully !


100%|██████████| 399/399 [02:12<00:00,  3.01it/s]


Epoch : 1 | Train | loss = 0.002842729029210234
Epoch : 1 | Eval | Generate successfully !


100%|██████████| 399/399 [02:12<00:00,  3.02it/s]


Epoch : 2 | Train | loss = 0.0022593253167552917
Epoch : 2 | Eval | Generate successfully !


100%|██████████| 399/399 [02:12<00:00,  3.01it/s]


Epoch : 3 | Train | loss = 0.002054003612954928
Epoch : 3 | Eval | Generate successfully !


100%|██████████| 399/399 [02:12<00:00,  3.01it/s]


Epoch : 4 | Train | loss = 0.0017996034928750599
Epoch : 4 | Eval | Generate successfully !


100%|██████████| 399/399 [02:12<00:00,  3.02it/s]


Epoch : 5 | Train | loss = 0.0017185998432376235
Epoch : 5 | Eval | Generate successfully !


100%|██████████| 399/399 [02:12<00:00,  3.02it/s]


Epoch : 6 | Train | loss = 0.0015855832347663162
Epoch : 6 | Eval | Generate successfully !


100%|██████████| 399/399 [02:12<00:00,  3.01it/s]


Epoch : 7 | Train | loss = 0.0015397293029134533
Epoch : 7 | Eval | Generate successfully !


100%|██████████| 399/399 [02:12<00:00,  3.02it/s]


Epoch : 8 | Train | loss = 0.0015040662553250533
Epoch : 8 | Eval | Generate successfully !


100%|██████████| 399/399 [02:12<00:00,  3.02it/s]


Epoch : 9 | Train | loss = 0.0014270426212233064
Epoch : 9 | Eval | Generate successfully !


100%|██████████| 399/399 [02:12<00:00,  3.02it/s]


Epoch : 10 | Train | loss = 0.0014112620044860618
Epoch : 10 | Eval | Generate successfully !


100%|██████████| 399/399 [02:12<00:00,  3.01it/s]


Epoch : 11 | Train | loss = 0.0013823718545763293
Epoch : 11 | Eval | Generate successfully !


100%|██████████| 399/399 [02:12<00:00,  3.02it/s]


Epoch : 12 | Train | loss = 0.0013825337315199655
Epoch : 12 | Eval | Generate successfully !


100%|██████████| 399/399 [02:12<00:00,  3.02it/s]


Epoch : 13 | Train | loss = 0.0013662610670698535
Epoch : 13 | Eval | Generate successfully !


100%|██████████| 399/399 [02:12<00:00,  3.02it/s]


Epoch : 14 | Train | loss = 0.0013207123386492772
Epoch : 14 | Eval | Generate successfully !


100%|██████████| 399/399 [02:12<00:00,  3.02it/s]


Epoch : 15 | Train | loss = 0.001305319087506042
Epoch : 15 | Eval | Generate successfully !


100%|██████████| 399/399 [02:12<00:00,  3.02it/s]


Epoch : 16 | Train | loss = 0.001333378666771445
Epoch : 16 | Eval | Generate successfully !


100%|██████████| 399/399 [02:12<00:00,  3.02it/s]


Epoch : 17 | Train | loss = 0.0013018681663726067
Epoch : 17 | Eval | Generate successfully !


100%|██████████| 399/399 [02:12<00:00,  3.02it/s]


Epoch : 18 | Train | loss = 0.0012596271151699427
Epoch : 18 | Eval | Generate successfully !


100%|██████████| 399/399 [02:12<00:00,  3.02it/s]


Epoch : 19 | Train | loss = 0.0012898358417380413
Epoch : 19 | Eval | Generate successfully !


100%|██████████| 399/399 [02:12<00:00,  3.02it/s]


Epoch : 20 | Train | loss = 0.0012496573390879786
Epoch : 20 | Eval | Generate successfully !


100%|██████████| 399/399 [02:12<00:00,  3.02it/s]


Epoch : 21 | Train | loss = 0.0012178278128076132
Epoch : 21 | Eval | Generate successfully !


100%|██████████| 399/399 [02:12<00:00,  3.02it/s]


Epoch : 22 | Train | loss = 0.0012456715802624443
Epoch : 22 | Eval | Generate successfully !


100%|██████████| 399/399 [02:12<00:00,  3.02it/s]


Epoch : 23 | Train | loss = 0.0012533831445550676
Epoch : 23 | Eval | Generate successfully !


100%|██████████| 399/399 [02:12<00:00,  3.01it/s]


Epoch : 24 | Train | loss = 0.0012235568780064722
Epoch : 24 | Eval | Generate successfully !


100%|██████████| 399/399 [02:12<00:00,  3.01it/s]


Epoch : 25 | Train | loss = 0.001202397528128233
Epoch : 25 | Eval | Generate successfully !


100%|██████████| 399/399 [02:12<00:00,  3.02it/s]


Epoch : 26 | Train | loss = 0.0012062733458766473
Epoch : 26 | Eval | Generate successfully !


100%|██████████| 399/399 [02:12<00:00,  3.01it/s]


Epoch : 27 | Train | loss = 0.0012061455817187988
Epoch : 27 | Eval | Generate successfully !


100%|██████████| 399/399 [02:12<00:00,  3.02it/s]


Epoch : 28 | Train | loss = 0.0011871774169032179
Epoch : 28 | Eval | Generate successfully !


100%|██████████| 399/399 [02:12<00:00,  3.02it/s]


Epoch : 29 | Train | loss = 0.0012030244290115507
Epoch : 29 | Eval | Generate successfully !


100%|██████████| 399/399 [02:12<00:00,  3.02it/s]


Epoch : 30 | Train | loss = 0.0011870750764997953
Epoch : 30 | Eval | Generate successfully !


100%|██████████| 399/399 [02:12<00:00,  3.02it/s]


Epoch : 31 | Train | loss = 0.0011705540043143675
Epoch : 31 | Eval | Generate successfully !


100%|██████████| 399/399 [02:12<00:00,  3.02it/s]


Epoch : 32 | Train | loss = 0.001177149322563501
Epoch : 32 | Eval | Generate successfully !


100%|██████████| 399/399 [02:12<00:00,  3.02it/s]


Epoch : 33 | Train | loss = 0.001194202232771221
Epoch : 33 | Eval | Generate successfully !


100%|██████████| 399/399 [02:12<00:00,  3.02it/s]


Epoch : 34 | Train | loss = 0.0011691812140864923
Epoch : 34 | Eval | Generate successfully !


100%|██████████| 399/399 [02:12<00:00,  3.02it/s]


Epoch : 35 | Train | loss = 0.001152459570605318
Epoch : 35 | Eval | Generate successfully !


100%|██████████| 399/399 [02:12<00:00,  3.02it/s]


Epoch : 36 | Train | loss = 0.0011756568793433934
Epoch : 36 | Eval | Generate successfully !


100%|██████████| 399/399 [02:12<00:00,  3.02it/s]


Epoch : 37 | Train | loss = 0.0011559870830642424
Epoch : 37 | Eval | Generate successfully !


100%|██████████| 399/399 [02:12<00:00,  3.01it/s]


Epoch : 38 | Train | loss = 0.0011303971269087214
Epoch : 38 | Eval | Generate successfully !


100%|██████████| 399/399 [02:12<00:00,  3.02it/s]


Epoch : 39 | Train | loss = 0.0011608116489478118
Epoch : 39 | Eval | Generate successfully !


100%|██████████| 399/399 [02:12<00:00,  3.02it/s]


Epoch : 40 | Train | loss = 0.0011341678974627514
Epoch : 40 | Eval | Generate successfully !


100%|██████████| 399/399 [02:13<00:00,  2.99it/s]


Epoch : 41 | Train | loss = 0.0011356477664762406
Epoch : 41 | Eval | Generate successfully !


100%|██████████| 399/399 [02:12<00:00,  3.00it/s]


Epoch : 42 | Train | loss = 0.0011398969642692288
Epoch : 42 | Eval | Generate successfully !


100%|██████████| 399/399 [02:12<00:00,  3.02it/s]


Epoch : 43 | Train | loss = 0.0011540553769664513
Epoch : 43 | Eval | Generate successfully !


100%|██████████| 399/399 [02:12<00:00,  3.02it/s]


Epoch : 44 | Train | loss = 0.0011223927157773556
Epoch : 44 | Eval | Generate successfully !


100%|██████████| 399/399 [02:12<00:00,  3.02it/s]


Epoch : 45 | Train | loss = 0.0011051743649744248
Epoch : 45 | Eval | Generate successfully !


100%|██████████| 399/399 [02:12<00:00,  3.02it/s]


Epoch : 46 | Train | loss = 0.0011247876653043651
Epoch : 46 | Eval | Generate successfully !


100%|██████████| 399/399 [02:12<00:00,  3.02it/s]


Epoch : 47 | Train | loss = 0.0011279576327600587
Epoch : 47 | Eval | Generate successfully !


100%|██████████| 399/399 [02:12<00:00,  3.02it/s]


Epoch : 48 | Train | loss = 0.0011461225067542645
Epoch : 48 | Eval | Generate successfully !


100%|██████████| 399/399 [02:12<00:00,  3.02it/s]


Epoch : 49 | Train | loss = 0.0011265367238825095
Epoch : 49 | Eval | Generate successfully !


100%|██████████| 399/399 [02:12<00:00,  3.02it/s]


Epoch : 50 | Train | loss = 0.0011085805830627201
Epoch : 50 | Eval | Generate successfully !


100%|██████████| 399/399 [02:12<00:00,  3.02it/s]


Epoch : 51 | Train | loss = 0.0011132382961313787
Epoch : 51 | Eval | Generate successfully !


100%|██████████| 399/399 [02:12<00:00,  3.02it/s]


Epoch : 52 | Train | loss = 0.0011305181910976235
Epoch : 52 | Eval | Generate successfully !


100%|██████████| 399/399 [02:12<00:00,  3.02it/s]


Epoch : 53 | Train | loss = 0.0011219320180849638
Epoch : 53 | Eval | Generate successfully !


100%|██████████| 399/399 [02:12<00:00,  3.02it/s]


Epoch : 54 | Train | loss = 0.0011159511687136856
Epoch : 54 | Eval | Generate successfully !


100%|██████████| 399/399 [02:12<00:00,  3.02it/s]


Epoch : 55 | Train | loss = 0.0011295424130511032
Epoch : 55 | Eval | Generate successfully !


100%|██████████| 399/399 [02:12<00:00,  3.02it/s]


Epoch : 56 | Train | loss = 0.0011237405693591833
Epoch : 56 | Eval | Generate successfully !


100%|██████████| 399/399 [02:12<00:00,  3.02it/s]


Epoch : 57 | Train | loss = 0.0011119947011340315
Epoch : 57 | Eval | Generate successfully !


100%|██████████| 399/399 [02:12<00:00,  3.02it/s]


Epoch : 58 | Train | loss = 0.0010780391234739381
Epoch : 58 | Eval | Generate successfully !


100%|██████████| 399/399 [02:12<00:00,  3.01it/s]


Epoch : 59 | Train | loss = 0.001093075061549435
Epoch : 59 | Eval | Generate successfully !


100%|██████████| 399/399 [02:12<00:00,  3.02it/s]


Epoch : 60 | Train | loss = 0.0011062275537933822
Epoch : 60 | Eval | Generate successfully !


100%|██████████| 399/399 [02:12<00:00,  3.02it/s]


Epoch : 61 | Train | loss = 0.0010934082110894338
Epoch : 61 | Eval | Generate successfully !


100%|██████████| 399/399 [02:12<00:00,  3.02it/s]


Epoch : 62 | Train | loss = 0.0011143024268396792
Epoch : 62 | Eval | Generate successfully !


100%|██████████| 399/399 [02:12<00:00,  3.02it/s]


Epoch : 63 | Train | loss = 0.0010791670859865508
Epoch : 63 | Eval | Generate successfully !


100%|██████████| 399/399 [02:12<00:00,  3.02it/s]


Epoch : 64 | Train | loss = 0.001105709012382963
Epoch : 64 | Eval | Generate successfully !


100%|██████████| 399/399 [02:12<00:00,  3.01it/s]


Epoch : 65 | Train | loss = 0.0011057322490456152
Epoch : 65 | Eval | Generate successfully !


100%|██████████| 399/399 [02:12<00:00,  3.02it/s]


Epoch : 66 | Train | loss = 0.001093350289474987
Epoch : 66 | Eval | Generate successfully !


100%|██████████| 399/399 [02:12<00:00,  3.01it/s]


Epoch : 67 | Train | loss = 0.0011089469300690424
Epoch : 67 | Eval | Generate successfully !


100%|██████████| 399/399 [02:12<00:00,  3.01it/s]


Epoch : 68 | Train | loss = 0.001084310249832856
Epoch : 68 | Eval | Generate successfully !


100%|██████████| 399/399 [02:12<00:00,  3.02it/s]


Epoch : 69 | Train | loss = 0.00108387236570718
Epoch : 69 | Eval | Generate successfully !


100%|██████████| 399/399 [02:12<00:00,  3.01it/s]


Epoch : 70 | Train | loss = 0.0011070109573519899
Epoch : 70 | Eval | Generate successfully !


100%|██████████| 399/399 [02:12<00:00,  3.01it/s]


Epoch : 71 | Train | loss = 0.0010922256332488564
Epoch : 71 | Eval | Generate successfully !


100%|██████████| 399/399 [02:12<00:00,  3.01it/s]


Epoch : 72 | Train | loss = 0.0010824902043792958
Epoch : 72 | Eval | Generate successfully !


100%|██████████| 399/399 [02:12<00:00,  3.02it/s]


Epoch : 73 | Train | loss = 0.0010911466509232384
Epoch : 73 | Eval | Generate successfully !


100%|██████████| 399/399 [02:12<00:00,  3.02it/s]


Epoch : 74 | Train | loss = 0.001075221363436058
Epoch : 74 | Eval | Generate successfully !


100%|██████████| 399/399 [02:12<00:00,  3.01it/s]


Epoch : 75 | Train | loss = 0.0010695405972324292
Epoch : 75 | Eval | Generate successfully !


100%|██████████| 399/399 [02:12<00:00,  3.01it/s]


Epoch : 76 | Train | loss = 0.0010741807396164498
Epoch : 76 | Eval | Generate successfully !


100%|██████████| 399/399 [02:12<00:00,  3.02it/s]


Epoch : 77 | Train | loss = 0.0010788130040856707
Epoch : 77 | Eval | Generate successfully !


100%|██████████| 399/399 [02:12<00:00,  3.01it/s]


Epoch : 78 | Train | loss = 0.0010726105396936527
Epoch : 78 | Eval | Generate successfully !


100%|██████████| 399/399 [02:12<00:00,  3.01it/s]


Epoch : 79 | Train | loss = 0.0010899009948129419
Epoch : 79 | Eval | Generate successfully !


100%|██████████| 399/399 [02:12<00:00,  3.02it/s]


Epoch : 80 | Train | loss = 0.001105032469558454
Epoch : 80 | Eval | Generate successfully !


100%|██████████| 399/399 [02:12<00:00,  3.02it/s]


Epoch : 81 | Train | loss = 0.0010997055176603136
Epoch : 81 | Eval | Generate successfully !


100%|██████████| 399/399 [02:12<00:00,  3.02it/s]


Epoch : 82 | Train | loss = 0.001059486871463217
Epoch : 82 | Eval | Generate successfully !


100%|██████████| 399/399 [02:12<00:00,  3.02it/s]


Epoch : 83 | Train | loss = 0.0010903674404718867
Epoch : 83 | Eval | Generate successfully !


100%|██████████| 399/399 [02:12<00:00,  3.02it/s]


Epoch : 84 | Train | loss = 0.0010726632305562334
Epoch : 84 | Eval | Generate successfully !


100%|██████████| 399/399 [02:12<00:00,  3.02it/s]


Epoch : 85 | Train | loss = 0.0010867233984771546
Epoch : 85 | Eval | Generate successfully !


100%|██████████| 399/399 [02:12<00:00,  3.02it/s]


Epoch : 86 | Train | loss = 0.0010643106291531675
Epoch : 86 | Eval | Generate successfully !


100%|██████████| 399/399 [02:12<00:00,  3.02it/s]


Epoch : 87 | Train | loss = 0.001068279028330517
Epoch : 87 | Eval | Generate successfully !


100%|██████████| 399/399 [02:12<00:00,  3.02it/s]


Epoch : 88 | Train | loss = 0.0010787662516193513
Epoch : 88 | Eval | Generate successfully !


100%|██████████| 399/399 [02:12<00:00,  3.02it/s]


Epoch : 89 | Train | loss = 0.0010646859582158434
Epoch : 89 | Eval | Generate successfully !


100%|██████████| 399/399 [02:12<00:00,  3.01it/s]


Epoch : 90 | Train | loss = 0.0010489207985834685
Epoch : 90 | Eval | Generate successfully !


100%|██████████| 399/399 [02:12<00:00,  3.01it/s]


Epoch : 91 | Train | loss = 0.0010726282234961681
Epoch : 91 | Eval | Generate successfully !


100%|██████████| 399/399 [02:12<00:00,  3.00it/s]


Epoch : 92 | Train | loss = 0.0010651497236826037
Epoch : 92 | Eval | Generate successfully !


100%|██████████| 399/399 [02:12<00:00,  3.01it/s]


Epoch : 93 | Train | loss = 0.0010654120102799513
Epoch : 93 | Eval | Generate successfully !


100%|██████████| 399/399 [02:12<00:00,  3.01it/s]


Epoch : 94 | Train | loss = 0.0010488315209565139
Epoch : 94 | Eval | Generate successfully !


100%|██████████| 399/399 [02:12<00:00,  3.01it/s]


Epoch : 95 | Train | loss = 0.0010680849580974754
Epoch : 95 | Eval | Generate successfully !


100%|██████████| 399/399 [02:12<00:00,  3.01it/s]


Epoch : 96 | Train | loss = 0.0010568070591150803
Epoch : 96 | Eval | Generate successfully !


100%|██████████| 399/399 [02:12<00:00,  3.01it/s]


Epoch : 97 | Train | loss = 0.0010515299428030761
Epoch : 97 | Eval | Generate successfully !


100%|██████████| 399/399 [02:12<00:00,  3.01it/s]


Epoch : 98 | Train | loss = 0.0010408673881198342
Epoch : 98 | Eval | Generate successfully !


100%|██████████| 399/399 [02:12<00:00,  3.01it/s]


Epoch : 99 | Train | loss = 0.0010705519268832146
Epoch : 99 | Eval | Generate successfully !


In [6]:
'''
for batch in tqdm(test_loader):
        imgs, labels = batch
        imgs, labels = imgs.to(device), labels.to(device)
        #timesteps = torch.randint(0, T, (imgs.size(0), ))
        timesteps = torch.full((imgs.size(0),), T-1, device='cpu', dtype=torch.long)
        #print(timesteps)
        results, noise = forward.NoisePredictor(imgs, timesteps)
        
        img = results[0].unsqueeze(0)
        #print(img.shape)
        for i in range(0,T)[::-1]:
            t = torch.full((1,), i, device='cpu', dtype=torch.long)
            with torch.no_grad():
                img = forward.output_sample(reverse, img, t)
            #img = torch.clamp(img, -1.0, 1.0)
            if i % 10 == 0:
                img_tmp = img.squeeze(0)
                img_tmp = toPIL(img_tmp)
                img_tmp.save(f'{reverse_path}/img_{epoch}_{i}.png')
                
'''


"\nfor batch in tqdm(test_loader):\n        imgs, labels = batch\n        imgs, labels = imgs.to(device), labels.to(device)\n        #timesteps = torch.randint(0, T, (imgs.size(0), ))\n        timesteps = torch.full((imgs.size(0),), T-1, device='cpu', dtype=torch.long)\n        #print(timesteps)\n        results, noise = forward.NoisePredictor(imgs, timesteps)\n        \n        img = results[0].unsqueeze(0)\n        #print(img.shape)\n        for i in range(0,T)[::-1]:\n            t = torch.full((1,), i, device='cpu', dtype=torch.long)\n            with torch.no_grad():\n                img = forward.output_sample(reverse, img, t)\n            #img = torch.clamp(img, -1.0, 1.0)\n            if i % 10 == 0:\n                img_tmp = img.squeeze(0)\n                img_tmp = toPIL(img_tmp)\n                img_tmp.save(f'{reverse_path}/img_{epoch}_{i}.png')\n                \n"