In [1]:
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ExponentialLR
from tqdm.auto import tqdm
import torch
from timeit import default_timer as timer
import requests
from pathlib import Path
import os
from PIL import Image
import torchvision.models as models

if Path("helper_functions.py").is_file():
    print("helper_functions.py already exists, skipping download")
else:
    print("Downloading helper_functions.py")
    request = requests.get("https://raw.githubusercontent.com/mrdbourke/pytorch-deep-learning/main/helper_functions.py")
    with open("helper_functions.py", "wb") as f:
        f.write(request.content)
from helper_functions import accuracy_fn

# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

  from .autonotebook import tqdm as notebook_tqdm


helper_functions.py already exists, skipping download
Using device: cpu


In [2]:
def load_images_with_labels(data_dir, batch_size, transform=None):
    """Loads images from a folder into a DataLoader with image names as labels.

    Args:
        data_dir (str): Path to the data directory.
        batch_size (int): Batch size for the DataLoader.
        transform (torchvision.transforms.Compose, optional): Data transformations. Defaults to None.

    Returns:
        torch.utils.data.DataLoader: A DataLoader instance.
    """

    if transform is None:
        transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor()
        ])

    image_paths = []
    for root, _, files in os.walk(data_dir):
        for file in files:
            if file.endswith('.jpg') or file.endswith('.jpeg') or file.endswith('.png'):
                image_path = os.path.join(root, file)
                image_paths.append(image_path)

    class ImageDataset(Dataset):
        def __init__(self, image_paths, transform=None):
            self.image_paths = image_paths
            self.transform = transform

        def __getitem__(self, index):
            image_path = self.image_paths[index]
            image = Image.open(image_path).convert("RGB")
            if self.transform:
                image = self.transform(image)
            return image, image_path.split('/')[-1]  # Use filename as label

        def __len__(self):
            return len(self.image_paths)

    dataset = ImageDataset(image_paths, transform=transform)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    return dataloader


In [3]:
# Load images
path = 'D:\\image_folder\\Images'
image_list = load_images_with_labels(path, batch_size=1)
print(f"Loaded {len(image_list)} images.")


Loaded 10 images.


In [4]:
def extract_labels_from_dataloader(dataloader):
    """Extracts labels from a DataLoader and creates a list of labels.

    Args:
        dataloader (torch.utils.data.DataLoader): The DataLoader containing data and labels.

    Returns:
        list: A list of labels extracted from the DataLoader.
    """

    labels = []
    for _, label in dataloader:
        labels.append(label)
    return labels


In [5]:
labels_list = extract_labels_from_dataloader(image_list)


In [7]:
def extract_patches_from_dataloader(dataloader, patch_size):
    """Extracts patches from images in a DataLoader.

    Args:
        dataloader (torch.utils.data.DataLoader): DataLoader containing images.
        patch_size (int): Size of the patches.

    Returns:
        list: A list of image patches.
    """

    all_patches = []
    for images, labels in dataloader:
        for image in images:
            _, h, w = image.shape
            for i in range(0, h, patch_size):
                for j in range(0, w, patch_size):
                    patch = image[:, i:i+patch_size, j:j+patch_size]
                    if patch.shape[1] == patch_size and patch.shape[2] == patch_size:  # Ensure patch size is consistent
                        all_patches.append(patch)
    return all_patches


In [8]:
patch_list = extract_patches_from_dataloader(image_list, patch_size=32)
print(f"Extracted {len(patch_list)} patches.")


Extracted 490 patches.


In [9]:
def create_patch_dataloader_with_labels(patches, labels, batch_size=49):
    """Creates a DataLoader for image patches with corresponding labels.

    Args:
        patches (list): List of image patches.
        labels (list): List of labels, where each label corresponds to 49 patches.
        batch_size (int): Batch size (default: 49 patches).

    Returns:
        torch.utils.data.DataLoader: A DataLoader for image patches with labels.
    """

    assert len(patches) % batch_size == 0, "Number of patches must be divisible by batch size"
    assert len(labels) == len(patches) // batch_size, "Number of labels must match number of patches"

    class PatchDatasetWithLabels(Dataset):
        def __init__(self, patches, labels):
            self.patches = patches
            self.labels = labels

        def __getitem__(self, index):
            start_idx = index * batch_size
            end_idx = start_idx + batch_size
            patches_batch = self.patches[start_idx:end_idx]
            label = self.labels[index]
            return torch.stack(patches_batch), label

        def __len__(self):
            return len(self.labels)

    dataset = PatchDatasetWithLabels(patches, labels)
    dataloader = DataLoader(dataset, batch_size=1, shuffle=False)
    return dataloader


In [10]:
# Create patch DataLoader
patch_loader = create_patch_dataloader_with_labels(patch_list, labels_list)
print(f"Created patch DataLoader with {len(patch_loader)} batches.")


Created patch DataLoader with 10 batches.


In [11]:
for (X, y) in patch_loader:
    print(f"y===={X.shape}")
    break


y====torch.Size([1, 49, 3, 32, 32])
y====torch.Size([1, 49, 3, 32, 32])
y====torch.Size([1, 49, 3, 32, 32])
y====torch.Size([1, 49, 3, 32, 32])
y====torch.Size([1, 49, 3, 32, 32])
y====torch.Size([1, 49, 3, 32, 32])
y====torch.Size([1, 49, 3, 32, 32])
y====torch.Size([1, 49, 3, 32, 32])
y====torch.Size([1, 49, 3, 32, 32])
y====torch.Size([1, 49, 3, 32, 32])


In [12]:
# # Define the PDN model
# class PDN(nn.Module):
#     def __init__(self):
#         super(PDN, self).__init__()

#         # Encoder layers
#         self.encoder = nn.Sequential(
#             nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),  # Input channels: 3 (RGB), output: 32 channels
#             nn.BatchNorm2d(32),
#             nn.ReLU(inplace=True),

#             nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
#             nn.BatchNorm2d(64),
#             nn.ReLU(inplace=True),
#             nn.MaxPool2d(2, 2),

#             nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
#             nn.BatchNorm2d(128),
#             nn.ReLU(inplace=True),
#             nn.MaxPool2d(2, 2),

#             nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
#             nn.BatchNorm2d(256),
#             nn.ReLU(inplace=True),
#             nn.MaxPool2d(2, 2),
#         )

#         # Decoder layers (use transposed convolutions for upsampling)
#         self.decoder = nn.Sequential(
#             nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
#             nn.BatchNorm2d(128),
#             nn.ReLU(inplace=True),

#             nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
#             nn.BatchNorm2d(64),
#             nn.ReLU(inplace=True),

#             nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
#             nn.BatchNorm2d(32),
#             nn.ReLU(inplace=True),

#             nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2, padding=1),
#             nn.Sigmoid()  # Output layer between 0 and 1 for image reconstruction
#         )

#     def forward(self, x):
#         encoded = self.encoder(x)
#         decoded = self.decoder(encoded)
#         return decoded


In [13]:
# Define the PDN (Patch Discriminator Network)
class PDN(nn.Module):
    def __init__(self):
        super(PDN, self).__init__()
        
        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
            
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
            
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
            
            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2)
        )
        
        # Decoder
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            
            nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2, padding=1),
            nn.Sigmoid()
        )
        
    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded


In [14]:
# Feature Extractor using ResNet-18
class FeatureExtractor(nn.Module):
    def __init__(self):
        super(FeatureExtractor, self).__init__()
        resnet = models.resnet18(pretrained=True)
        self.features = nn.Sequential(*list(resnet.children())[:-1])
        
    def forward(self, x):
        return self.features(x).squeeze()


In [15]:
# # Set random seed
# torch.manual_seed(42)

# # Initialize model, optimizer, and loss function
# PDN_model = PDN().to(device)

# feature_extrection_epoches = 20
# PDN_network_epoches = 10

# optimizer = optim.Adam(PDN_model.parameters(), lr=0.001)
# scheduler = ExponentialLR(optimizer, gamma=0.97)

# loss_fn = nn.MSELoss()  # Use MSELoss for reconstruction tasks


In [16]:
# Training Parameters
learning_rate = 0.001
feature_extrection_epoches = 20
PDN_network_epoches = 10

# Initialize models
feature_extractor = FeatureExtractor().to(device)
PDN_model = PDN().to(device)

# Define optimizer and loss function
optimizer = optim.Adam(list(feature_extractor.parameters()) + list(PDN_model.parameters()), lr=learning_rate)
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.97)
loss_fn = nn.MSELoss()



In [17]:
# # Define train step function
# def train_step(model: torch.nn.Module,
#                data_loader: torch.utils.data.DataLoader,
#                loss_fn: torch.nn.Module,
#                optimizer: torch.optim.Optimizer,
#                device: torch.device = device):
#     train_loss = 0
#     model.train()
#     for batch, (X, _) in enumerate(data_loader):
#         X = X.squeeze(0)  # Remove batch dimension
#         X = X.to(device)

#         # Forward pass
#         y_pred = model(X)

#         # Calculate loss
#         if y_pred.shape != X.shape:
#             y_pred = nn.functional.interpolate(y_pred, size=(X.shape[2], X.shape[3]), mode='bilinear', align_corners=False)

#         loss = loss_fn(y_pred, X)
#         train_loss += loss.item()

#         # Optimizer zero grad
#         optimizer.zero_grad()

#         # Loss backward
#         loss.backward()

#         # Optimizer step
#         optimizer.step()

#     train_loss /= len(data_loader)
#     print(f"Train loss: {train_loss:.5f}")


In [18]:
def train_model(feature_extractor, pdn_model, train_loader, optimizer, loss_fn, device):
    feature_extractor.train()
    pdn_model.train()
    
    for epoch in range(feature_extrection_epoches):
        total_loss = 0.0
        
        for images, _ in tqdm(train_loader):
            images = images.to(device).squeeze()
            
            # Feature extraction
            features = feature_extractor(images)
            
            # PDN forward pass
            reconstructed_images = pdn_model(images)
            
            # Compute loss
            loss = loss_fn(reconstructed_images, images)
            total_loss += loss.item()
            
            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        
        scheduler.step()
        print(f"Epoch [{epoch + 1}/{feature_extrection_epoches}], Loss: {total_loss / len(train_loader):.4f}")


In [19]:
# # Training loop
# train_time_start_on_cpu = timer()

# for epoch in tqdm(range(feature_extrection_epoches)):
#     print(f"Epoch number: {epoch+1}")
#     train_step(model=PDN_model, data_loader=patch_loader, optimizer=optimizer, loss_fn=loss_fn)


In [20]:
train_model(feature_extractor, PDN_model, patch_loader, optimizer, loss_fn, device)

100%|██████████| 10/10 [00:01<00:00,  7.89it/s]


Epoch [1/20], Loss: 0.0861


100%|██████████| 10/10 [00:01<00:00,  8.54it/s]


Epoch [2/20], Loss: 0.0686


100%|██████████| 10/10 [00:01<00:00,  9.50it/s]


Epoch [3/20], Loss: 0.0557


100%|██████████| 10/10 [00:00<00:00, 10.17it/s]


Epoch [4/20], Loss: 0.0490


100%|██████████| 10/10 [00:00<00:00, 10.21it/s]


Epoch [5/20], Loss: 0.0451


100%|██████████| 10/10 [00:00<00:00, 10.14it/s]


Epoch [6/20], Loss: 0.0422


100%|██████████| 10/10 [00:00<00:00, 10.26it/s]


Epoch [7/20], Loss: 0.0429


100%|██████████| 10/10 [00:00<00:00, 10.21it/s]


Epoch [8/20], Loss: 0.0429


100%|██████████| 10/10 [00:00<00:00, 10.23it/s]


Epoch [9/20], Loss: 0.0410


100%|██████████| 10/10 [00:00<00:00, 10.26it/s]


Epoch [10/20], Loss: 0.0393


100%|██████████| 10/10 [00:00<00:00, 10.17it/s]


Epoch [11/20], Loss: 0.0379


100%|██████████| 10/10 [00:00<00:00, 10.10it/s]


Epoch [12/20], Loss: 0.0359


100%|██████████| 10/10 [00:01<00:00,  9.90it/s]


Epoch [13/20], Loss: 0.0334


100%|██████████| 10/10 [00:00<00:00, 10.25it/s]


Epoch [14/20], Loss: 0.0320


100%|██████████| 10/10 [00:01<00:00,  9.42it/s]


Epoch [15/20], Loss: 0.0308


100%|██████████| 10/10 [00:01<00:00,  9.32it/s]


Epoch [16/20], Loss: 0.0306


100%|██████████| 10/10 [00:01<00:00,  9.18it/s]


Epoch [17/20], Loss: 0.0289


100%|██████████| 10/10 [00:01<00:00,  8.64it/s]


Epoch [18/20], Loss: 0.0277


100%|██████████| 10/10 [00:01<00:00,  9.75it/s]


Epoch [19/20], Loss: 0.0268


100%|██████████| 10/10 [00:01<00:00,  9.83it/s]

Epoch [20/20], Loss: 0.0257



