In [1]:
import os
import random
import torch
from torchvision import transforms, datasets
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import torch.nn as nn
import torchvision.models as models

In [2]:
# Define transformation
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

In [3]:
class PatchDataset(Dataset):
    def __init__(self, root_dir, transform=None, num_patches=40):
        self.root_dir = root_dir
        self.transform = transform
        self.num_patches = num_patches
        self.image_paths = [os.path.join(root_dir, f) for f in os.listdir(root_dir) if f.endswith('.jpg')]
        self.label_mapping = {label: idx for idx, label in enumerate(sorted(set(p.split('/')[0] for p in self.image_paths)))}

    def __len__(self):
        return len(self.image_paths) * self.num_patches

    def __getitem__(self, idx):
        img_idx = idx // self.num_patches
        patch_idx = idx % self.num_patches
        img_path = self.image_paths[img_idx]
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        i = random.randint(0, image.size(1) - 224)
        j = random.randint(0, image.size(2) - 224)
        patch = image[:, i:i+224, j:j+224]
        
        label = img_path.split('/')[0]  # Extract label from image path
        label_tensor = torch.tensor(self.label_mapping[label], dtype=torch.long)  # Convert label to tensor

        return patch, label_tensor

In [4]:
# Define path
dataset_path = 'D:\\image_folder\\Images'
# Load dataset
dataset = PatchDataset(dataset_path, transform)

In [5]:
# Create data loaders
batch_size = 1
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

In [6]:
# Define the feature extractor (ResNet-18)
class FeatureExtractor(nn.Module):
    def __init__(self):
        super(FeatureExtractor, self).__init__()
        resnet18 = models.resnet18(pretrained=True)
        self.features = nn.Sequential(*list(resnet18.children())[:-1])
        self.fc = nn.Linear(resnet18.fc.in_features, 512)

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

In [7]:
# Define the PDN (Projection Discriminator Network)
class PDN(nn.Module):
    def __init__(self):
        super(PDN, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(32, 3, kernel_size=2, stride=2),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x


In [8]:
# Instantiate models
feature_extractor = FeatureExtractor()
pdn = PDN()

# Training parameters
lr = 0.001
num_epochs = 20
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(list(feature_extractor.parameters()) + list(pdn.parameters()), lr=lr)



In [9]:
# Training loop
for epoch in range(num_epochs):
    feature_extractor.train()
    pdn.train()
    running_loss = 0.0

    for inputs, labels in train_loader:
        optimizer.zero_grad()

        # Forward pass through feature extractor
        features = feature_extractor(inputs)
        
        # Forward pass through PDN
        y_pred = pdn(inputs)
        

        # # Ensure labels are 1D
        # if len(labels.shape) > 1:
        #     labels = labels.squeeze()  # Remove extra dimensions if necessary
        #     print(labels)


        # Compute loss
        loss = loss_fn(y_pred, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}")

print("Training complete.")


RuntimeError: only batches of spatial targets supported (3D tensors) but got targets of dimension: 1