In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import torch

import torchvision
from torchvision.utils import make_grid
from DDPM import *
from UNet import *

In [7]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

modelPath = 'DiffusionData\DeepConditional1BigData-g_ep-150_BS-128_ts-500_bt-(0.0001, 0.02)\model.pth'

print('Loading model...')
numClasses = 16
ddpm = DDPM(model=UNetDeepFullyConditional(numClasses=numClasses), betas=(1e-4, 0.02), numTimesteps=500, dropoutRate=0.4, device=device, numClasses=numClasses)

# Load individual models if we need to
ddpm.load_state_dict(torch.load(modelPath))
ddpm = ddpm.to(device)

rows = 4

print('Generating samples...')
ddpm.eval()
with torch.no_grad():
    generatedSamples, storedSamples = ddpm.sample(rows**2, (1, 28, 28), classifierGuidance=2, classLabels=torch.Tensor([1]).to(torch.int64).to(device))


frames = []
for frame in storedSamples:
    
    imGrid = make_grid(frame, nrow=rows, normalize=True)
    img = transforms.ToPILImage()(imGrid)
    frames.append(img)

frames[0].save('./samples/sample.gif', save_all=True, append_images=frames[1:], duration=100, loop=0)

print(generatedSamples.shape)

generatedSamples[generatedSamples < 0.25] = 0

grid = make_grid(generatedSamples, nrow=rows, normalize=True)
img = torchvision.transforms.ToPILImage()(grid)
img.show()

Loading model...
Generating samples...
torch.Size([16, 1, 28, 28])
