In [272]:
import math
import tqdm
import mlx.data as dx
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim

from pathlib import Path
from dataclasses import dataclass
from functools import partial


# Data

In [273]:
@dataclass
class Hyperparameters:
    train_data_root: str = '/Users/noel/dataset/VOCdevKit/VOC2012/JPEGImages'
    crop_size: int = 96
    upscale_factor: int = 4
    learning_rate: float = 0.001
    batch_size: int = 64
    num_epochs: int = 10
    dropout_rate: float = 0.5
    num_workers: int = 4

hp = Hyperparameters()

In [274]:
def get_img_files(root: Path):
    images = []
    for ext in ['*.png', '*.jpg', '*.jpeg', '*.PNG', '*.JPG', '*.JPEG']:
        images += list(root.rglob(ext))

    return [{"image": str(img).encode("ascii")} for img in images]


In [275]:
dataset = (
    dx.buffer_from_vector(get_img_files(root=Path(hp.train_data_root)))
    .to_stream()
    .load_image('image')
    .image_random_crop('image', hp.crop_size, hp.crop_size, 'hr_image')
    .image_resize('image', hp.crop_size // hp.upscale_factor, hp.crop_size // hp.upscale_factor, 'lr_image')
    .key_transform("hr_image", lambda x: x.astype("float32"))
    .key_transform("lr_image", lambda x: x.astype("float32"))
    .batch(hp.batch_size)
    .prefetch(4, 2)
)

In [276]:
next(dataset).keys()


dict_keys(['hr_image', 'image', 'lr_image'])

# Modeling

![image.png](https://raw.githubusercontent.com/crlotwhite/ML_Study/main/%EB%85%BC%EB%AC%B8%EA%B5%AC%ED%98%84/generative/assets/Cap%202024-08-15%2001-58-41-395.jpg)

In [277]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels):
        super().__init__()

        self.model = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=in_channels,
                      kernel_size=3, padding=1),
            nn.BatchNorm(num_features=in_channels),
            nn.PReLU(),
            nn.Conv2d(in_channels=in_channels, out_channels=in_channels,
                      kernel_size=3, padding=1),
            nn.BatchNorm(num_features=in_channels)
        )

    def __call__(self, x):
        return self.model(x) + x

In [278]:
class PixelShuffle(nn.Module):
    def __init__(self, upscale_factor):
        super().__init__()
        
        self.upscale_factor = upscale_factor

    def __call__(self, x):
        batch, H_in, W_in, C_in = x.shape[:-3], x.shape[-3], x.shape[-2], x.shape[-1]
        C_out = C_in // self.upscale_factor**2 
        H_out = H_in * self.upscale_factor
        W_out = W_in * self.upscale_factor
        
        x = mx.reshape(x, shape= (*batch, self.upscale_factor, self.upscale_factor, H_in, W_in, C_out))
        x = mx.einsum('b c u v h w -> b c h u w v', x)
        
        return mx.reshape(x, shape=(*batch, H_out, W_out, C_out))

    

In [279]:
# pixelshuffle test
model = PixelShuffle(4)

input_image = mx.random.randint(0, 1, shape=(64, 24, 24, 64))
output_image = model(input_image)

print(input_image.shape)
print(output_image.shape) # expect: (1, 1, 12, 12)

(64, 24, 24, 64)
(64, 96, 96, 4)


In [280]:
# pixelshuffle test
upscale_factor = 3
model = PixelShuffle(upscale_factor)

input_image = mx.random.randint(0, 1, shape=(1, 4, 4, 9))
output_image = model(input_image)

print(input_image.shape)
print(output_image.shape) # expect: (1, 12, 12, 1)

(1, 4, 4, 9)
(1, 12, 12, 1)


In [281]:
class UpsampleBlock(nn.Module):
    def __init__(self, in_channels, up_scale):
        super().__init__()

        self.model = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, 
                      out_channels=in_channels * up_scale**2,
                      kernel_size=3, padding=1),
            PixelShuffle(up_scale),
            nn.PReLU()
        )

    def __call__(self, x):
        return self.model(x)

In [282]:
class Generator(nn.Module):
    def __init__(self, scale_factor):
        super().__init__()

        num_upsamples = int(math.log(scale_factor, 2))

        self.model = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=64, kernel_size=9, padding=4),
            nn.PReLU(),
            *[ResidualBlock(64) for _ in range(5)],
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1),
            nn.BatchNorm(num_features=64),
            *[UpsampleBlock(64, 2) for _ in range(num_upsamples)],
            nn.Conv2d(in_channels=64, out_channels=3, kernel_size=9, padding=4),
            nn.Tanh()
        )
        
    def __call__(self, x):
        return (self.model(x) + 1) / 2
        

In [283]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()

        self.model = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, 
                      kernel_size=3, padding=1),
            nn.BatchNorm(num_features=out_channels),
            nn.LeakyReLU(0.2),
            nn.Conv2d(in_channels=out_channels, out_channels=out_channels, 
                      kernel_size=3, stride=2, padding=1),
            nn.BatchNorm(num_features=out_channels),
            nn.LeakyReLU(0.2),
        )
        
    def __call__(self, x):
        return self.model(x)

In [284]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()

        self.model = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1),
            nn.BatchNorm(num_features=64),
            nn.LeakyReLU(0.2),
            *[ConvBlock(in_channels, out_channels) 
              for in_channels, out_channels in [(64, 128), (128, 256), (256, 512)]],
            nn.AvgPool2d(kernel_size=12),
            nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(in_channels=1024, out_channels=1, kernel_size=1),
        )

    def __call__(self, x):
        batch_size = x.shape[0]
        return nn.sigmoid(self.model(x).reshape(batch_size, -1))
        

In [285]:
pool = nn.AvgPool2d(kernel_size=12)
x = mx.random.randint(0, 1, shape=(64, 12, 12, 512))# (64, 12, 12, 512)
pool(x).shape

(64, 1, 1, 512)

# Loss
![](https://raw.githubusercontent.com/crlotwhite/ML_Study/main/%EB%85%BC%EB%AC%B8%EA%B5%AC%ED%98%84/generative/assets/Cap%202024-08-15%2002-45-29-962.png)
![](https://raw.githubusercontent.com/crlotwhite/ML_Study/main/%EB%85%BC%EB%AC%B8%EA%B5%AC%ED%98%84/generative/assets/Cap%202024-08-15%2002-43-50-138.png)

In [286]:
class TVLoss(nn.Module):
    def __init__(self, tv_loss_weight=1):
        super().__init__()

        self.tv_loss_weight = tv_loss_weight

    def __call__(self, x):
        batch_size, h_x, w_x = x.shape[0], x.shape[2], x.shape[3]
        count_h = self.tensor_size(x[:, :, 1:, :])
        count_w = self.tensor_size(x[:, :, :, 1:])
        h_tv = mx.power((x[:, :, 1:, :] - x[:, :, :h_x - 1, :]), 2).sum()
        w_tv = mx.power((x[:, :, :, 1:] - x[:, :, :, :w_x - 1]), 2).sum()
        return self.tv_loss_weight * 2 * (h_tv / count_h + w_tv / count_w) / batch_size

    def tensor_size(self, t):
        return t.shape[1] * t.shape[2] * t.shape[3]

In [287]:
import h5py, requests, os

VGG16_WEIGHTS_LINK = (
    "https://storage.googleapis.com/tensorflow/keras-applications/"
    "vgg16/vgg16_weights_tf_dim_ordering_tf_kernels.h5"
)

class VGG16(nn.Module):     
    def __init__(self, weight_path=None):
        super().__init__()
        
        self.model = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        
        self.classifier = nn.Sequential(
            nn.Linear(25088, 4096),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(4096, 4096),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Softmax()
        )
        
        if weight_path is not None:
            self.load_weights()
    
    def __call__(self, x):
        return self.classifier(self.model(x))
        

    def load_weights(self, weights_path=None) -> None:
        if weights_path is None:
            weights_path = 'resources/vgg16_weights_tf_dim_ordering_tf_kernels.h5'
            if not os.path.exists(weights_path):
                os.makedirs('resources', exist_ok=True)
                r = requests.get(VGG16_WEIGHTS_LINK, allow_redirects=True)
                open(weights_path, 'wb').write(r.content)
                
        weight = h5py.File(weights_path, 'r')
        for key in weight.keys():
            params = weight[key]
            params_keys = list(params.keys())
            try:
                layer = getattr(self, key)
                if len(layer.weight.shape) == 4:
                    layer.weight = mx.array(params[params_keys[0]][...]).transpose((3, 0, 1, 2))
                elif len(layer.weight.shape) == 2:
                    layer.weight = mx.array(params[params_keys[0]][...]).transpose((1, 0))
                layer.bias = mx.array(params[params_keys[1]][...])
            except:
                pass
        weight.close()
        return self

In [288]:
class GeneratorLoss(nn.Module):
    def __init__(self):
        vgg = VGG16()
        loss_net = vgg.model
        loss_net.eval()
        loss_net.freeze()
        self.loss_net = loss_net
        self.tv_loss = TVLoss()
    
    def __call__(self, D_out, fake_img, real_img):
        content_loss = nn.losses.mse_loss(self.loss_net(fake_img), 
                                          self.loss_net(real_img))
        adversarial_loss = mx.mean(1-D_out)
        perceptual_loss = content_loss * 0.006 + adversarial_loss * 0.001
        tv_loss = self.tv_loss(fake_img)
        
        return perceptual_loss + tv_loss * 2e-8
        
        

In [289]:
mx.set_default_device(mx.gpu)

G = Generator(4)
mx.eval(G.parameters())
D = Discriminator()
mx.eval(D.parameters())

G_loss = GeneratorLoss()

G_opt = optim.Adam(0.001)
D_opt = optim.Adam(0.001)



In [291]:
def discriminator_loss(hr_img, fake_img):
    return hr_img.mean() - 1 + fake_img.mean()

def d_forward(D_model, hr_img, fake_img):
    D_real = D_model(hr_img)
    D_fake = D_model(fake_img) 
    loss = discriminator_loss(D_real, D_fake)
    return loss, (D_real, D_fake)

def g_forward(D_model, hr_img, fake_img):
    D_out = D_model(fake_img).mean()
    loss = G_loss(D_out, fake_img, hr_img)
    return loss, D_out

g_loss_and_grad_fn = nn.value_and_grad(G, g_forward)
d_loss_and_grad_fn = nn.value_and_grad(D, d_forward)
    
state = [G.state, D.state, G_opt.state, D_opt.state]

G_losses = []
D_losses = []

@partial(mx.compile, inputs=state, outputs=state)
def step(lr_img, hr_img):
    fake_img = G(lr_img)
    (loss_g, _), grad = d_loss_and_grad_fn(D, hr_img, fake_img)
    D_opt.update(D, grad)
    (loss_d, _), grad = g_loss_and_grad_fn(D, hr_img, fake_img)
    G_opt.update(G, grad)
    return loss_g, loss_d

for batch in tqdm.tqdm(dataset, total=17123):
    hr_img = mx.array(batch['hr_image'])
    lr_img = mx.array(batch['lr_image'])

    loss_g, loss_d = step(lr_img, hr_img)
    mx.eval(state)
    
    G_losses.append(loss_g)
    D_losses.append(loss_d)

  0%|                                                 | 0/17123 [00:00<?, ?it/s]

  0%|                                     | 32/17123 [02:51<25:23:21,  5.35s/it]


KeyboardInterrupt: 

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(10, 5))
plt.plot(D_losses, label='Discriminator Loss')
plt.plot(G_losses, label='Generator Loss')
plt.title('Training Losses')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid()
plt.show()