# Train `super-image` Models

---

[Github](https://github.com/eugenesiow/super-image) | All Models @ [huggingface.co](https://huggingface.co/models?filter=super-image) | All Datasets @ [huggingface datasets](https://huggingface.co/datasets?filter=task_ids:other-other-image-super-resolution)

---

Notebook to train `super-image` models for image super resolution tasks.

The notebook is structured as follows:
* Setting up the Environment
* Loading and Augmenting the Dataset
* Training the Model

## Setting up the Environment

#### Ensure we have a GPU runtime

If you're running this notebook in Google Colab, select `Runtime` > `Change Runtime Type` from the menubar. Ensure that `GPU` is selected as the `Hardware accelerator`. This will allow us to use the GPU to train the model subsequently.

#### Install the library

We will install the `super-image` and huggingface `datasets` library using `pip install`.

In [1]:
!pip install -qq datasets super-image

[K     |████████████████████████████████| 542 kB 12.7 MB/s 
[K     |████████████████████████████████| 43 kB 1.9 MB/s 
[K     |████████████████████████████████| 76 kB 4.1 MB/s 
[K     |████████████████████████████████| 43 kB 1.7 MB/s 
[K     |████████████████████████████████| 118 kB 52.7 MB/s 
[K     |████████████████████████████████| 243 kB 49.6 MB/s 
[K     |████████████████████████████████| 51.0 MB 55 kB/s 
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
albumentations 0.1.12 requires imgaug<0.2.7,>=0.2.5, but you have imgaug 0.2.9 which is incompatible.[0m
[?25h

## Loading and Augmenting the Dataset

We download the [`Div2k`](https://huggingface.co/datasets/eugenesiow/Div2k) dataset using the huggingface `datasets` library. You can explore more super resolution datasets [here](https://huggingface.co/datasets?filter=task_ids:other-other-image-super-resolution). 

We then follow the pre-processing and augmentation method of [Wang et al. (2021)](https://arxiv.org/abs/2104.07566). This will take awhile.

In [2]:
from datasets import load_dataset
from super_image.data import EvalDataset, TrainDataset, augment_five_crop

augmented_dataset = load_dataset('eugenesiow/Div2k', 'bicubic_x4', split='train')\
    .map(augment_five_crop, batched=True, desc="Augmenting Dataset")                                # download and augment the data with the five_crop method
train_dataset = TrainDataset(augmented_dataset)                                                     # prepare the train dataset for loading PyTorch DataLoader
eval_dataset = EvalDataset(load_dataset('eugenesiow/Div2k', 'bicubic_x4', split='validation'))      # prepare the eval dataset for the PyTorch DataLoader

Downloading:   0%|          | 0.00/6.23k [00:00<?, ?B/s]

Downloading and preparing dataset div2k/bicubic_x4 (download: Unknown size, generated: Unknown size, post-processed: Unknown size, total: Unknown size) to /root/.cache/huggingface/datasets/div2k/bicubic_x4/2.0.0/d7599f94c7e662a3eed3547efc7efa52b2ed71082b40fc2e42a693870e35b677...


Downloading:   0%|          | 0.00/247M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/31.5M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/3.53G [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/449M [00:00<?, ?B/s]

0 examples [00:00, ? examples/s]

0 examples [00:00, ? examples/s]

Dataset div2k downloaded and prepared to /root/.cache/huggingface/datasets/div2k/bicubic_x4/2.0.0/d7599f94c7e662a3eed3547efc7efa52b2ed71082b40fc2e42a693870e35b677. Subsequent calls will reuse this data.


Augmenting Dataset:   0%|          | 0/1 [00:00<?, ?ba/s]

Reusing dataset div2k (/root/.cache/huggingface/datasets/div2k/bicubic_x4/2.0.0/d7599f94c7e662a3eed3547efc7efa52b2ed71082b40fc2e42a693870e35b677)


We then train the model on the GPU.

**NOTE:** Remember to set the `num_train_epochs` to 1000 (or more). We set the `num_train_epochs=5` for quick testing in the notebook.

In [4]:
from super_image import Trainer, TrainingArguments, EdsrModel, EdsrConfig

training_args = TrainingArguments(
    output_dir='./results',                 # output directory
    num_train_epochs=5,                  # total number of training epochs
)

config = EdsrConfig(
    scale=4,                                # train a model to upscale 4x
)
model = EdsrModel(config)

trainer = Trainer(
    model=model,                         # the instantiated model to be trained
    args=training_args,                  # training arguments, defined above
    train_dataset=train_dataset,         # training dataset
    eval_dataset=eval_dataset             # evaluation dataset
)

trainer.train()

  0%|          | 0/4000 [00:00<?, ?it/s]

epoch: psnr: 27.624657   ssim: 0.7542
best epoch: 0, psnr: 27.624657, ssim: 0.754153


  0%|          | 0/4000 [00:00<?, ?it/s]

epoch: psnr: 28.291857   ssim: 0.7798
best epoch: 1, psnr: 28.291857, ssim: 0.779772


  0%|          | 0/4000 [00:00<?, ?it/s]

epoch: psnr: 28.557993   ssim: 0.7866
best epoch: 2, psnr: 28.557993, ssim: 0.786642


  0%|          | 0/4000 [00:00<?, ?it/s]

epoch: psnr: 28.008018   ssim: 0.7932


  0%|          | 0/4000 [00:00<?, ?it/s]

epoch: psnr: 28.679325   ssim: 0.7933
best epoch: 4, psnr: 28.679325, ssim: 0.793304


We see, after the training, the PSNR and SSIM scores of the best model on the validation set. (After just 5 epochs, PSNR is **28.68** while SSIM is **0.7933**) 

We can also train other types of [models](https://eugenesiow.github.io/super-image/models/msrn).

In [6]:
from super_image import Trainer, TrainingArguments, MsrnModel, MsrnConfig

training_args = TrainingArguments(
    output_dir='./results_msrn',         # output directory
    num_train_epochs=2,                  # total number of training epochs
)

config = MsrnConfig(
    scale=4,                                # train a model to upscale 4x
    bam=True,                               # use balanced attention
)
model = MsrnModel(config)

trainer = Trainer(
    model=model,                         # the instantiated model to be trained
    args=training_args,                  # training arguments, defined above
    train_dataset=train_dataset,         # training dataset
    eval_dataset=eval_dataset            # evaluation dataset
)

trainer.train()

  0%|          | 0/4000 [00:00<?, ?it/s]

epoch: psnr: 26.820234   ssim: 0.7262
best epoch: 0, psnr: 26.820234, ssim: 0.726203


  0%|          | 0/4000 [00:00<?, ?it/s]

epoch: psnr: 27.123207   ssim: 0.7340
best epoch: 1, psnr: 27.123207, ssim: 0.734009


After training just 2 epochs, PSNR is **27.12** while SSIM is **0.7940**) 