In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, ConcatDataset, Subset
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
from PIL import Image, ImageEnhance, ImageChops
import numpy as np
from tqdm.notebook import tqdm
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from glob import glob
import re
from pytorch_msssim import SSIM
import random

In [None]:
# Constants
NUM_EPOCHS = 300
BATCH_SIZE = 2
LEARNING_RATE = .001
print_num = 500

In [None]:
# Functions
def setup_paths():
    paths = {
        'Image_Folder': 'Images',
        'Model_Folder': 'models',
        'Result_Folder': 'Results',
        'Data_Folder': 'data',
        'Data_SubFolder': 'Church_Music_Images_Downscaled',
        'SubFolder': "Watermark_Removal",
        'original_watermarked_path': 'data/Actual_Images/Low_Res_Watermark',
        'original_unwatermarked_path': 'data/Actual_Images/Low_Res_Shifted'
    }

    paths['original_watermarked_paths'] = [os.path.join(paths['original_watermarked_path'], image) for image in os.listdir(paths['original_watermarked_path'])]
    paths['original_unwatermarked_paths'] = [os.path.join(paths['original_unwatermarked_path'], image) for image in os.listdir(paths['original_unwatermarked_path'])]
    
    
    for folder in ['Image_Folder', 'Model_Folder', 'Result_Folder']:
        dir_path = os.path.join(paths[folder], paths['SubFolder'])
        os.makedirs(dir_path, exist_ok=True)

    return paths

def prepare_datasets(dirs):
    # Sort paths first
    original_unwatermarked_paths = sorted(dirs['original_unwatermarked_paths'], key=lambda x: os.path.basename(x))
    original_watermarked_paths = sorted(dirs['original_watermarked_paths'], key=lambda x: os.path.basename(x))

    # Pair and split original images
    train_pairs, val_pairs = train_test_split(list(zip(original_unwatermarked_paths, original_watermarked_paths)), test_size=0.1)
    
    # Create datasets
    train_dataset = OriginalWatermarkDataset(*zip(*train_pairs))
    val_dataset = OriginalWatermarkDataset(*zip(*val_pairs))

    return train_dataset, val_dataset

def create_datasets(train_dataset, val_dataset, batch_size):
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
    return train_dataloader, val_dataloader

def PIL_to_tensor(path):
    image = Image.open(path).convert('L')
    to_tensor = transforms.ToTensor()
    image_tensor = to_tensor(image)
    return image_tensor

def find_last_saved_epoch(paths):
    saved_models = glob(os.path.join(paths['Model_Folder'], paths['SubFolder'], 'model_epoch_*.pth'))
    if saved_models:
        last_saved_model = max(saved_models, key=os.path.getctime)
        epoch_number = int(re.findall(r'model_epoch_(\d+).pth', last_saved_model)[0])
        return epoch_number, last_saved_model
    else:
        return -1, None

def display_images(original, watermarked, outputs, dataloader, model, device):
    test_image_unwatermarked, test_image_watermarked = next(iter(dataloader))
    test_image_unwatermarked, test_image_watermarked = test_image_unwatermarked.to(device).unsqueeze(0), test_image_watermarked.to(device).unsqueeze(0)
    model.eval()
    with torch.inference_mode():
        if test_image_watermarked.shape[1] == 1:
            test_image_watermarked = test_image_watermarked.repeat(1, 3, 1, 1)
        test_image_output = model(test_image_watermarked.squeeze(0))
    model.train()
    
    plt.figure(figsize=(50 ,30))
    plt.subplot(1, 6, 1)
    plt.imshow(original[0].cpu().numpy().transpose(1,2,0), cmap='gray')
    plt.title(f"Train")
    plt.subplot(1, 6, 2)
    plt.imshow(watermarked[0].cpu().numpy().transpose(1,2,0), cmap='gray')
    plt.title(f"Train Watermarked")
    plt.subplot(1, 6, 3)
    plt.imshow(outputs[0].cpu().detach().numpy().transpose(1,2,0), cmap='gray')
    plt.title(f"Train Processed")
    plt.subplot(1, 6, 4)
    plt.imshow(test_image_unwatermarked[0][0].cpu().detach().squeeze().numpy(), cmap='gray')
    plt.title(f"Test")
    plt.subplot(1, 6, 5)
    plt.imshow(test_image_watermarked[0][0].cpu().detach().squeeze().numpy(), cmap='gray')
    plt.title(f"Test Watermarked")
    plt.subplot(1, 6, 6)
    plt.imshow(test_image_output[0][0].cpu().detach().squeeze().numpy(), cmap='gray')
    plt.title(f"Test Processed")
    plt.show()

def train_model(train_dataloader, val_dataloader, model, optimizer, loss_fn, num_epochs, batch_size, paths, device, print_num):
    writer = SummaryWriter()
    
    start_epoch, last_saved_model = find_last_saved_epoch(paths)
    if last_saved_model:
        model.load_state_dict(torch.load(last_saved_model))
        print(f"Resuming training from epoch {start_epoch + 1} using {last_saved_model}")

    train_loss = 0
    val_loss = 0
        
    for epoch in tqdm(range(start_epoch + 1, num_epochs)):
        
        # Training
        model.train()
        train_loss = 0
        for batch_idx, (original, watermarked) in tqdm(enumerate(train_dataloader), total=len(train_dataloader)):
            original, watermarked = original.to(device), watermarked.to(device)
            outputs = model(watermarked)
            if original.shape[1] == 1:
                original = original.repeat(1, 3, 1, 1)
            if outputs.shape[1] == 1:
                outputs = outputs.repeat(1, 3, 1, 1)
            loss = loss_fn(outputs, original)
            batch_loss = loss.item()
            train_loss += batch_loss
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            writer.add_scalar('Loss/train', batch_loss, epoch * len(train_dataloader) + batch_idx)            
            
            print(f"\r Epoch: {epoch} | Batch: {batch_idx} | Train Batch loss: {batch_loss:.8f}", end="")
            if batch_idx % print_num == 0:
                display_images(original, watermarked, outputs, val_dataloader, model, device) 
                
        train_loss /= len(train_dataloader)
        print(f"\n Epoch: {epoch} | Train loss: {train_loss:.8f}")
        writer.add_scalar('Loss/train', train_loss, epoch)

        # Validation
        model.eval()
        val_loss = 0
        with torch.inference_mode():
            for batch_idx, (original, watermarked) in tqdm(enumerate(val_dataloader), total=len(val_dataloader)):
                original, watermarked = original.to(device), watermarked.to(device)
                outputs = model(watermarked)
                if original.shape[1] == 1:
                    original = original.repeat(1, 3, 1, 1)
                if outputs.shape[1] == 1:
                    outputs = outputs.repeat(1, 3, 1, 1)
                loss = loss_fn(outputs, original)
                batch_loss = loss.item()
                val_loss += batch_loss
                writer.add_scalar('Loss/validation', batch_loss, epoch * len(val_dataloader) + batch_idx)
                print(f"\r Epoch: {epoch} | Batch: {batch_idx} | Val Batch loss: {batch_loss:.8f}", end="")
                
        val_loss /= len(val_dataloader)
        print(f"\n Epoch: {epoch} | Val loss: {val_loss:.8f}")
        writer.add_scalar('Loss/validation', val_loss, epoch)
        
        model_path = os.path.join(paths['Model_Folder'], paths['SubFolder'], f'model_epoch_{epoch}.pth')
        torch.save(model.state_dict(), model_path)
        
        print(f"\n Epoch: {epoch} | Train loss: {train_loss:.8f} | Val loss: {val_loss:.8f}")
        
    print(f"Training Complete: Train loss: {train_loss:.8f} | Val loss: {val_loss:.8f}\n")
    writer.close()

In [None]:
# Classes
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import vgg19
# from torchvision.models.vgg import VGG19Weights  # Import the weights enum

class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()

        # Encoder
        self.enc1 = self.conv_block(1, 32)
        self.enc2 = self.conv_block(32, 64)
        self.enc3 = self.conv_block(64, 128)
        self.enc4 = self.conv_block(128, 256)
        self.enc5 = self.conv_block(256, 512)

        # Middle
        self.middle = nn.Sequential(
            self.conv_block(512, 1024),
            nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        )

        # Decoder
        self.dec5 = self.conv_block(512 + 512, 512) 
        self.dec4 = self.conv_block(512 + 256, 256)
        self.dec3 = self.conv_block(256 + 128, 128)
        self.dec2 = self.conv_block(128 + 64, 64)
        self.dec1 = self.conv_block(64 + 32, 32)

        # Final Layer
        self.final_conv = nn.Conv2d(32, 1, kernel_size=1)

    def conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        # Encoder
        enc1 = self.enc1(x)
        enc2 = self.enc2(F.max_pool2d(enc1, 2))
        enc3 = self.enc3(F.max_pool2d(enc2, 2))
        enc4 = self.enc4(F.max_pool2d(enc3, 2))
        enc5 = self.enc5(F.max_pool2d(enc4, 2))

        # Middle
        middle = self.middle(F.max_pool2d(enc5, 2))
        enc5 = enc5[:, :, :middle.shape[2], :]
        middle = middle + enc5 # Skip connection

        # Decoder
        dec5 = self.dec5(torch.cat([F.interpolate(middle, size=enc5.shape[2:]), enc5], dim=1))
        dec4 = self.dec4(torch.cat([F.interpolate(dec5, size=enc4.shape[2:]), enc4], dim=1))
        dec3 = self.dec3(torch.cat([F.interpolate(dec4, size=enc3.shape[2:]), enc3], dim=1))
        dec2 = self.dec2(torch.cat([F.interpolate(dec3, size=enc2.shape[2:]), enc2], dim=1))
        dec1 = self.dec1(torch.cat([F.interpolate(dec2, size=enc1.shape[2:]), enc1], dim=1))

        # Final Layer
        final_output = self.final_conv(dec1)
        return final_output.clamp(0,1)

class OriginalWatermarkDataset(Dataset):
    def __init__(self, unwatermarked_images, watermarked_images):
        self.unwatermarked_images = unwatermarked_images
        self.watermarked_images = watermarked_images
        
        self.transform = transforms.Compose([
            transforms.Resize((792, 612)),
            transforms.Grayscale(),
            transforms.ToTensor()
        ])

    def __len__(self):
        return len(self.unwatermarked_images)

    def __getitem__(self, idx):
        unwatermarked_img_path = self.unwatermarked_images[idx]
        watermarked_img_path = self.watermarked_images[idx]
    
        # Load the original and watermarked images
        unwatermarked_img = Image.open(unwatermarked_img_path)
        watermarked_img = Image.open(watermarked_img_path)

        # Apply transforms
        unwatermarked_img_tensor = self.transform(unwatermarked_img)
        watermarked_img_tensor = self.transform(watermarked_img)
        
        return unwatermarked_img_tensor, watermarked_img_tensor
    
class CombinedLoss(nn.Module):
    def __init__(self, alpha=1.0, beta=0.5):
        super(CombinedLoss, self).__init__()
        self.l1_loss = nn.L1Loss()
        self.vgg = vgg19(pretrained=True).features
        for param in self.vgg.parameters():
            param.requires_grad = False
        self.alpha = alpha
        self.beta = beta

    def forward(self, input, target):
        l1 = self.l1_loss(input, target)
        vgg_input = self.vgg(input)
        vgg_target = self.vgg(target)
        perceptual = self.l1_loss(vgg_input, vgg_target)
        
        return self.alpha * l1 + self.beta * perceptual

In [None]:
def main():
    paths = setup_paths()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = UNet().to(device)
    optimizer = optim.Adam(model.parameters(), 
                           lr=LEARNING_RATE)
    
    loss_fn = CombinedLoss(alpha=0.5).to(device)
    
    train_combined_dataset, val_combined_dataset = prepare_datasets(paths)
    
    train_dataloader, val_dataloader = create_datasets(train_combined_dataset, val_combined_dataset, BATCH_SIZE)
    
    train_model(train_dataloader, 
                val_dataloader, 
                model, 
                optimizer, 
                loss_fn, 
                NUM_EPOCHS,
                BATCH_SIZE,
                paths, 
                device, 
                print_num)

In [None]:
if __name__ == "__main__":
    main()