In [1]:
import os
import torch
import torchvision
from PIL import Image
from matplotlib import pyplot as plt
from torch.utils.data import DataLoader
from PIL import Image
from torch.utils.data import Dataset
import json
from os.path import exists
import copy
import numpy as np
import torch.nn as nn
from tqdm import tqdm
from torch import optim
import logging
import random
import torch.nn.functional as F

# Setting reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

<torch._C.Generator at 0x2b017ec31b0>

In [2]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
NOISE_STEPS = 1000 
BETA_START = 1e-4
BETA_END = 0.02
IMAGE_CHW = (3, 64, 64)
NUM_CLASSES = 22
EMA_BETA = 0.995

MODEL_PATH = "./models"
OUTPUT_PATH = "./output"
CLASS_LEGEND_FILE = "legenda_classi.json"
TEST_FILE = "test.txt"
MODEL_FILE = "ckpt.pt"
EMA_MODEL_FILE = "ema_ckpt.pt"

In [3]:
def save_images(images, path, **kwargs):
    grid = torchvision.utils.make_grid(images, **kwargs)
    ndarr = grid.permute(1, 2, 0).to('cpu').numpy()
    im = Image.fromarray(ndarr)
    im.save(path)

In [4]:
def sample_500(ddpm, ema=True, images_per_class=10):
  if not os.path.exists(CLASS_LEGEND_FILE):
    print("File labels non trovato!")
    exit(1)

  os.makedirs(OUTPUT_PATH, exist_ok=True)

  with open(CLASS_LEGEND_FILE) as json_file:
    legend_labels = json.load(json_file)

  letters = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J']
  with open(TEST_FILE, 'r') as f:
    lines = f.readlines()
    for line in tqdm(lines, position=0):
        image_name, label = line.strip().split(';')
        print("Classe: ", label)
        label_number = legend_labels[str(int(label, 2))]
        labels = (torch.ones(images_per_class)*label_number).long().to(DEVICE)
        sampled_images = ddpm.sample(n=len(labels), labels=labels, ema=ema)
        for i, image in tqdm(enumerate(sampled_images), position=0):
            image_name_output = image_name + "_" + letters[i]
            file_img = os.path.join(OUTPUT_PATH, f"{image_name_output}.jpg")
            save_images(image, file_img)

In [5]:
class EMA:
    def __init__(self):
        super().__init__()
        self.step = 0

    def update_model_average(self, ma_model, current_model):
        for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
            old_weight, up_weight = ma_params.data, current_params.data
            ma_params.data = self.update_average(old_weight, up_weight)

    def update_average(self, old, new):
        if old is None:
            return new
        return old * EMA_BETA + (1 - EMA_BETA) * new

    def step_ema(self, ema_model, model, step_start_ema=2000):
        if self.step < step_start_ema:
            self.reset_parameters(ema_model, model)
            self.step += 1
            return
        self.update_model_average(ema_model, model)
        self.step += 1

    def reset_parameters(self, ema_model, model):
        ema_model.load_state_dict(model.state_dict())


class SelfAttention(nn.Module):
    def __init__(self, channels, size):
        super(SelfAttention, self).__init__()
        self.channels = channels
        self.size = size
        self.mha = nn.MultiheadAttention(channels, 4, batch_first=True)
        self.ln = nn.LayerNorm([channels])
        self.ff_self = nn.Sequential(
            nn.LayerNorm([channels]),
            nn.Linear(channels, channels),
            nn.GELU(),
            nn.Linear(channels, channels),
        )

    def forward(self, x):
        x = x.view(-1, self.channels, self.size * self.size).swapaxes(1, 2)
        x_ln = self.ln(x)
        attention_value, _ = self.mha(x_ln, x_ln, x_ln)
        attention_value = attention_value + x
        attention_value = self.ff_self(attention_value) + attention_value
        return attention_value.swapaxes(2, 1).view(-1, self.channels, self.size, self.size)


class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels, mid_channels=None, residual=False):
        super().__init__()
        self.residual = residual
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.GroupNorm(1, mid_channels),
            nn.GELU(),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.GroupNorm(1, out_channels),
        )

    def forward(self, x):
        if self.residual:
            return F.gelu(x + self.double_conv(x))
        else:
            return self.double_conv(x)


class Down(nn.Module):
    def __init__(self, in_channels, out_channels, emb_dim=256):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, in_channels, residual=True),
            DoubleConv(in_channels, out_channels),
        )

        self.emb_layer = nn.Sequential(
            nn.SiLU(),
            nn.Linear(
                emb_dim,
                out_channels
            ),
        )

    def forward(self, x, t):
        x = self.maxpool_conv(x)
        emb = self.emb_layer(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1])
        return x + emb


class Up(nn.Module):
    def __init__(self, in_channels, out_channels, emb_dim=256):
        super().__init__()

        self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
        self.conv = nn.Sequential(
            DoubleConv(in_channels, in_channels, residual=True),
            DoubleConv(in_channels, out_channels, in_channels // 2),
        )

        self.emb_layer = nn.Sequential(
            nn.SiLU(),
            nn.Linear(
                emb_dim,
                out_channels
            ),
        )

    def forward(self, x, skip_x, t):
        x = self.up(x)
        x = torch.cat([skip_x, x], dim=1)
        x = self.conv(x)
        emb = self.emb_layer(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1])
        return x + emb


class UNet(nn.Module):
    def __init__(self, c_in=IMAGE_CHW[0], c_out=IMAGE_CHW[0], time_dim=256, num_classes=NUM_CLASSES, device=DEVICE):
        super().__init__()
        self.device = device
        self.time_dim = time_dim
        self.inc = DoubleConv(c_in, 64)
        self.down1 = Down(64, 128)
        self.sa1 = SelfAttention(128, 32)
        self.down2 = Down(128, 256)
        self.sa2 = SelfAttention(256, 16)
        self.down3 = Down(256, 256)
        self.sa3 = SelfAttention(256, 8)

        self.bot1 = DoubleConv(256, 512)
        self.bot2 = DoubleConv(512, 512)
        self.bot3 = DoubleConv(512, 256)

        self.up1 = Up(512, 128)
        self.sa4 = SelfAttention(128, 16)
        self.up2 = Up(256, 64)
        self.sa5 = SelfAttention(64, 32)
        self.up3 = Up(128, 64)
        self.sa6 = SelfAttention(64, 64)
        self.outc = nn.Conv2d(64, c_out, kernel_size=1)

        if num_classes is not None:
            self.label_emb = nn.Embedding(num_classes, time_dim)

    def pos_encoding(self, t, channels):
        inv_freq = 1.0 / (
            10000
            ** (torch.arange(0, channels, 2, device=self.device).float() / channels)
        )
        pos_enc_a = torch.sin(t.repeat(1, channels // 2) * inv_freq)
        pos_enc_b = torch.cos(t.repeat(1, channels // 2) * inv_freq)
        pos_enc = torch.cat([pos_enc_a, pos_enc_b], dim=-1)
        return pos_enc

    def forward(self, x, t, y):
        t = t.unsqueeze(-1).type(torch.float)
        t = self.pos_encoding(t, self.time_dim)

        if y is not None:
            t += self.label_emb(y)

        x1 = self.inc(x)
        x2 = self.down1(x1, t)
        x2 = self.sa1(x2)
        x3 = self.down2(x2, t)
        x3 = self.sa2(x3)
        x4 = self.down3(x3, t)
        x4 = self.sa3(x4)

        x4 = self.bot1(x4)
        x4 = self.bot2(x4)
        x4 = self.bot3(x4)

        x = self.up1(x4, x3, t)
        x = self.sa4(x)
        x = self.up2(x, x2, t)
        x = self.sa5(x)
        x = self.up3(x, x1, t)
        x = self.sa6(x)
        output = self.outc(x)
        return output

In [6]:
# DDPM class
class MyDDPM(nn.Module):
    def __init__(self):
        super(MyDDPM, self).__init__()
        self.network = UNet().to(DEVICE)
        self.file_network = os.path.join(MODEL_PATH, MODEL_FILE)
        if os.path.exists(self.file_network):
            self.load_model()
        
        self.ema = EMA()
        self.ema_network = copy.deepcopy(self.network).eval().requires_grad_(False)
        self.file_ema_network = os.path.join(MODEL_PATH, EMA_MODEL_FILE)
        if os.path.exists(self.file_ema_network):
            self.load_ema_model()

        self.beta = self.noise_schedule().to(DEVICE)
        self.alpha = 1. - self.beta
        self.alpha_hat = torch.cumprod(self.alpha, dim=0)
        
    def noise_schedule(self):
        return torch.linspace(BETA_START, BETA_END, NOISE_STEPS)
    
    def sample_timesteps(self, n):
        return torch.randint(low=1, high=NOISE_STEPS, size=(n,))
    
    def sample(self, n, labels, ema=False, cfg_scale=3):
        print(f"Sampling {n} new images....")
        model = self.ema_network if ema else self.network
        model.eval()
        with torch.no_grad():
            x = torch.randn((n, IMAGE_CHW[0], IMAGE_CHW[1], IMAGE_CHW[2])).to(DEVICE)
            for i in tqdm(reversed(range(1, NOISE_STEPS)), position=0):
                t = (torch.ones(n) * i).long().to(DEVICE)
                predicted_noise = model(x, t, labels)
                if cfg_scale > 0:
                    uncond_predicted_noise = self.backward(x, t, None)
                    predicted_noise = torch.lerp(uncond_predicted_noise, predicted_noise, cfg_scale)
                alpha = self.alpha[t][:, None, None, None]
                alpha_hat = self.alpha_hat[t][:, None, None, None]
                alpha_hat_prec = self.alpha_hat[t-1][:, None, None, None]
                beta = self.beta[t][:, None, None, None]
                if i > 1:
                    noise = torch.randn_like(x)
                else:
                    noise = torch.zeros_like(x)
                #x = 1 / torch.sqrt(alpha) * (x - ((1 - alpha) / (torch.sqrt(1 - alpha_hat))) * predicted_noise) + torch.sqrt(beta) * noise
                val = 1/torch.sqrt(alpha_hat) * x - torch.sqrt((1-alpha_hat)/alpha_hat) * predicted_noise
                x = (torch.sqrt(alpha_hat_prec)*beta / (1 - alpha_hat)) * val.clamp(-1,1) + (1-alpha_hat_prec)*torch.sqrt(alpha)/(1-alpha_hat) * x + torch.sqrt(beta) * noise
        model.train()
        x = (x.clamp(-1, 1) + 1) / 2
        x = (x * 255).type(torch.uint8)
        return x
    
    def load_model(self):
        self.network.load_state_dict(torch.load(self.file_network))
        print("Modello caricato!")
        
    def load_ema_model(self):
        self.ema_network.load_state_dict(torch.load(self.file_ema_network))
        print("Modello EMA caricato!")
        
    def save_model(self):
        torch.save(self.network.state_dict(), self.file_network)
        print("Modello salvato!")
        
    def save_ema_model(self):
        torch.save(self.ema_network.state_dict(), self.file_ema_network)
        print("Modello EMA salvato!")
            
    def step_ema(self):
        self.ema.step_ema(self.ema_network, self.network)

    def forward(self, x0, t):
        sqrt_alpha_hat = torch.sqrt(self.alpha_hat[t])[:, None, None, None]
        sqrt_one_minus_alpha_hat = torch.sqrt(1 - self.alpha_hat[t])[:, None, None, None]
        eps = torch.randn_like(x0)
        return sqrt_alpha_hat * x0 + sqrt_one_minus_alpha_hat * eps, eps

    def backward(self, x, t, c=None):
        # Run each image through the network for each timestep t in the vector t.
        # The network returns its estimation of the noise that was added.
        return self.network(x, t, c)

In [None]:
sample_500(MyDDPM())