In [152]:
import torch
import torchvision
import torch.nn as nn
import torchvision.models as models
from torch.optim.lr_scheduler import ExponentialLR
import torch.optim as optim
import torch.nn.functional as F
from tqdm.auto import tqdm
import os
from torchvision import transforms
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import random
from collections import Counter
import torchattacks
# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")


Using device: cpu


In [124]:
def load_images_into_dataloader(data_dir, batch_size, transform=None):


  if transform is None:
      transform = transforms.Compose([
          transforms.ToTensor(),
          transforms.Resize((224,224)),
          # Add other transformations as needed
      ])

  dataset = torchvision.datasets.ImageFolder(root=data_dir, transform=transform)
  dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

  return dataloader


In [125]:

# Example usage
# data_dir = 'D:\Forchheim - Copy'  # Ensure this path is correct
data_dir = 'D:\image_folder'  # Ensure this path is correct
dataloader = load_images_into_dataloader(data_dir,1)


In [126]:
len(dataloader)

10

In [127]:
def print_first_images_with_labels(dataloader):
  """Prints the first image of each class with its label as the title in a 7x7 grid.

  Args:
      dataloader (torch.utils.data.DataLoader): The DataLoader containing images and labels.
  """

  class_to_idx = dataloader.dataset.class_to_idx
  idx_to_class = {v: k for k, v in class_to_idx.items()}

  # Create a dictionary to store the first image of each class
  first_images = {}

  for images, labels in dataloader:
      for image, label in zip(images, labels):
          label = label.item()  # Convert tensor to integer

          # Efficiently track seen classes (optional for large datasets)
          if label not in first_images:
              first_images[label] = image
              # Alternatively, for efficiency with large datasets:
              # break  # Uncomment this to stop after seeing all classes once

  # Create a 7x7 grid of images
  num_images = len(first_images)
  rows, cols = (num_images // 5) + (1 if num_images % 5 > 0 else 0), min(5, num_images)
  fig, axes = plt.subplots(rows, cols, figsize=(cols * 3, rows * 3))  # Adjust figsize as needed

  fig.tight_layout()

  row, col = 0, 0
  for label, image in first_images.items():
      axes[row, col].imshow(image.permute(1, 2, 0))  # Assuming CHW format
      axes[row, col].set_title(idx_to_class[label])
      axes[row, col].axis('off')
      col += 1
      if col == cols:
          row += 1
          col = 0

  plt.show()

In [128]:
# print_first_images_with_labels(dataloader)

In [163]:
def create_patch_dataloader(data_dir, batch_size=1, patch_size=64, stride=32, transform=None):
    if transform is None:
        transform = transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.ToTensor(),
        ])

    dataset = torchvision.datasets.ImageFolder(root=data_dir, transform=transform)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    return dataloader

In [164]:
# Extract patches
def extract_patches(image, patch_size=64, stride=32):
    patches = []
    _, h, w = image.shape
    for i in range(0, h - patch_size + 1, stride):
        for j in range(0, w - patch_size + 1, stride):
            patch = image[:, i:i+patch_size, j:j+patch_size]
            patches.append(patch)
    return torch.stack(patches)

In [165]:
patch_loader = create_patch_dataloader(data_dir)

In [169]:
len(patch_loader)

10

In [146]:
# Define the PDN (Patch Discriminator Network)
class PDN(nn.Module):
    def __init__(self):
        super(PDN, self).__init__()
        
        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
            
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
            
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
            
            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2)
        )
        
        # Decoder
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            
            nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2, padding=1),
            nn.Sigmoid()
        )
        
    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded

In [147]:
# Define the Feature Extractor using ResNet-18
class FeatureExtractor(nn.Module):
    def __init__(self):
        super(FeatureExtractor, self).__init__()
        resnet = models.resnet18(pretrained=True)
        self.features = nn.Sequential(*list(resnet.children())[:-1])
        
    def forward(self, x):
        return self.features(x).squeeze()

In [148]:
# Initialize models
feature_extractor = FeatureExtractor()
PDN_model = PDN()



In [153]:
attack = torchattacks.PGD(PDN_model, eps=1e-3, alpha=1/255, steps=10, random_start=True)

In [219]:
def train_feature_extractor_and_pdn(feature_extractor, pdn_model, dataloader, patch_loader, attack,device):
    feature_extractor.to(device)
    pdn_model.to(device)
    
    # Optimizers and loss functions
    feature_optimizer = optim.Adam(feature_extractor.parameters(), lr=0.001)
    pdn_optimizer = optim.Adam(pdn_model.parameters(), lr=0.001)
    scheduler = optim.lr_scheduler.ExponentialLR(feature_optimizer, gamma=0.97)
    # loss_fn = nn.KLDivLoss(reduction='batchmean')
    loss_fn = nn.MSELoss()

    # Training feature extractor
    # for epoch in tqdm(range(20)):
    #     feature_extractor.train()
    #     total_loss = 0.0
        
    #     for images in dataloader:
    #         images = images[0].to(device)
    #         features = feature_extractor(images)
            
    #         # Forward pass through PDN model
    #         reconstructed_images = pdn_model(images)
            
    #         # Compute loss
    #         # print(f"{reconstructed_images.shape}===={images.shape}")
    #         loss = loss_fn(reconstructed_images, images)
    #         total_loss += loss.item()
            
    #         # Backpropagation
    #         feature_optimizer.zero_grad()
    #         loss.backward()
    #         feature_optimizer.step()
        
    #     scheduler.step()
    #     print(f"Feature Extractor Epoch [{epoch + 1}/20], Loss: {total_loss / len(dataloader):.4f}")

    # Training PDN model
    for epoch in tqdm(range(10)):
        pdn_model.train()
        total_loss = 0.0
        
        for patches, labels in patch_loader:
            patches = patches.to(device)
            labels=labels.to(device)

            print(f"{patches.shape}===={labels.shape}")
            # Generate adversarial examples
            # adv_patches = attack(patches, labels=torch.zeros(patches.size(0), dtype=torch.long).to(device))
            adv_patches = attack(patches, labels)
            
            # Forward pass through PDN model
            reconstructed_patches = pdn_model(patches)
            adv_reconstructed_patches = pdn_model(adv_patches)
            
            # Compute loss on clean and adversarial examples
            loss = (loss_fn(reconstructed_patches, patches) + loss_fn(adv_reconstructed_patches, adv_patches)) / 2
            total_loss += loss.item()
            
            # Backpropagation
            pdn_optimizer.zero_grad()
            loss.backward()
            pdn_optimizer.step()
        
        print(f"PDN Model Epoch [{epoch + 1}/10], Loss: {total_loss / len(patch_loader):.4f}")

In [218]:
train_feature_extractor_and_pdn(feature_extractor, PDN_model,dataloader,patch_loader, attack,device)

  0%|          | 0/10 [00:00<?, ?it/s]

torch.Size([1, 3, 256, 256])====torch.Size([1, 1])





RuntimeError: only batches of spatial targets supported (3D tensors) but got targets of dimension: 2

In [None]:
def major_voting(patches, model, device):
    model.eval()
    with torch.no_grad():
        patch_labels = []
        for patch in patches:
            patch = patch.unsqueeze(0).to(device)
            features = model(patch)
            # Assuming a classifier is used to predict the class from features
            label = torch.argmax(features, dim=1).item()
            patch_labels.append(label)
        
        # Aggregate results using majority voting
        most_common_label, _ = Counter(patch_labels).most_common(1)[0]
        return most_common_label

In [None]:
def evaluate_model(patches_per_image, model, device):
    image_level_labels = []
    for patches in patches_per_image:
        predicted_label = major_voting(patches, model, device)
        image_level_labels.append(predicted_label)
    
    # Evaluate accuracy
    # Assuming `true_labels` is a list of true labels for each image
    accuracy = np.mean([pred == true_label for pred, true_label in zip(image_level_labels, true_labels)])
    print(f"Image Level Accuracy: {accuracy:.4f}")
