In [1]:
import torch
import torchvision
import torch.nn as nn
import torchvision.models as models
from torch.optim.lr_scheduler import ExponentialLR
import torch.optim as optim
import torch.nn.functional as F
from tqdm.auto import tqdm
import os
from torchvision import transforms
import numpy as np
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import random
from collections import Counter
import matplotlib.pyplot as plt

# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# Function to extract patches from images
def extract_patches(image, patch_size=64, stride=32):
    patches = []
    _, h, w = image.shape
    for i in range(0, h - patch_size + 1, stride):
        for j in range(0, w - patch_size + 1, stride):
            patch = image[:, i:i+patch_size, j:j+patch_size]
            patches.append(patch)
    return torch.stack(patches)

# Custom Dataset to load patches
class PatchDataset(Dataset):
    def __init__(self, root_dir, transform=None, patch_size=64, stride=32):
        self.dataset = datasets.ImageFolder(root=root_dir, transform=transform)
        self.patch_size = patch_size
        self.stride = stride

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        image, label = self.dataset[idx]
        patches = extract_patches(image, patch_size=self.patch_size, stride=self.stride)
        return patches, label

# Function to create DataLoader
def create_patch_dataloader(data_dir, batch_size=32, num_workers=4, patch_size=64, stride=32):
    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
    ])

    dataset = PatchDataset(root_dir=data_dir, transform=transform, patch_size=patch_size, stride=stride)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, collate_fn=lambda x: (torch.cat([item[0] for item in x], dim=0), [item[1] for item in x]))
    
    return dataloader

# Example usage
data_dir = 'D:/Forchheim - Copy'  # Ensure this path is correct
dataloader = create_patch_dataloader(data_dir, batch_size=1)

# Print number of images loaded
print(f"Number of batches: {len(dataloader)}")

# Function for PGD attack
def pgd_attack(model, images, labels, eps=0.3, alpha=2/255, iters=40):
    adv_images = images.clone().detach().requires_grad_(True)
    for _ in range(iters):
        outputs = model(adv_images)
        loss = nn.CrossEntropyLoss()(outputs, labels)
        model.zero_grad()
        loss.backward()
        grad = adv_images.grad.data
        adv_images = adv_images + alpha * grad.sign()
        adv_images = torch.clamp(adv_images, images - eps, images + eps)
        adv_images = torch.clamp(adv_images, 0, 1)
        adv_images = adv_images.detach().requires_grad_(True)
    return adv_images

# Define the PDN (Patch Discriminator Network)
class PDN(nn.Module):
    def __init__(self):
        super(PDN, self).__init__()
        
        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
            
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
            
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
            
            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2)
        )
        
        # Decoder
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            
            nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2, padding=1),
            nn.Sigmoid()
        )
        
    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded

# Define the Feature Extractor using ResNet-18
class FeatureExtractor(nn.Module):
    def __init__(self):
        super(FeatureExtractor, self).__init__()
        resnet = models.resnet18(pretrained=True)
        self.features = nn.Sequential(*list(resnet.children())[:-1])
        
    def forward(self, x):
        return self.features(x).squeeze()

# Initialize models
feature_extractor = FeatureExtractor()
PDN_model = PDN()

# Train the Feature Extractor and PDN with patches
def train_feature_extractor_and_pdn(feature_extractor, pdn_model, patch_dataloader, device):
    feature_extractor.to(device)
    pdn_model.to(device)
    
    # Optimizers and loss functions
    feature_optimizer = optim.Adam(feature_extractor.parameters(), lr=0.001)
    pdn_optimizer = optim.Adam(pdn_model.parameters(), lr=0.001)
    scheduler = optim.lr_scheduler.ExponentialLR(feature_optimizer, gamma=0.97)
    loss_fn = nn.KLDivLoss(reduction='batchmean')

    # Training feature extractor
    for epoch in range(20):
        feature_extractor.train()
        total_loss = 0.0
        
        for patches, _ in patch_dataloader:
            patches = patches.to(device)
            features = feature_extractor(patches)
            
            # Forward pass through PDN model
            reconstructed_patches = pdn_model(patches)
            
            # Compute loss
            loss = loss_fn(reconstructed_patches, patches)
            total_loss += loss.item()
            
            # Backpropagation
            feature_optimizer.zero_grad()
            loss.backward()
            feature_optimizer.step()
        
        scheduler.step()
        print(f"Feature Extractor Epoch [{epoch + 1}/20], Loss: {total_loss / len(patch_dataloader):.4f}")

    # Training PDN model
    for epoch in range(10):
        pdn_model.train()
        total_loss = 0.0
        
        for patches, _ in patch_dataloader:
            patches = patches.to(device)
            
            # Generate adversarial examples
            adv_patches = pgd_attack(pdn_model, patches, labels=torch.zeros(patches.size(0), dtype=torch.long).to(device), eps=0.3, alpha=2/255, iters=40)
            
            # Forward pass through PDN model
            reconstructed_patches = pdn_model(patches)
            adv_reconstructed_patches = pdn_model(adv_patches)
            
            # Compute loss on clean and adversarial examples
            loss = (loss_fn(reconstructed_patches, patches) + loss_fn(adv_reconstructed_patches, adv_patches)) / 2
            total_loss += loss.item()
            
            # Backpropagation
            pdn_optimizer.zero_grad()
            loss.backward()
            pdn_optimizer.step()
        
        print(f"PDN Model Epoch [{epoch + 1}/10], Loss: {total_loss / len(patch_dataloader):.4f}")

# Function for major voting
def major_voting(patches, model, device):
    model.eval()
    with torch.no_grad():
        patch_labels = []
        for patch in patches:
            patch = patch.unsqueeze(0).to(device)
            features = model(patch)
            # Assuming a classifier is used to predict the class from features
            label = torch.argmax(features, dim=1).item()
            patch_labels.append(label)
        
        # Aggregate results using majority voting
        most_common_label, _ = Counter(patch_labels).most_common(1)[0]
        return most_common_label

# Function to evaluate model
def evaluate_model(patches_per_image, model, device, true_labels):
    image_level_labels = []
    for patches in patches_per_image:
        predicted_label = major_voting(patches, model, device)
        image_level_labels.append(predicted_label)
    
    # Evaluate accuracy
    accuracy = np.mean([pred == true_label for pred, true_label in zip(image_level_labels, true_labels)])
    print(f"Image Level Accuracy: {accuracy:.4f}")

# Function to print images in a 7x7 grid
def print_images_in_grid(dataloader, num_images=49):
    images, _ = next(iter(dataloader))
    images = images[0:num_images]
    images = images.permute(0, 2, 3, 1)
    images = (images * 255).byte().cpu().numpy()
    
    fig, axes = plt.subplots(7, 7, figsize=(10, 10))
    for i, ax in enumerate(axes.flat):
        ax.imshow(images[i])
        ax.axis('off')
    plt.show()

# Example usage of the print_images_in_grid function
print_images_in_grid(dataloader)

# Train the models
train_feature_extractor_and_pdn(feature_extractor, PDN_model, dataloader, device)


  from .autonotebook import tqdm as notebook_tqdm


Using device: cpu
Number of batches: 1440




AttributeError: Can't pickle local object 'create_patch_dataloader.<locals>.<lambda>'