In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
from pathlib import Path
import warnings

import lightning as pl
from PIL import Image
import torch
from torch.utils.data import DataLoader

from dataset import WatermarkedDataset
from trainer import DifficultyScheduler, WatermarkRemovalModel

In [None]:
warnings.filterwarnings('ignore', category=Image.DecompressionBombWarning)

In [3]:
torch.set_float32_matmul_precision(precision='high')

In [4]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'


device

'cuda'

In [5]:
def calculate_difficulty_step(
    initial_difficulty: float,
    max_difficulty: float,
    max_epochs: int,
) -> float:
    """
    Calculate the step size for difficulty increase based on epochs.

    Parameters
    ----------
    initial_difficulty: float
        Starting difficulty level between ``0`` and ``1``
    max_difficulty: float
        Target difficulty level between ``0`` and ``1``
    max_epochs: int
        Number of epochs to train for

    Returns
    -------
    step_size: float
        Amount to increase difficulty by each epoch

    """
    total_difficulty_increase = max_difficulty - initial_difficulty
    return total_difficulty_increase / max_epochs

In [6]:
initial_difficulty = 0.25
max_difficulty = 0.95
max_epochs = 75

image_size = 512

In [7]:
# calculate difficulty step size
difficulty_step = calculate_difficulty_step(
    initial_difficulty=initial_difficulty,
    max_difficulty=max_difficulty,
    max_epochs=max_epochs,
)

In [8]:
common_dataset_kwargs = {
    'difficulty': initial_difficulty,
    'image_size': image_size,
}
common_dataloader_kwargs = {
    # 'batch_size': 16,
    'batch_size': 24,
    'num_workers': os.cpu_count(),
}

train_dataset = WatermarkedDataset(
    root_dir='/home/nathancooperjones/Desktop/imagenet-1k/train_images/',
    **common_dataset_kwargs,
)
val_dataset = WatermarkedDataset(
    root_dir='/home/nathancooperjones/Desktop/imagenet-1k/test_images/',
    **common_dataset_kwargs,
)

train_dataloader = DataLoader(
    dataset=train_dataset,
    shuffle=True,
    **common_dataloader_kwargs,
)
val_dataloader = DataLoader(
    dataset=val_dataset,
    shuffle=False,
    **common_dataloader_kwargs,
)


len(train_dataloader), len(val_dataloader)

(1875, 209)

In [9]:
model = WatermarkRemovalModel(
    learning_rate=5e-4,
)

In [10]:
# create trainer
trainer = pl.Trainer(
    max_epochs=max_epochs,
    accelerator=device,
    logger=pl.pytorch.loggers.CSVLogger(save_dir='logs'),
    enable_checkpointing=True,
    benchmark=True,
    deterministic=False,
    callbacks=[
        pl.pytorch.callbacks.ModelCheckpoint(
            dirpath='checkpoints',
            filename='watermark-removal-{epoch:02d}-{val_total_loss:.2f}',
            monitor='val_total_loss',
            mode='min',
            save_top_k=3,
        ),
        pl.pytorch.callbacks.EarlyStopping(
            monitor='val_total_loss',
            patience=3,
            mode='min',
        ),
        DifficultyScheduler(
            initial_difficulty=initial_difficulty,
            max_difficulty=max_difficulty,
            step_size=difficulty_step,
        ),
    ],
    precision='16-mixed',
)

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [None]:
# train the model
trainer.fit(
    model=model,
    train_dataloaders=train_dataloader,
    val_dataloaders=val_dataloader,
)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

   | Name       | Type       | Params | Mode 
---------------------------------------------------
0  | enc_1      | DoubleConv | 39.0 K | train
1  | enc_2      | DoubleConv | 221 K  | train
2  | enc_3      | DoubleConv | 886 K  | train
3  | enc_4      | DoubleConv | 3.5 M  | train
4  | enc_5      | DoubleConv | 14.2 M | train
5  | dec_5      | DoubleConv | 7.1 M  | train
6  | dec_4      | DoubleConv | 3.0 M  | train
7  | dec_3      | DoubleConv | 738 K  | train
8  | dec_2      | DoubleConv | 184 K  | train
9  | dec1       | DoubleConv | 110 K  | train
10 | pool       | MaxPool2d  | 0      | train
11 | upsample   | Upsample   | 0      | train
12 | final_conv | Conv2d     | 195    | train
13 | activation | Sigmoid    | 0      | train
14 | vgg        | Sequential | 1.7 M  | eval 
---------------------------------------------------
29.9 M    Trainable params
1.7 M     Non-trainable params
31.7 M    Total params
126.609   Total estimated model para

Sanity Checking: |                                                                                            …

Training: |                                                                                                   …



Validation: |                                                                                                 …



Validation: |                                                                                                 …



----- 