In [41]:
import torch as tch
import torch.nn as nn
import torchvision
import numpy as np
import matplotlib.pyplot as plt
from toolbox import disp, disp_loss


In [42]:
# device = 'cpu'
# device = 'cuda'
# batch_size=
device = tch.device("cuda" if tch.cuda.is_available() else "cpu")

In [43]:
# input will be array in range[-1,1]
transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.5,), (0.5,))
])
# should I augment the data?


train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
# Restricting to only ones.
label_mask = train_dataset.targets == 1
train_dataset.data = train_dataset.data[label_mask]
train_dataset.targets = train_dataset.targets[label_mask]

# subset_size = 500  # Choose the desired subset size max size is 50000
# train_subset = tch.utils.data.Subset(train_dataset, range(subset_size))
# test_subset = tch.utils.data.Subset(test_dataset, range(subset_size))


batch_size = 100
train_loader = tch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = tch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [44]:
n_directions = 6
T = 1000
beta_1 = 10**-4
beta_T = 10**-2
beta_1_tensor = tch.tensor(beta_1).to(device)
height = 28
width = 28
# list containing \bar{alpha_t}
betas = tch.linspace(beta_1, beta_T, T, device=device)  # Linear schedule
alphas = 1 - betas
alphas_cumprod = tch.cumprod(alphas, dim=0)  # Cumulative product of alphas
diffusion_scheduler = alphas_cumprod

In [45]:

class Network(nn.Module):
    def __init__(self):
        super(Network, self).__init__()

        self.flatten = nn.Flatten()
        self.dense1 = nn.Linear(784, 784)
        self.reshape1 = nn.Unflatten(1, (1, 28, 28))
        self.dense2 = nn.Linear(784, 7 * 7 * 64)
        self.reshape2 = nn.Unflatten(1, (64, 7, 7))
        self.conv_transpose1 = nn.ConvTranspose2d(64, 64, 3, stride=2, padding=1, output_padding=1)
        self.conv_transpose2 = nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1)
        self.conv_transpose3 = nn.ConvTranspose2d(32, 1, 3, padding=1)
        self.relu = nn.ReLU()

    def forward(self, x, t):
        combined = x + t
        x = self.flatten(combined)
        x = self.dense1(x)
        x1 = self.reshape1(x)

        x = self.relu(self.dense2(x))
        x = self.reshape2(x)
        x = self.relu(self.conv_transpose1(x))
        x = self.relu(self.conv_transpose2(x))
        x = self.conv_transpose3(x)

        output = x + x1
        return output


In [46]:
# modifiedUnet = modifiedUnet(
#     in_channels=1,
#     out_channels=1,
#     time_embedding_dimension=32,
#     direction_embedding_dimension=32,
#     n_classes=10).to(device)
# #weights, loss function, optimizer
modifiedUnet = Network()

# Instantiate the model
# optimizer and scheduler
optimizer = tch.optim.AdamW(modifiedUnet.parameters(), lr=1e-3,
                            weight_decay=1e-4)
scheduler = tch.optim.lr_scheduler.StepLR(optimizer, step_size=10000,
                                          gamma=0.2)
loss_func = tch.nn.MSELoss(reduction='mean')


total_params = sum(p.numel() for p in modifiedUnet.parameters())

print(f"modifiedUnet has {total_params:,} parameters.")


modifiedUnet has 3,132,881 parameters.


In [47]:
import time

def estimate_remaining_time(start_time, current_epoch, total_epochs):
    """Estimates the remaining training time.

    Args:
        start_time: The start time of the training process.
        current_epoch: The current epoch number.
        total_epochs: The total number of epochs.

    Returns:
        None
    """
    elapsed_time = time.time() - start_time
    time_per_epoch = elapsed_time / (current_epoch + 1)
    remaining_time = time_per_epoch * (total_epochs - current_epoch - 1)
    remaining_hours = int(remaining_time // 3600)
    remaining_minutes = int((remaining_time % 3600) // 60)
    remaining_seconds = int(remaining_time % 60)
    print(f"Estimated remaining time: {remaining_hours:02d}:{remaining_minutes:02d}:{remaining_seconds:02d}")

In [48]:
def normalize_data(images):
    # for each image take a maximu of absolute values. Look at channels, height and width
    # thus each image gets scaled individually
    max_vals = tch.amax(tch.abs(images), dim=(1, 2, 3), keepdim=True)
    # images are between [-1,1]
    images = images/max_vals
    return images

In [49]:
#training
running_loss = 0.0
epoch_loss_ = 0.0
epoch_loss = 0.0
n_epoch = 100

start_time = time.time()
for epoch in range(n_epoch):
    i = 0
    for data in train_loader:
        ###### COMPLETER ICI ######

        loss_val = 0 # requis aux lignes suivantes
        images,labels = data
        images, labels = images.to(device), labels.to(device)
        current_batch_size = images.shape[0]
        timesteps = tch.randint(1, T, size=(current_batch_size,), device=device)  # Move timesteps to the device
        # generating a batch of random noise
        timesteps = timesteps.view(-1, 1, 1, 1).to(device)
        noise = tch.randn(current_batch_size, 1, height, width, device=device)
        # alpha_bar = diffusion_scheduler[timesteps].view(-1,1,1,1).to(device)
        # noised_image = tch.sqrt(alpha_bar)*images+tch.sqrt(1-alpha_bar)*noise
        # predicted_noise = modifiedUnet(noised_image,labels, timesteps)
        noised_image = images+timesteps*noise
        noised_image = normalize_data(noised_image)
        predicted_noise = modifiedUnet(noised_image, timesteps)
        loss_val = loss_func(noise, predicted_noise)
        loss_image = loss_func(images,noised_image-predicted_noise )
        ## Gradient calculation
        loss = loss_val+loss_image
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        #####

        running_loss += loss_val.item()
        epoch_loss += loss_val.item()
    print(f"epoch= {epoch}", end="\r", flush=True)
    if epoch % 2 == 0:    # every 100 epoch...
        disp_loss(epoch_loss, epoch)
        estimate_remaining_time(start_time, epoch, n_epoch)

    i = i+1
    epoch_loss = 0.0
    scheduler.step()
print("Finished training")
#save the model weights after training
tch.save(modifiedUnet.state_dict(), 'model_weights.pth')


Epoch: 0 -- Total loss: 520519.47044
Estimated remaining time: 00:28:14


KeyboardInterrupt: 

In [None]:
# loading model from weights
state_dict = tch.load("model_weights.pth")

# Load the state dictionary into the model
modifiedUnet.load_state_dict(state_dict)

modifiedUnet = modifiedUnet.to(device)

# Set the model to evaluation mode
modifiedUnet.eval()

In [None]:

# how to generate image from noise
# algorith from DDPM paper
# def generate_image(n_images, labels, noise_predictor, device=device):
#     labels = tch.tensor(labels).to(device)
#     images = tch.randn(n_images, 1, height, width).to(device)
#     images = normalize_data(images)
#     i = T - 1
#     while i >= 0:
#         times = tch.tensor([i] * n_images).to(device)
#         alpha_bar = diffusion_scheduler[i].to(device)
#         if i > 0:
#             noise = tch.randn(n_images, 1, height, width).to(device)
#         else:
#             noise = images * 0
#         images = (1 / tch.sqrt(alphas[i])) * (images - (1 - alphas[i]) / tch.sqrt(1 - diffusion_scheduler[i]) * noise_predictor(images, labels, times)) + tch.sqrt(betas[i]) * noise
#         images = normalize_data(images)
#         i = i - 1  # Decrement i for the next iteration
#     return images.detach().cpu().numpy()


def generate_image(n_images, labels, noise_predictor, device=device):
    labels = tch.tensor(labels).to(device)
    images = tch.randn(n_images, 1, height, width).to(device)
    images = normalize_data(images)
    times = tch.tensor([T] * n_images).to(device)
    images = images-noise_predictor(images, times)
    return images.detach().cpu().numpy()

In [None]:
images = generate_image(n_images=6, labels=[1,2,3,4,5,6], noise_predictor=modifiedUnet).squeeze(1)
disp(images, shape = (1,6), scale=1)