In [1]:
import torch
import os
from PIL import Image, ImageDraw, ImageFont
import random
import shutil
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.cuda.amp import autocast, GradScaler
import matplotlib.pyplot as plt

print(f"CUDA is available: {torch.cuda.is_available()}")
print(f"Amount of CUDA devices available: {torch.cuda.device_count()}")
print(f"Index of current CUDA device: {torch.cuda.current_device()}")
print(f"Name of current CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
print(f"Amount of CPU cores available: {os.cpu_count()}")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
cores = os.cpu_count()

  from .autonotebook import tqdm as notebook_tqdm


CUDA is available: True
Amount of CUDA devices available: 1
Index of current CUDA device: 0
Name of current CUDA device: Tesla V100-SXM3-32GB
Amount of CPU cores available: 96


In [2]:
def add_date_stamp_to_images(input_dir, output_dir):
    try:
        shutil.rmtree(output_dir)
        print(f"Directory '{output_dir}' has been removed successfully.")
    except OSError as error:
        print(error)

    # Create directories for clean and noisy (stamped) images
    clean_dir = os.path.join(output_dir, "clean")
    stamped_dir = os.path.join(output_dir, "stamped")
    os.makedirs(clean_dir, exist_ok=True)
    os.makedirs(stamped_dir, exist_ok=True)

    for subdir in ['training', 'validation', 'evaluation']:
        input_subdir = os.path.join(input_dir, subdir)
        clean_subdir = os.path.join(clean_dir, subdir)
        stamped_subdir = os.path.join(stamped_dir, subdir)
        os.makedirs(clean_subdir, exist_ok=True)
        os.makedirs(stamped_subdir, exist_ok=True)

        i = 0

        for image_name in os.listdir(input_subdir):
            if i >= 1025:
                break
                # break or pass to limit or not

            # Load image
            image_path = os.path.join(input_subdir, image_name)
            image = Image.open(image_path)

            # Save clean version
            clean_output_path = os.path.join(clean_subdir, image_name)
            image.save(clean_output_path)

            # Add date stamp to a copy
            draw = ImageDraw.Draw(image)
            date_text = f"{random.choice(months)} {random.randint(1, 28)}, 2024"
            font = ImageFont.load_default(32)  # Use default font
            text_width = draw.textlength(date_text, font=font)
            x, y = image.width - text_width - 10, image.height - 50
            draw.text((x, y), date_text, fill="white", font=font)

            # Save stamped version
            stamped_output_path = os.path.join(stamped_subdir, image_name)
            image.save(stamped_output_path)

            if i % 500 == 0:
                print(f"Date stamped: {stamped_output_path}")
            i += 1

    print(f"Paired dataset created at '{output_dir}'")


months = ['January', 'February', 'March', 'April', 'May', 'June', 'July', 'August', 'September', 'October', 'November',
          'December']

add_date_stamp_to_images("dataset", "dataset_with_paired_images")

Directory 'dataset_with_paired_images' has been removed successfully.
Date stamped: dataset_with_paired_images/stamped/training/2_394.jpg
Date stamped: dataset_with_paired_images/stamped/training/9_321.jpg
Date stamped: dataset_with_paired_images/stamped/training/7_107.jpg
Date stamped: dataset_with_paired_images/stamped/training/1_84.jpg
Date stamped: dataset_with_paired_images/stamped/training/9_386.jpg
Date stamped: dataset_with_paired_images/stamped/training/9_1018.jpg
Date stamped: dataset_with_paired_images/stamped/training/5_641.jpg
Date stamped: dataset_with_paired_images/stamped/training/1_139.jpg
Date stamped: dataset_with_paired_images/stamped/training/0_138.jpg
Date stamped: dataset_with_paired_images/stamped/training/8_680.jpg
Date stamped: dataset_with_paired_images/stamped/training/10_318.jpg
Date stamped: dataset_with_paired_images/stamped/training/9_555.jpg
Date stamped: dataset_with_paired_images/stamped/training/2_1186.jpg
Date stamped: dataset_with_paired_images/sta

In [None]:
class PreloadedDateStampedDataset(Dataset):
    def __init__(self, stamped_dir, clean_dir, transform):

        self.stamped_images = []
        self.clean_images = []

        file_count = sum([len(files) for _, _, files in os.walk("dataset_with_paired_images/stamped/")])
        print(f"Total files in dataset_with_paired_images/stamped/: {file_count}")

        file_count = sum([len(files) for _, _, files in os.walk("dataset_with_paired_images/clean/")])
        print(f"Total files in dataset_with_paired_images/clean/: {file_count}")

        i = 0

        # Apply transformations and load stamped images
        for stamped_file in sorted(os.listdir(stamped_dir)):
            stamped_path = os.path.join(stamped_dir, stamped_file)
            stamped_image = Image.open(stamped_path).convert("RGB")
            if transform:
                stamped_image = transform(stamped_image)
            self.stamped_images.append(stamped_image)
            i += 1
            if i % 1000 == 0:
                print(f"Loaded {i} stamped images into memory.")

        i = 0
        # Apply transformations and load clean images
        for clean_file in sorted(os.listdir(clean_dir)):
            clean_path = os.path.join(clean_dir, clean_file)
            clean_image = Image.open(clean_path).convert("RGB")
            if transform:
                clean_image = transform(clean_image)
            self.clean_images.append(clean_image)
            i += 1
            if i % 1000 == 0:
                print(f"Loaded {i} clean images into memory.")

        # Ensure equal number of stamped and clean images
        assert len(self.stamped_images) == len(self.clean_images), "Mismatch in paired dataset size"

    def __len__(self):
        return len(self.stamped_images)

    def __getitem__(self, idx):
        # Preprocessed images are already in memory
        stamped_image = self.stamped_images[idx]
        clean_image = self.clean_images[idx]
        return stamped_image, clean_image


# Transformations for input images
transform = transforms.Compose([
    transforms.Resize((448, 448)),
    transforms.ToTensor(),
])

# Initialize DataLoaders
train_loader = DataLoader(
    PreloadedDateStampedDataset(
        stamped_dir='dataset_with_paired_images/stamped/training',
        clean_dir='dataset_with_paired_images/clean/training',
        transform=transform
    ),
    batch_size=256, shuffle=True, num_workers=0, pin_memory=True
)

val_loader = DataLoader(
    PreloadedDateStampedDataset(
        stamped_dir='dataset_with_paired_images/stamped/validation',
        clean_dir='dataset_with_paired_images/clean/validation',
        transform=transform
    ),
    batch_size=256, shuffle=False, num_workers=0, pin_memory=True
)

eval_loader = DataLoader(
    PreloadedDateStampedDataset(
        stamped_dir='dataset_with_paired_images/stamped/evaluation',
        clean_dir='dataset_with_paired_images/clean/evaluation',
        transform=transform
    ),
    batch_size=256, shuffle=False, num_workers=0, pin_memory=True
)

for stamped_batch, clean_batch in train_loader:
    print(stamped_batch.shape, clean_batch.shape)
    break

Total files in dataset_with_paired_images/stamped/: 16643
Total files in dataset_with_paired_images/clean/: 16643
Loaded 1000 stamped images into memory.
Loaded 2000 stamped images into memory.
Loaded 3000 stamped images into memory.
Loaded 4000 stamped images into memory.
Loaded 5000 stamped images into memory.
Loaded 6000 stamped images into memory.
Loaded 7000 stamped images into memory.
Loaded 8000 stamped images into memory.
Loaded 9000 stamped images into memory.
Loaded 1000 clean images into memory.
Loaded 2000 clean images into memory.
Loaded 3000 clean images into memory.
Loaded 4000 clean images into memory.
Loaded 5000 clean images into memory.
Loaded 6000 clean images into memory.
Loaded 7000 clean images into memory.
Loaded 8000 clean images into memory.
Loaded 9000 clean images into memory.
Total files in dataset_with_paired_images/stamped/: 16643
Total files in dataset_with_paired_images/clean/: 16643
Loaded 1000 stamped images into memory.
Loaded 2000 stamped images int

In [None]:
class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()
        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),  # 512x512 -> 256x256
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),  # 256x256 -> 128x128
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),  # 128x128 -> 64x64
            nn.ReLU(),
        )
        # Decoder
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),  # 64x64 -> 128x128
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),  # 128x128 -> 256x256
            nn.ReLU(),
            nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1),  # 256x256 -> 512x512
            nn.Sigmoid(),  # Output in range [0, 1]
        )

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded


# Initialize Autoencoder
autoencoder = Autoencoder().to(device, non_blocking=True)


In [None]:
criterion = nn.MSELoss()
optimizer = optim.Adam(autoencoder.parameters(), lr=0.001)

patience = 2
best_val_loss = np.inf
early_stop_counter = 0
num_epochs = 300

scaler = GradScaler()

In [None]:
for epoch in range(num_epochs):
    autoencoder.train()
    train_loss = 0.0

    for stamped_images, clean_images in train_loader:
        stamped_images, clean_images = stamped_images.to(device, non_blocking=True), clean_images.to(device, non_blocking=True)

        # Forward pass
        optimizer.zero_grad()
        with autocast():
            outputs = autoencoder(stamped_images)
            loss = criterion(outputs, clean_images)  # Compute loss against clean images

        # Scale loss for FP16 stability
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        train_loss += loss.item()

    avg_train_loss = train_loss / len(train_loader)

    # Validation phase
    autoencoder.eval()
    val_loss = 0.0
    with torch.no_grad():
        for stamped_images, clean_images in val_loader:
            # Move data to device
            stamped_images, clean_images = stamped_images.to(device, non_blocking=True), clean_images.to(device, non_blocking=True)

            outputs = autoencoder(stamped_images)
            loss = criterion(outputs, clean_images)
            val_loss += loss.item()

    avg_val_loss = val_loss / len(val_loader)
    print(f"Epoch [{epoch + 1}/{num_epochs}], Train Loss: {avg_train_loss:.6f}, Val Loss: {avg_val_loss:.6f}")

    # Early stopping check
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        early_stop_counter = 0

        # Save the best model
        torch.save(autoencoder.state_dict(), "good_autoencoder.pth")
    else:
        early_stop_counter += 1
        print(f"Early stopping patience counter: {early_stop_counter}/{patience}")
        if early_stop_counter >= patience:
            print("Early stopping triggered!")
            break


In [None]:
def show_images(input_images, reconstructed, n=4):
    plt.figure(figsize=(12, 4))  # Adjusted figure size to reduce spacing
    for i in range(n):
        # Stamped (Input) images
        plt.subplot(2, n, i + 1)
        input_img = input_images[i].permute(1, 2, 0).cpu().numpy()  # CHW -> HWC
        input_img = np.clip(input_img, 0, 1)  # Ensure pixel values are in [0, 1]
        plt.imshow(input_img)
        plt.axis('off')
        plt.title("Stamped (Input)")

        # Reconstructed (Output) images
        plt.subplot(2, n, i + 1 + n)
        reconstructed_img = reconstructed[i].permute(1, 2, 0).cpu().numpy()  # CHW -> HWC
        reconstructed_img = np.clip(reconstructed_img, 0, 1)  # Ensure pixel values are in [0, 1]
        plt.imshow(reconstructed_img)
        plt.axis('off')
        plt.title("Reconstructed (Output)")

    plt.tight_layout()
    plt.show()


# Evaluate on the evaluation dataset
stamped_images, reconstructed_images = [], []

with torch.no_grad():
    for stamped, clean in eval_loader:
        stamped = stamped.to(device, non_blocking=True)
        outputs = autoencoder(stamped)
        stamped_images.extend(stamped.cpu())
        reconstructed_images.extend(outputs.cpu())
        if len(stamped_images) >= 4:  # Limit to 4 examples
            break

# Visualize
show_images(stamped_images, reconstructed_images)