In [1]:
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torchvision
import matplotlib.pyplot as plt
import numpy as np
from diffusion import DiffusionPipeline
from diffusion.unet import UNet
import os
from diffusion.config import Config
import time
import json

In [2]:
def train(dataloader: DataLoader, pipe: DiffusionPipeline, optimizer):
    losses = []
    size = len(dataloader.dataset)
    pipe.unet.train()
    for batch, (X, y) in enumerate(dataloader):
        X= X.cuda()
        loss = pipe.predict_eps(X)
        # backpropagation
        optimizer.zero_grad()
        loss.backward()
        losses.append(loss.item())
        optimizer.step()
        # print loss
        if batch % 10 == 0:
            loss, current = loss.item(), (batch + 1) * len(X)
            print(f'loss: {loss:>7f}, [{current:>5d}/{size:>5d}]', end='\r')
    return sum(losses) / len(losses)

In [3]:
bs=128
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5,), (0.5,))])

trainset = datasets.MNIST('./F_MNIST_data/', download=True, train=True, transform=transform)
testset = datasets.MNIST('./F_MNIST_data/', download=True, train=False, transform=transform)
trainloader = DataLoader(trainset, batch_size=bs, shuffle=True)
testloader = DataLoader(testset, batch_size=bs, shuffle=True)



In [4]:
config = Config(
    hidden_dim_list=(16, 32, 64)
)


pipe = DiffusionPipeline(device='cuda', config=config)
model = pipe.unet

epochs = 100
learning_rate = 1e-3
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.7)
timestamp = time.strftime("%Y%m%d-%H%M%S")
# create the folder if not exists
if not os.path.exists(f'./exp_{timestamp}'):
    os.makedirs(f'./exp_{timestamp}')
# save config in json
with open(f'./exp_{timestamp}/config.json', 'w') as f:
    json.dump(config.__dict__, f, indent=2)
    
losses = []
for t in range(epochs):
    print(f'\nEpoch {t+1}-------------------------------')
    loss = train(trainloader, pipe, optimizer)
    losses.append(loss)
    scheduler.step()
    # sample(model, t+1, n_classes)
    if (t+1) % 10 == 0:
        torch.save(model.state_dict(), f'./exp_{timestamp}/checkpoint_{t+1}.pth')
        
# plot losses line chart
plt.plot(losses)


Epoch 1-------------------------------
loss: 0.111089, [59008/60000]
Epoch 2-------------------------------
loss: 0.088250, [59008/60000]
Epoch 3-------------------------------
loss: 0.092896, [59008/60000]
Epoch 4-------------------------------
loss: 0.054912, [59008/60000]
Epoch 5-------------------------------
loss: 0.071897, [59008/60000]
Epoch 6-------------------------------
loss: 0.068523, [59008/60000]
Epoch 7-------------------------------
loss: 0.071238, [59008/60000]
Epoch 8-------------------------------
loss: 0.061659, [59008/60000]
Epoch 9-------------------------------
loss: 0.050357, [59008/60000]
Epoch 10-------------------------------
loss: 0.056092, [59008/60000]
Epoch 11-------------------------------
loss: 0.040676, [59008/60000]
Epoch 12-------------------------------
loss: 0.063934, [59008/60000]
Epoch 13-------------------------------
loss: 0.066684, [59008/60000]
Epoch 14-------------------------------
loss: 0.051242, [59008/60000]
Epoch 15--------------------

KeyboardInterrupt: 