# Neural Style Transfer with Determined AI Distributed Training 

In this notebook, we are using [Determined AI Distributed Training Platform](https://www.googleadservices.com/pagead/aclk?sa=L&ai=DChcSEwj0qJ2Hha3-AhXKMq0GHdKjAHsYABAAGgJwdg&ohost=www.google.com&cid=CAESauD28j3TAEQF3m2XI5muKYYFzyP5j_nYonVGVJg5j0l7ImbzKbJzE3317fwr9tHoies3u_WbAXhOvYKWOS-uhOn1TfKKWuaELUEb01YWMkU23PIzqaQO0Rc4vj4ycnsAEANVyXvnv9pTioY&sig=AOD64_2oY2aUJkycrKR4Tq71uJOV_rGjDQ&q&adurl&ved=2ahUKEwjQ9JSHha3-AhXjOX0KHSLcB8gQ0Qx6BAgFEAE) to train a customized neural style transfer CNN Model built with PyTorch to obtain style transfer results.

## 1. Preparation 

In [1]:
!pip install determined





In [3]:
# !pip3 install torch torchvision torchaudio
# commented out to save space

In [4]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

In [5]:
import torch
import torch.nn.functional as F
import torch.optim as optim
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image

In [6]:
import warnings
warnings.filterwarnings('ignore')
# turn off warnings

In [44]:
import logging
logger = logging.getLogger()
logger.setLevel(logging.CRITICAL)

## 2. Problem Definition & Model Architecture

Imganie creating a copy of a painting as a mathematical optimization problem. Given the source image `S`, we want to generate a target image `T`, so that the mean square error **MSE(S-T)** is the smallest, which means this imaage generation problem becomes an optimization problem of finding the optimal `T`.

For this problem, we can randomly initialize an image `T`, perform gradient descent based on this optimization goal, and eventually find the optimal `T`, a target image that is the closest to the source image `S`.

![model-architecture](images/model-architecture.png)

paper link: [Image Style Transfer Using Convolutional Neural Networks](https://openaccess.thecvf.com/content_cvpr_2016/papers/Gatys_Image_Style_Transfer_CVPR_2016_paper.pdf) 

The expected result should be:

![paper image](images/paper-image.png)

## 3. Data Processing

In order to calculate MSE correctly, the shape of all images must be **uniform**. Therefore, we resize images to `(256, 256)`. 

After data preprocessing, the image format is in `(c, h, w)`.

Style Image            |  Content Image
:-------------------------:|:-------------------------:
![](images/andyw.jpg)  |  ![](images/grogu.jpeg)

In [7]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [8]:
img_size = (256, 256)


def read_image(image_path):
    pipeline = transforms.Compose([transforms.Resize((img_size)), transforms.ToTensor()])

    img = Image.open(image_path)
    img = pipeline(img).unsqueeze(0)
    return img.to(device, torch.float)


def save_image(tensor, image_path):
    toPIL = transforms.ToPILImage()
    
    img = tensor.detach().cpu().clone()
    img = toPIL(img.squeeze(0))
    
    img.save(image_path)

In [9]:
style_img = read_image('images/andyw.jpg')
content_img = read_image('images/grogu.jpeg')

In [10]:
input_img = torch.randn(1, 3, *img_size, device=device)
input_img.requires_grad_(True)
optimizer = optim.LBFGS([input_img])

In [11]:
steps = 0
while steps <= 10:
    def closure():
        global steps
        optimizer.zero_grad()
        loss = F.mse_loss(input_img, style_img) + F.mse_loss(input_img, content_img)
        loss.backward()
        
        steps += 1
        if steps % 1 == 0:
            print(f'Step {steps}:')
            print(f'Loss: {loss}')
        return loss

    optimizer.step(closure)

save_image(input_img, 'images/output.jpg')

Step 1:
Loss: 2.5425872802734375
Step 2:
Loss: 2.5425591468811035
Step 3:
Loss: 0.11352577805519104
Step 4:
Loss: 0.11331774294376373
Step 5:
Loss: 0.11331774294376373
Step 6:
Loss: 0.11331774294376373
Step 7:
Loss: 0.11331774294376373
Step 8:
Loss: 0.11331774294376373
Step 9:
Loss: 0.11331774294376373
Step 10:
Loss: 0.11331774294376373
Step 11:
Loss: 0.11331774294376373


The output picture now looks like this: ![out](images/output.jpg). We will be reusing `read_image` and `save_image` in the following model training tasks.

## 4. Model training

### 4.1. Loss Function
Since we are using MSE to cauculate loss in this notebook, we will reuse its `F.mse_loss` directly instead of defining our own `forward()` functions in the `torch.autograd`.

In [12]:
def features(x: torch.Tensor):
    n, c, h, w = x.shape

    features = x.reshape(n * c, h * w)
    features = torch.mm(features, features.T) / n / c / h / w
    
    return features

In [13]:
class ContentLoss(torch.nn.Module):
    def __init__(self, target: torch.Tensor):
        super().__init__()
        self.target = target.detach()

    def forward(self, input):
        self.loss = F.mse_loss(input, self.target)
        return input
    

class StyleLoss(torch.nn.Module):
    def __init__(self, target: torch.Tensor):
        super().__init__()
        self.target = features(target.detach()).detach()

    def forward(self, input):
        G = features(input)
        self.loss = F.mse_loss(G, self.target)
        return input

### 4.2. Normalization

We also use one layer to normalize input distribution so the mean and standard deviation of the input data can be directly used for training.

In [14]:
class Normalization(torch.nn.Module):
    def __init__(self, mean, std):
        super().__init__()
        self.mean = torch.tensor(mean).to(device).reshape(-1, 1, 1)
        self.std = torch.tensor(std).to(device).reshape(-1, 1, 1)

    def forward(self, img):
        return (img - self.mean) / self.std

### 4.3. Model Assemble 

Following, we assemble the norm layer, the loss calculation and features(ONLY) of a pretrained VGG model from `torch.vision` using `torch.nn.Sequential`. If a certain layer is used for MSE calculation, we attach the loss module to it. 

In the following log, we can see the model architecture consists of:

```
 norm => conv_1 => style_loss_1 => relu_1 => conv_2 => style_loss_2 => relu_2 => pool_2 => conv_3 => style_loss_3 => relu_3 => conv_4 => content_loss_4 => style_loss_4 => relu_4 => pool_4 => conv_5 => style_loss_5
```

In [15]:
default_content_layers = ['conv_4']
default_style_layers = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5']

def get_model_and_losses(content_img, style_img, content_layers, style_layers):
    num_loss = 0
    expected_num_loss = len(content_layers) + len(style_layers)
    content_losses = []
    style_losses = []

    model = torch.nn.Sequential(
        Normalization([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]))
    cnn = models.vgg19(pretrained=True).features.to(device).eval()
    i = 0
    for layer in cnn.children():
        if isinstance(layer, torch.nn.Conv2d):
            i += 1
            name = f'conv_{i}'
        elif isinstance(layer, torch.nn.ReLU):
            name = f'relu_{i}'
            layer = torch.nn.ReLU(inplace=False)
        elif isinstance(layer, torch.nn.MaxPool2d):
            name = f'pool_{i}'
        elif isinstance(layer, torch.nn.BatchNorm2d):
            name = f'bn_{i}'
        else:
            raise RuntimeError(
                f'Unrecognized layer: {layer.__class__.__name__}')

        model.add_module(name, layer)

        if name in content_layers:
            target = model(content_img)
            content_loss = ContentLoss(target)
            model.add_module(f'content_loss_{i}', content_loss)
            content_losses.append(content_loss)
            num_loss += 1

        if name in style_layers:
            target_feature = model(style_img)
            style_loss = StyleLoss(target_feature)
            model.add_module(f'style_loss_{i}', style_loss)
            style_losses.append(style_loss)
            num_loss += 1

        if num_loss >= expected_num_loss:
            break

    return model, content_losses, style_losses

In [16]:
style_img = read_image('images/andyw.jpg')
content_img = read_image('images/grogu.jpeg')

In [17]:
input_img = torch.randn(1, 3, *img_size, device=device)
model, content_losses, style_losses = get_model_and_losses(
    content_img, style_img, default_content_layers, default_style_layers)

input_img.requires_grad_(True)
model.requires_grad_(False)

Sequential(
  (0): Normalization()
  (conv_1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (style_loss_1): StyleLoss()
  (relu_1): ReLU()
  (conv_2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (style_loss_2): StyleLoss()
  (relu_2): ReLU()
  (pool_2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv_3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (style_loss_3): StyleLoss()
  (relu_3): ReLU()
  (conv_4): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (content_loss_4): ContentLoss()
  (style_loss_4): StyleLoss()
  (relu_4): ReLU()
  (pool_4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv_5): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (style_loss_5): StyleLoss()
)

### 4.4. Image Generation

In [20]:
optimizer = optim.LBFGS([input_img])
style_weight = 1e5
content_weight = 1.5
steps = 0
prev_loss = 0

while steps <= 1000 and prev_loss < 100:
    def closure():
        with torch.no_grad():
            input_img.clamp_(0, 1)
        global steps
        global prev_loss
        optimizer.zero_grad()
        model(input_img)
        content_loss = 0
        style_loss = 0
        for l in content_losses:
            content_loss += l.loss
        for l in style_losses:
            style_loss += l.loss
        loss = content_weight * content_loss + style_weight * style_loss
        loss.backward()
        steps += 1
        if steps % 1 == 0:
            print(f'Step {steps}:')
            print(f'Prev Loss: {loss}')
            print(f'Loss: {loss}')
            save_image(input_img, f'images/output_{steps}.jpg')
        prev_loss = loss
        return loss

    optimizer.step(closure)
    
with torch.no_grad():
    input_img.clamp_(0, 1)
    
save_image(input_img, 'images/output_final.jpg')

Step 1:
Prev Loss: 2953.3984375
Loss: 2953.3984375
Step 2:
Prev Loss: 2951.595458984375
Loss: 2951.595458984375
Step 3:
Prev Loss: 702.031982421875
Loss: 702.031982421875
Step 4:
Prev Loss: 387.2135009765625
Loss: 387.2135009765625
Step 5:
Prev Loss: 201.2799835205078
Loss: 201.2799835205078
Step 6:
Prev Loss: 123.96598815917969
Loss: 123.96598815917969
Step 7:
Prev Loss: 90.75457763671875
Loss: 90.75457763671875
Step 8:
Prev Loss: 73.46826171875
Loss: 73.46826171875
Step 9:
Prev Loss: 62.302764892578125
Loss: 62.302764892578125
Step 10:
Prev Loss: 54.568519592285156
Loss: 54.568519592285156
Step 11:
Prev Loss: 49.43915939331055
Loss: 49.43915939331055
Step 12:
Prev Loss: 45.83957290649414
Loss: 45.83957290649414
Step 13:
Prev Loss: 42.9759521484375
Loss: 42.9759521484375
Step 14:
Prev Loss: 39.94908905029297
Loss: 39.94908905029297
Step 15:
Prev Loss: 37.23042297363281
Loss: 37.23042297363281
Step 16:
Prev Loss: 34.981529235839844
Loss: 34.981529235839844
Step 17:
Prev Loss: 33.119499

Step 131:
Prev Loss: 12.684412002563477
Loss: 12.684412002563477
Step 132:
Prev Loss: 12.660452842712402
Loss: 12.660452842712402
Step 133:
Prev Loss: 12.637665748596191
Loss: 12.637665748596191
Step 134:
Prev Loss: 12.615509033203125
Loss: 12.615509033203125
Step 135:
Prev Loss: 12.601007461547852
Loss: 12.601007461547852
Step 136:
Prev Loss: 12.57907772064209
Loss: 12.57907772064209
Step 137:
Prev Loss: 12.562492370605469
Loss: 12.562492370605469
Step 138:
Prev Loss: 12.53976821899414
Loss: 12.53976821899414
Step 139:
Prev Loss: 12.51814079284668
Loss: 12.51814079284668
Step 140:
Prev Loss: 12.49918270111084
Loss: 12.49918270111084
Step 141:
Prev Loss: 12.481354713439941
Loss: 12.481354713439941
Step 142:
Prev Loss: 12.463284492492676
Loss: 12.463284492492676
Step 143:
Prev Loss: 12.443449974060059
Loss: 12.443449974060059
Step 144:
Prev Loss: 12.423367500305176
Loss: 12.423367500305176
Step 145:
Prev Loss: 12.404609680175781
Loss: 12.404609680175781
Step 146:
Prev Loss: 12.385761260

Step 258:
Prev Loss: 12.488736152648926
Loss: 12.488736152648926
Step 259:
Prev Loss: 11.914255142211914
Loss: 11.914255142211914
Step 260:
Prev Loss: 11.904149055480957
Loss: 11.904149055480957
Step 261:
Prev Loss: 11.839544296264648
Loss: 11.839544296264648
Step 262:
Prev Loss: 11.808177947998047
Loss: 11.808177947998047
Step 263:
Prev Loss: 11.786237716674805
Loss: 11.786237716674805
Step 264:
Prev Loss: 11.747186660766602
Loss: 11.747186660766602
Step 265:
Prev Loss: 11.719892501831055
Loss: 11.719892501831055
Step 266:
Prev Loss: 11.694402694702148
Loss: 11.694402694702148
Step 267:
Prev Loss: 11.671867370605469
Loss: 11.671867370605469
Step 268:
Prev Loss: 11.65311050415039
Loss: 11.65311050415039
Step 269:
Prev Loss: 11.642206192016602
Loss: 11.642206192016602
Step 270:
Prev Loss: 11.631991386413574
Loss: 11.631991386413574
Step 271:
Prev Loss: 11.62299919128418
Loss: 11.62299919128418
Step 272:
Prev Loss: 11.62265396118164
Loss: 11.62265396118164
Step 273:
Prev Loss: 11.6326961

Step 385:
Prev Loss: 12.319891929626465
Loss: 12.319891929626465
Step 386:
Prev Loss: 12.381745338439941
Loss: 12.381745338439941
Step 387:
Prev Loss: 12.447964668273926
Loss: 12.447964668273926
Step 388:
Prev Loss: 12.43948745727539
Loss: 12.43948745727539
Step 389:
Prev Loss: 12.370811462402344
Loss: 12.370811462402344
Step 390:
Prev Loss: 12.374898910522461
Loss: 12.374898910522461
Step 391:
Prev Loss: 12.637775421142578
Loss: 12.637775421142578
Step 392:
Prev Loss: 12.875200271606445
Loss: 12.875200271606445
Step 393:
Prev Loss: 13.190455436706543
Loss: 13.190455436706543
Step 394:
Prev Loss: 13.250569343566895
Loss: 13.250569343566895
Step 395:
Prev Loss: 14.330364227294922
Loss: 14.330364227294922
Step 396:
Prev Loss: 49.0704345703125
Loss: 49.0704345703125
Step 397:
Prev Loss: 22.916542053222656
Loss: 22.916542053222656
Step 398:
Prev Loss: 16.69277000427246
Loss: 16.69277000427246
Step 399:
Prev Loss: 15.057304382324219
Loss: 15.057304382324219
Step 400:
Prev Loss: 14.608222961

And we can obtain a baby grodu image that is the same style as a [Andy Warhol](https://en.wikipedia.org/wiki/Andy_Warhol) Painting:


Style Image                   | Transfer Image   | Content Image   
:-------------------------:|:-------------------------:|:-------------------------:
<img src="images/andyw.jpg"  width="256" height="256">  |  ![](images/output_700.jpg)  | <img src="images/original.jpeg"  width="256" height="256"> 

## 5. AI Distributed Training with Determined AI

Next, we use [determined](https://github.com/determined-ai/determined) package to enable distributed model training across clusters for more efficient training process and performance optimization. 

There are two ways I tried to initialize a determined AI project: WEB cluster interface vs Docker. The determined Version used in this project is:
```shell
$ det --version
det 0.21.1
```

In [21]:
! git clone https://github.com/determined-ai/determined

Cloning into 'determined'...
remote: Enumerating objects: 126324, done.[K
remote: Counting objects: 100% (1477/1477), done.[K
remote: Compressing objects: 100% (740/740), done.[K
remote: Total 126324 (delta 889), reused 1183 (delta 707), pack-reused 124847[K
Receiving objects: 100% (126324/126324), 121.64 MiB | 29.15 MiB/s, done.
Resolving deltas: 100% (97207/97207), done.


Following we will use **PyTorchTrial** based on [Determined AI PyTorch Trail](https://docs.determined.ai/latest/training/apis-howto/api-pytorch-ug.html) for specifying experiment configuration:

In [37]:
from typing import Any, Dict, Union, Sequence
from determined.pytorch import DataLoader, PyTorchTrial, PyTorchTrialContext
import os

TorchData = Union[Dict[str, torch.Tensor], Sequence[torch.Tensor], torch.Tensor]

class MyTrial(PyTorchTrial):
    def __init__(self, context: PyTorchTrialContext) -> None:
        self.context = context
        self.optimizer = optim.LBFGS([input_img])
        
    def build_model(self):
        model, content_losses, style_losses = get_model_and_losses(
    content_img, style_img, default_content_layers, default_style_layers)
        return model
    
    def build_training_data_loader(self) -> DataLoader:
        traindir = os.path.join(os.getcwd(), 'images')
        self.normalize = Normalization()

        train_dataset = datasets.ImageFolder(
            traindir,
            transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                self.normalize,
            ])
        )

        train_loader = determined.pytorch.DataLoader(
            train_dataset,
            batch_size=self.context.get_per_slot_batch_size(),
            shuffle=True,
            num_workers=self.context.get_hparam("workers", pin_memory=True),
        )
        
        return train_loader

    def build_validation_data_loader(self) -> DataLoader:
        return DataLoader()

    def train_batch(self, batch: TorchData, epoch_idx: int, batch_idx: int)  -> Dict[str, Any]:
        return {}

    def evaluate_batch(self, batch: TorchData) -> Dict[str, Any]:
        return {}

In [50]:
def main():
    with det.pytorch.init() as train_context:
        trial = MyTrial(train_context)
        trainer = det.pytorch.Trainer(trial, train_context)
        trainer.fit(
            max_length=pytorch.Epoch(1),
            checkpoint_period=pytorch.Batch(500),
            validation_period=pytorch.Batch(500),
            checkpoint_policy="all",
        )

if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO, format=det.LOG_FORMAT)
    main()

Use **local distributed training**, function initialization becomes: 

In [52]:
def main():
    dist.init_process_group(backend="gloo|nccl")  
    os.environ["USE_TORCH_DISTRIBUTED"] = "true"
    
    with det.pytorch.init(distributed=core.DistributedContext.from_torch_distributed()) as train_context:
        trial = MyTrial(train_context)
        trainer = det.pytorch.Trainer(trial, train_context)
        trainer.fit(
            max_length=pytorch.Epoch(1),
            checkpoint_period=pytorch.Batch(100),
            validation_period=pytorch.Batch(100),
            checkpoint_policy="all",
        )

To train the model, whether it's single instance or distributed, we can start using:

In [53]:
train_model=["name:image_style_transfer",
              "global_batch_size: 32",
               "dense1: 128",
               "name: single",
               "metric: val_accuracy",
               "epochs: 5",
               "entrypoint: model_def:image_style_transfer"]

## 6. References

- [PyTorch](https://pytorch.org/) including [TorchVision](https://pytorch.org/vision/)
- [Basic Setup - Determined AI Documentation](https://docs.determined.ai/latest/cluster-setup-guide/basic.html)
- [Install Determined Using Docker - DEtermined AI Documentation](https://www.googleadservices.com/pagead/aclk?sa=L&ai=DChcSEwjr-L31k7D-AhVFLK0GHcQMDlYYABAAGgJwdg&ohost=www.google.com&cid=CAESauD2a8aozeJBdbuvL_0G3p8oPtY-Z57CWfOI519uSb4QLuOJapvAsKL-hTA_kYiMZENmLoC-ZB4b7Y-QwvCWwrdlsfCE-nW2ATpYJhG8Id7saqRay842BZHpfXkN7T7PpmLaVUa0OYGjn9c&sig=AOD64_0HSgtNM8ymy1FKpSbmmeQorcDIdw&q&adurl&ved=2ahUKEwjWnrb1k7D-AhVQhu4BHciYBWkQ0Qx6BAgHEAE)
- [Image Style Transfer Using Convolutional Neural Networks](https://www.cv-foundation.org/openaccess/content_cvpr_2016/papers/Gatys_Image_Style_Transfer_CVPR_2016_paper.pdf)
- [A Keras Implementation of Image Style Transfer Using Convolutional Neural Networks, Gatys et al](https://github.com/superb20/Image-Style-Transfer-Using-Convolutional-Neural-Networks)
- [a PyTorch implementation of Image Style Transfer Using Convolutional Neural Networks](https://github.com/ali-gtw/ImageStyleTransfer-CNN)
- [Image Style Transfer using CNN](https://github.com/Suvoo/Image-Style-Transfer-Using-CNNs)
- [DL Style Transfer Demos](github.com/SingleZombie/DL-Demos/tree/master/dldemos/StyleTransfer)
- [PyTorch Style Transfer Official Tutorials](https://link.zhihu.com/?target=https%3A//pytorch.org/tutorials/advanced/neural_style_tutorial.html)