<a href="https://colab.research.google.com/github/lacykaltgr/continual-learning-ait/blob/experiment/generator.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
'''Download the files '''
'''Only for colab'''

!wget https://github.com/lacykaltgr/continual-learning-ait/archive/refs/heads/experiment.zip
!unzip experiment.zip
!find continual-learning-ait-experiment -type f ! -name "main.ipynb" -exec cp {} . \;

!rm -r stable_diffusion
!rm -r models
!mkdir stable_diffusion
!mkdir models
!mv diffusion_model.py stable_diffusion/
!mv autoencoder_kl.py stable_diffusion/
!mv layers.py stable_diffusion/
!mv stable_diffusion.py stable_diffusion/
!mv constants.py stable_diffusion/
!mv encoder.h5 models/
!mv classifier.h5 models/

--2023-05-11 14:25:26--  https://github.com/lacykaltgr/continual-learning-ait/archive/refs/heads/experiment.zip
Resolving github.com (github.com)... 140.82.113.4
Connecting to github.com (github.com)|140.82.113.4|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://codeload.github.com/lacykaltgr/continual-learning-ait/zip/refs/heads/experiment [following]
--2023-05-11 14:25:26--  https://codeload.github.com/lacykaltgr/continual-learning-ait/zip/refs/heads/experiment
Resolving codeload.github.com (codeload.github.com)... 140.82.112.9
Connecting to codeload.github.com (codeload.github.com)|140.82.112.9|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: unspecified [application/zip]
Saving to: ‘experiment.zip’

experiment.zip          [ <=>                ]   1.58M  --.-KB/s    in 0.1s    

2023-05-11 14:25:26 (14.8 MB/s) - ‘experiment.zip’ saved [1659606]

Archive:  experiment.zip
6ab06d34a435387547f849642cd1aca6fb12d185
   creatin

In [2]:
from keras.models import load_model
import tensorflow as tf
import keras
import numpy as np
from keras.layers import Conv2D, Conv2DTranspose
import math
from stable_diffusion.constants import _ALPHAS_CUMPROD

In [3]:
def load_cifar_10():
    (X_train, y_train), (X_test, y_test) = tf.keras.datasets.cifar10.load_data()
    n_classes = 10
    X_train = (X_train / 127.5) -1
    X_test = (X_test / 127.5) -1
    y_train = tf.keras.utils.to_categorical(y_train, n_classes)
    y_test = tf.keras.utils.to_categorical(y_test, n_classes)
    return (X_train, y_train), (X_test, y_test)

In [4]:
encoder = load_model("models/encoder.h5")
classifier = load_model("models/classifier.h5")

encoder.compile(
    optimizer=keras.optimizers.Adam(learning_rate=5e-3),
    loss=keras.losses.CategoricalCrossentropy(),
    metrics=['accuracy']
)

classifier.compile(
    optimizer=keras.optimizers.Adam(learning_rate=5e-3),
    loss=keras.losses.CategoricalCrossentropy(),
    metrics=['accuracy']
)



In [5]:
def apply_seq(x: object, layers: object) -> object:
    for l in layers:
        x = l(x)
    return x

In [6]:
class ResBlock(keras.layers.Layer):
    def __init__(self, channels, out_channels):
        super().__init__()
        self.in_layers = [
            keras.layers.GroupNormalization(epsilon=1e-5),
            keras.activations.swish,
            Conv2D(out_channels, 3, strides=(1, 1), padding='same'),
        ]
        self.emb_layers = [
            keras.activations.swish,
            keras.layers.Dense(out_channels),
        ]
        self.out_layers = [
            keras.layers.GroupNormalization(epsilon=1e-5),
            keras.activations.swish,
            Conv2D(out_channels, 3, strides=(1, 1), padding='same'),
        ]
        self.skip_connection = (
            Conv2D(out_channels, 3, strides=(1, 1), padding='same') if channels != out_channels else lambda x: x
        )

    def call(self, inputs):
        x, emb = inputs
        h = apply_seq(x, self.in_layers)
        emb_out = apply_seq(emb, self.emb_layers)
        h = h + emb_out[:, None, None]
        h = apply_seq(h, self.out_layers)
        skip_x = self.skip_connection(x)
        ret = skip_x + h
        return ret

In [7]:
class UNetModel(keras.Model):
    def __init__(self):
        print("UNetModel init")
        super().__init__()
        self.img_height = 32
        self.img_width = 32
        self.ntype = tf.float32
        self.time_embed = [
            keras.layers.Dense(128),
            keras.activations.swish,
            keras.layers.Dense(128),
        ]
        self.input_blocks = [
            [Conv2D( 32, 3, strides=(1, 1), padding='same')],

            [ResBlock(32, 32)],
            [ResBlock(32, 32)],
            [Conv2D(64, 3, strides=(2, 2), padding='same'),], #downsample

            [ResBlock(32, 64)], 
            [ResBlock(64, 64)], 
            [Conv2D(128, 3, strides=(2, 2), padding='same'),], #downsample

            [ResBlock(64, 128)],
            [ResBlock(128, 128)],
            [Conv2D(128, 3, strides=(2, 2), padding='same')], #downsample

            [ResBlock(128, 128)],
            [ResBlock(128, 128)],
        ]
        self.middle_block = [
            ResBlock(128, 128),
            ResBlock(128, 128),
        ]
        self.output_blocks = [
            [ResBlock(256, 128)],
            [ResBlock(256, 128)],

            [
                ResBlock(256, 128),
                Conv2DTranspose(128, 2, strides=(2,2)),
                Conv2D(128, 3, strides=(1,1), padding='same')
            ],
            [ResBlock(256, 128)], 
            [ResBlock(256, 128)],

            [
                ResBlock(192, 128),
                Conv2DTranspose(128, 2, strides=(2,2)),
                Conv2D(64, 2, strides=(1,1), padding='valid')
            ],
            [ResBlock(192, 64)], 
            [ResBlock(128, 64)], 

            [
                ResBlock(96, 64),
                Conv2DTranspose(64, 3, strides=(2,2)),
                Conv2D(64, 2, strides=(1,1), padding='valid')
            ],
            [ResBlock(96, 32)], 
            [ResBlock(64, 32)],

            [ResBlock(64, 32)],
        ]
        self.out = [
            keras.layers.GroupNormalization(epsilon=1e-5),
            keras.activations.swish,
            Conv2D(8, 3, strides=(1,1), padding='same'),
        ]

    def call(self, inputs):
        x, t_emb = inputs
        emb = apply_seq(t_emb, self.time_embed)

        def apply(x, layer):
            return layer([x, emb]) if isinstance(layer, ResBlock) else layer(x)

        saved_inputs = []
        for b in self.input_blocks:
            for layer in b:
                x = apply(x, layer)
            saved_inputs.append(x)

        for layer in self.middle_block:
            x = apply(x, layer)

        for b in self.output_blocks:
            skip = saved_inputs.pop()
            x = tf.concat([x, skip], axis=-1)
            for layer in b:
                x = apply(x, layer)

        return apply_seq(x, self.out)

    def initialize(self, params, input_latent=None, batch_size=64):
        timesteps = np.arange(1, params['num_steps']+ 1)
        input_lat_noise_t = timesteps[int(len(timesteps)* params["input_latent_strength"])]
        latent, alphas, alphas_prev = self.get_starting_parameters(
            timesteps, batch_size, input_latent=input_latent, input_lat_noise_t=input_lat_noise_t
        )
        timesteps = timesteps[: int(len(timesteps)*params["input_latent_strength"])]
        return latent, alphas, alphas_prev, timesteps


    def get_x_prev(self, x, e_t, a_t, a_prev, temperature):
        sigma_t = 0
        sqrt_one_minus_at = math.sqrt(1 - a_t)
        pred_x0 = x - sqrt_one_minus_at * e_t / math.sqrt(a_t)

        # Direction pointing to x_t
        dir_xt = math.sqrt(1.0 - a_prev - sigma_t**2) * e_t
        #noise = sigma_t * tf.random.normal(x.shape, seed=seed) * temperature
        x_prev = math.sqrt(a_prev) * pred_x0 + dir_xt
        return x_prev


    def get_model_output(self, latent, timestep, batch_size):
        timesteps = tf.convert_to_tensor([timestep], dtype=tf.float32)
        t_emb = self.timestep_embedding(timesteps)
        t_emb = tf.repeat(t_emb, repeats=batch_size, axis=0)
        latent = self.call([latent, t_emb])
        return latent


    def timestep_embedding(self, timesteps, dim=320, max_period=10000):
        half = dim // 2
        freqs = np.exp(
            -math.log(max_period) * np.arange(0, half, dtype="float32") / half
        )
        args = np.array(timesteps) * freqs
        embedding = np.concatenate([np.cos(args), np.sin(args)])
        return tf.convert_to_tensor(embedding.reshape(1, -1), dtype=self.ntype)



    # for model with input latent

    def add_noise(self, x, t, noise=None):
        if len(x.shape) == 3:
            x = tf.expand_dims(x, axis=0)
        batch_size, w, h, c = x.shape[0], x.shape[1], x.shape[2], x.shape[3]
        if noise is None:
            noise = tf.random.normal((batch_size, w, h, c), dtype=tf.float32)
        sqrt_alpha_prod = tf.cast(_ALPHAS_CUMPROD[t] ** 0.5, tf.float32)
        sqrt_one_minus_alpha_prod = (1 - _ALPHAS_CUMPROD[t]) ** 0.5

        return sqrt_alpha_prod * x + sqrt_one_minus_alpha_prod * noise

    def get_starting_parameters(self, timesteps, batch_size,  input_latent=None, input_lat_noise_t=None):
        n_h = self.img_height // 8
        n_w = self.img_width // 8
        alphas = [_ALPHAS_CUMPROD[t] for t in timesteps]
        alphas_prev = [1.0] + alphas[:-1]
        if input_latent is None:
            latent = tf.random.normal((batch_size, n_h, n_w, 8))
        else:
            input_latent = tf.cast(input_latent, self.ntype)
            #latent = tf.repeat(input_latent , batch_size , axis=0)
            latent = self.add_noise(input_latent, input_lat_noise_t)
        return latent, alphas, alphas_prev


In [8]:
def get_one_hot_predictions(mem_pred):
    maximum = np.argmax(mem_pred, axis=1)
    num_classes = mem_pred.shape[1]
    mem_true = np.zeros_like(mem_pred)
    mem_true[np.arange(len(maximum)), maximum] = 1
    return mem_true

In [66]:
'''Generate samples and train the diffusion model at the same time'''

#plusz lehetne itt még kritérium hogy ne menjen olyan messze az alaptól --- similarity loss
#plusz még lehetne talán egy discriminator is, hogy valós reprezentációkat tanuljon meg


def generate(cls=classifier, input_latent=None, train=True, coeff=1.0):

    batch_size = params['batch_size'] if train else 64
    latent, alphas, alphas_prev, timesteps = model.initialize(params, input_latent, batch_size)


    for index, timestep in reversed(list(enumerate(timesteps))):
        if train:
            with tf.GradientTape() as tape:
                e_t = model.get_model_output(
                    latent,
                    timestep,
                    batch_size,
                )
                a_t, a_prev = alphas[index], alphas_prev[index]
                latent = model.get_x_prev(latent, e_t,  a_t, a_prev, params["temperature"])

                pred = cls(latent)
                pred_true = get_one_hot_predictions(pred) #ezt nem fixen kell mecsinálni
                confidence_loss = coeff*tf.reduce_mean(tf.keras.losses.categorical_crossentropy(pred_true, pred))
                print(confidence_loss)
                similarity_loss = 0.1 * tf.reduce_mean(tf.square(latent - e_t))
                print(similarity_loss)
                loss = confidence_loss + similarity_loss
            grads = tape.gradient(loss, model.trainable_variables)
            tf.keras.optimizers.legacy.Adam(learning_rate=params["gen_lr"]).apply_gradients(zip(grads, model.trainable_variables))
        else:
            e_t = model.get_model_output(
                latent,
                timestep,
                batch_size,
            )
            a_t, a_prev = alphas[index], alphas_prev[index]
            latent = model.get_x_prev(latent, e_t,  a_t, a_prev, params["temperature"])

    return latent

In [67]:
model = UNetModel()

UNetModel init


In [64]:
params = {
    "num_steps": 3,
    "input_latent_strength": 0.9,
    "temperature": 0.9,
    "batch_size": 256,
    "gen_lr": 2e-3,
    "n_epoch": 1,
}

In [12]:
(X_train, y_train), (X_test, y_test) = load_cifar_10()

Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz


In [68]:
for epoch in range(params["n_epoch"]):
    loss = []
    for i in range(0, X_train.shape[0], params["batch_size"]):
        X_batch = X_train[i:i+params["batch_size"]]
        y_batch = y_train[i:i+params["batch_size"]]
        latent = encoder(X_batch)
        mem_x = generate(input_latent=latent, train=True)
        mem_pred = classifier(mem_x)
        unique_values, counts = np.unique(np.argmax(mem_pred, axis=1), return_counts=True)
        for value, count in zip(unique_values, counts):
            print("Value:", value, "Count:", count)
        #print(np.unique(np.argmax(mem_pred, axis=1)))
        mem_pred_true = get_one_hot_predictions(mem_pred)
        mem_loss = tf.keras.losses.categorical_crossentropy(mem_pred_true, mem_pred)
        loss.append(np.mean(mem_loss))
        print(np.mean(mem_loss))
    print("Loss on generate: ",  np.mean(loss))

tf.Tensor(0.13102618, shape=(), dtype=float32)
tf.Tensor(0.05313134, shape=(), dtype=float32)
tf.Tensor(0.11086766, shape=(), dtype=float32)
tf.Tensor(0.03974482, shape=(), dtype=float32)
Value: 0 Count: 16
Value: 1 Count: 38
Value: 2 Count: 24
Value: 3 Count: 24
Value: 4 Count: 23
Value: 5 Count: 22
Value: 6 Count: 33
Value: 7 Count: 25
Value: 8 Count: 22
Value: 9 Count: 29
0.11086766
tf.Tensor(0.14926352, shape=(), dtype=float32)
tf.Tensor(0.04855547, shape=(), dtype=float32)
tf.Tensor(0.13393158, shape=(), dtype=float32)
tf.Tensor(0.034818947, shape=(), dtype=float32)
Value: 0 Count: 31
Value: 1 Count: 30
Value: 2 Count: 17
Value: 3 Count: 21
Value: 4 Count: 36
Value: 5 Count: 21
Value: 6 Count: 21
Value: 7 Count: 25
Value: 8 Count: 25
Value: 9 Count: 29
0.13393158
tf.Tensor(0.17173126, shape=(), dtype=float32)
tf.Tensor(0.029068053, shape=(), dtype=float32)
tf.Tensor(0.17189915, shape=(), dtype=float32)
tf.Tensor(0.041004803, shape=(), dtype=float32)
Value: 0 Count: 27
Value: 1 Cou

KeyboardInterrupt: ignored