In [1]:
from unet import UNet
import matplotlib.pyplot as plt
from diffusion import Diffusion
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from tqdm import tqdm
from PIL import Image
import torch.nn as nn
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
transform = transforms.Compose([
    transforms.ToTensor(),
])

In [3]:
from torch.utils.data import DataLoader

In [4]:
train_dataset = datasets.CIFAR10(root = 'data/', train = True, download = True, transform = transform)
test_dataset = datasets.CIFAR10(root = 'data/', train = False, download = True, transform = transform)

Files already downloaded and verified
Files already downloaded and verified


In [5]:
train_loader = DataLoader(train_dataset, batch_size = 16, shuffle = True)
test_loader = DataLoader(test_dataset, batch_size = 16, shuffle = False)

In [6]:
def save_images(images, path, **kwargs):
    grid = torchvision.utils.make_grid(images, **kwargs)
    ndarr = grid.permute(1, 2, 0).to('cpu').numpy()
    im = Image.fromarray(ndarr)
    im.save(path)

In [7]:
import torch.optim as optim

In [8]:
def train(
    device = 'cpu',
    lr = 1e-4,
    img_size = 32,
    epochs = 10,
    
):
    
    model = UNet().to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.MSELoss()
    diffusion = Diffusion(img_size=img_size, device=device)
    for epoch in range(epochs):
        with tqdm(train_loader, unit="batch") as tepoch:
            for images, label in tepoch:
                tepoch.set_description(f"Epoch {epoch+1}")
                images = images.to(device)
                t = diffusion.sample_timesteps(images.shape[0]).to(device)
                x_t, noise = diffusion.sample_noise_image(images, t)
                predicted_noise = model(x_t, t)
                loss = criterion(noise, predicted_noise)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                tepoch.set_postfix(loss=loss.item())
        
        sampled_images = diffusion.sample_initial_image(model, images.shape[0])
        save_images(sampled_images, path = './runs/epoch{}_sampled_images.jpg'.format(epoch+1))   

In [9]:
img_size = 32

In [10]:
device = 'cpu'

In [11]:
images, labels = next(iter(train_loader))

In [12]:
model = UNet().to(device)
diffusion = Diffusion(img_size=img_size, device=device)

In [13]:
t = diffusion.sample_timesteps(images.shape[0]).to(device)

In [14]:
x_t, noise = diffusion.sample_noise_image(images, t)

In [15]:
predicted_noise = model(x_t, t)

In [None]:
train()

Epoch 1: 100%|████████████| 3125/3125 [1:38:15<00:00,  1.89s/batch, loss=0.0197]
