In [None]:
import torch
import torchvision

In [None]:
import os
from tqdm import tqdm

# Dataset

In [None]:
from torch.utils.data import Dataset, DataLoader
import json

from torchvision import transforms 
import torchvision.transforms.functional as TF

import random


class ShabbyPagesDataset(Dataset):

    def __init__(self, split="train", augment=False, crop_ratio=0.8, rotation_angle=90, flip_probability=0.5):
        super().__init__()

        self.augment = augment
        self.crop_ratio = crop_ratio
        self.rotation_angle = rotation_angle
        self.flip_probability = flip_probability


        self.folder_shabby = f"{split}/{split}/{split}_shabby/"
        self.folder_clean = f"{split}/{split}/{split}_cleaned/"

        self.image_dict = {}
        for idx, image_name in enumerate(os.listdir(self.folder_shabby)):
            self.image_dict[idx] = image_name

        with open(f"{split}_image_dict.json", "w") as f:
            json.dump(self.image_dict, f)

    def transform(self, input_img, target_img):
        
        # Random crop
        i, j, h, w = transforms.RandomCrop.get_params(
            input_img, output_size=(int(400*self.crop_ratio), int(400*self.crop_ratio)))
        input_img = TF.crop(input_img, i, j, h, w)
        target_img = TF.crop(target_img, i, j, h, w)

        # Random horizontal flipping
        if random.random() < self.flip_probability:
            input_img = TF.hflip(input_img)
            target_img = TF.hflip(target_img)

        # Random vertical flipping
        if random.random() < self.flip_probability:
            input_img = TF.vflip(input_img)
            target_img = TF.vflip(target_img)

        # Random Rotation 
        angle = random.randint(-90,+90)
        input_img = TF.rotate(input_img, angle)
        target_img = TF.rotate(target_img, angle)

        return input_img, target_img

    def __len__(self):
        return len(self.image_dict)

    def __getitem__(self, idx):
        input_sample = (
            torchvision.io.read_image(
                os.path.join(self.folder_shabby, self.image_dict[idx])
            ).to(dtype=torch.float32)
        )
        target_sample = (
            torchvision.io.read_image(
                os.path.join(self.folder_clean, self.image_dict[idx])
            ).to(dtype=torch.float32)
        )

        if self.augment:
            return self.transform(input_sample,target_sample)
        else:
            return input_sample, target_sample

In [None]:
train_dataset = ShabbyPagesDataset(split="train", augment=True)
train_dataloader = DataLoader(
    dataset=train_dataset, batch_size=16, shuffle=True, num_workers=8
)

test_dataset = ShabbyPagesDataset(split="train")
test_dataloader = DataLoader(
    dataset=test_dataset, batch_size=16, shuffle=True, num_workers=8
)


validate_dataset = ShabbyPagesDataset(split="train")
validate_dataloader = DataLoader(
    dataset=validate_dataset, batch_size=16, shuffle=True, num_workers=8
)

# deep learning model

In [None]:
from typing import Any
from torch.nn import (
    Conv2d,
    ConvTranspose2d,
    MaxPool2d,
    MaxUnpool2d,
    BatchNorm2d,
    MSELoss,
    ReLU,
    Sigmoid,
    Dropout2d
)

import lightning


class DenoisingNet(lightning.LightningModule):
    def __init__(self):
        super().__init__()

        self.encoder = torch.nn.Sequential(
            BatchNorm2d(num_features=1),
            Conv2d(in_channels=1, out_channels=8, kernel_size=5, padding=2),
            BatchNorm2d(num_features=8),
            Dropout2d(p=0.3),
            ReLU(),
            # MaxPool2d(2, 2),
            Conv2d(in_channels=8, out_channels=16, kernel_size=5, padding=2),
            BatchNorm2d(num_features=16),
            Dropout2d(p=0.3),
            ReLU(),
            # MaxPool2d(2, 2, return_indices=True),
            # Conv2d(in_channels=16, out_channels=16, kernel_size=5),
            # BatchNorm2d(num_features=16),
            # ReLU(),
            # MaxPool2d(2, 2, return_indices=True),
        )

        self.decoder = torch.nn.Sequential(
            # MaxUnpool2d(2, 2),
            # ConvTranspose2d(in_channels=16, out_channels=16, kernel_size=5),
            # BatchNorm2d(num_features=16),
            # ReLU(),
            # MaxUnpool2d(2, 2),
            ConvTranspose2d(in_channels=16, out_channels=8, kernel_size=5, padding=2),
            BatchNorm2d(num_features=8),
            Dropout2d(p=0.3),
            ReLU(),
            # MaxUnpool2d(2, 2),
            ConvTranspose2d(in_channels=8, out_channels=1, kernel_size=5, padding=2),
            BatchNorm2d(num_features=1),
            Sigmoid(),
        )

        self.mse_loss = MSELoss()

    def forward(self, x):
        x = x / 255.0
        x = self.encoder(x)
        x = self.decoder(x)
        x = x * 255.0

        return x

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        x, y = batch
        y_hat = self.forward(x)
        loss = self.mse_loss(y_hat, y)
        self.log("mse_train", loss, prog_bar=True, on_epoch=True, on_step=False)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)
        loss = self.mse_loss(y_hat, y)
        self.log("mse_val", loss, prog_bar=True, on_epoch=True, on_step=False)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.005)
        return optimizer


# model
model = DenoisingNet()

In [None]:
# model testing

for x,y in train_dataloader:
    model(x)
    break

In [None]:
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping
from lightning.pytorch.loggers import WandbLogger


early_stopping = EarlyStopping(
    # monitor='valid_mse',
    monitor='mse_val',
    patience=25,
    mode='min'
)

checkpoint_callback = ModelCheckpoint(
    dirpath='saved_models/', 
    filename='best_model',
    monitor='mse_val',
    mode='min'
)

wandb_logger = WandbLogger(project="shabby-pages")


trainer = lightning.Trainer(
    max_epochs=250, 
    # accelerator=device, 
    precision="16-mixed", 
    logger=wandb_logger,
    callbacks=[
        early_stopping,
        checkpoint_callback
    ]
)

In [None]:


# train model

trainer.fit(model=model, train_dataloaders=train_dataloader, val_dataloaders=validate_dataloader)