In [1]:
!pip install efficientnet-pytorch
!pip install torch

Collecting efficientnet-pytorch
  Downloading efficientnet_pytorch-0.7.1.tar.gz (21 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: efficientnet-pytorch
  Building wheel for efficientnet-pytorch (setup.py) ... [?25l[?25hdone
  Created wheel for efficientnet-pytorch: filename=efficientnet_pytorch-0.7.1-py3-none-any.whl size=16425 sha256=57daf84ec5397e474b7f02b37f5eddbaf3e044e13b1648aa5b2e96b359892003
  Stored in directory: /root/.cache/pip/wheels/03/3f/e9/911b1bc46869644912bda90a56bcf7b960f20b5187feea3baf
Successfully built efficientnet-pytorch
Installing collected packages: efficientnet-pytorch
Successfully installed efficientnet-pytorch-0.7.1


In [2]:
import torch
from torch.utils.data import DataLoader, ConcatDataset
from torchvision import datasets, transforms, models
import torch.nn as nn
import torch.optim as optim
from efficientnet_pytorch import EfficientNet
import numpy as np
from scipy import ndimage
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torchvision.datasets import FashionMNIST

In [3]:
# Define Transformation

def get_transform():
    return transforms.Compose([
        transforms.Resize((128, 128)),  # Resize all images to 128x128
        transforms.Grayscale(num_output_channels=3),  # Convert grayscale images to 3 channels
        transforms.ToTensor(),  # Convert images to tensor
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # Normalize images for 3 channels
    ])

In [4]:
# Example for FashionMNIST dataset
from torchvision.datasets import FashionMNIST

fashion_mnist_train = FashionMNIST(root='data', train=True, download=True, transform=get_transform())
fashion_mnist_test = FashionMNIST(root='data', train=False, download=True, transform=get_transform())

# You can apply similar code to the other datasets.

transform = get_transform()

# Load CIFAR-10 dataset
cifar10_data = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
cifar10_test_data = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

# Load FashionMNIST dataset
fashionmnist_data = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
fashionmnist_test_data = datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)

# Load StegoImages dataset
from torchvision.datasets import ImageFolder
stego_data = ImageFolder(root='/kaggle/input/stegoimagesdataset/train', transform=transform)
stego_test_data = ImageFolder(root='/kaggle/input/stegoimagesdataset/test', transform=transform)


Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 26421880/26421880 [00:01<00:00, 16986858.26it/s]


Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 29515/29515 [00:00<00:00, 272549.88it/s]


Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 4422102/4422102 [00:00<00:00, 5050277.79it/s]


Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 5148/5148 [00:00<00:00, 8791643.73it/s]


Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:02<00:00, 79492943.02it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [5]:
combined_train_data = ConcatDataset([cifar10_data, fashionmnist_data, stego_data])
combined_test_data = ConcatDataset([cifar10_test_data, fashionmnist_test_data, stego_test_data])

train_loader = DataLoader(combined_train_data, batch_size=64, shuffle=True)
test_loader = DataLoader(combined_test_data, batch_size=64, shuffle=False)

In [6]:
import torch
import torch.nn as nn
from efficientnet_pytorch import EfficientNet

class StegoLocationNet(nn.Module):
    def __init__(self, num_classes):
        super(StegoLocationNet, self).__init__()
        self.num_classes = num_classes

        # Load EfficientNet base model
        self.base_model = EfficientNet.from_pretrained('efficientnet-b0')

        # Retrieve the in_features from the original _fc layer
        if hasattr(self.base_model, '_fc') and isinstance(self.base_model._fc, nn.Linear):
            in_features = self.base_model._fc.in_features
        else:
            raise AttributeError("EfficientNet model does not have a valid '_fc' layer.")

        # Replace the classification layer
        self.base_model._fc = nn.Identity()  # Remove the existing fully connected layer

        # Add custom classification head
        self.classifier = nn.Linear(in_features, num_classes)

        # Add location detection layers
        self.location_head = nn.Sequential(
            nn.Conv2d(in_features, 128, kernel_size=1),  # Match input channels to 1280
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 1, kernel_size=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        # Extract features before the final classifier
        features = self.base_model.extract_features(x)

        # Generate location map
        location_map = self.location_head(features)

        # Optional classification (not used for location-only tasks)
        pooled_features = torch.mean(features, dim=[2, 3])  # Global average pooling
        class_output = self.classifier(pooled_features)

        return location_map, class_output


In [7]:
def generate_hiding_spots(images):
    # Example heuristic: higher values in uniform regions
    # Apply Sobel filter to find edges
    sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]).view(1, 1, 3, 3).float().to(images.device)
    sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]]).view(1, 1, 3, 3).float().to(images.device)

    edges_x = F.conv2d(images, sobel_x, padding=1)
    edges_y = F.conv2d(images, sobel_y, padding=1)
    edge_magnitude = torch.sqrt(edges_x**2 + edges_y**2)

    # Normalize and invert edge magnitude to prioritize smooth regions
    suitability = 1 - (edge_magnitude / edge_magnitude.max())
    return suitability


In [8]:
num_classes = 10  # Update based on your dataset
model = StegoLocationNet(num_classes=num_classes)


Downloading: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth" to /root/.cache/torch/hub/checkpoints/efficientnet-b0-355c32eb.pth
100%|██████████| 20.4M/20.4M [00:00<00:00, 183MB/s]

Loaded pretrained weights for efficientnet-b0





In [9]:

# Get the number of classes from your dataset
num_classes = len(cifar10_data.classes) + len(fashionmnist_data.classes) + len(stego_data.classes)

# Load the EfficientNet model
# Get the number of classes from your dataset
num_classes = len(cifar10_data.classes) + len(fashionmnist_data.classes) + len(stego_data.classes)

# Create model
model = StegoLocationNet(num_classes=num_classes)

# Move model to the appropriate device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)


Loaded pretrained weights for efficientnet-b0


In [10]:
criterion = nn.CrossEntropyLoss()
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)


In [11]:
# Add this helper function above the training loop
def generate_hiding_spots(images, threshold=0.8):
    """
    Generate target masks indicating good hiding spots
    This is a simple example using edge detection and local variance
    """
    batch_size = images.size(0)
    masks = torch.zeros((batch_size, 1, images.size(2), images.size(3))).to(images.device)
    
    for i in range(batch_size):
        img = images[i].cpu().permute(1, 2, 0).numpy()
        
        # Convert to grayscale
        gray = np.mean(img, axis=2)
        
        # Calculate local variance
        local_var = ndimage.generic_filter(gray, np.var, size=3)
        
        # Normalize and threshold
        local_var = (local_var - local_var.min()) / (local_var.max() - local_var.min())
        masks[i, 0] = torch.from_numpy(local_var > threshold).float().to(images.device)
    
    return masks

In [12]:
# Define the function to generate target masks
def generate_hiding_spots(images, threshold=0.8):
    """
    Generate target masks indicating good hiding spots.
    This is a simple example using edge detection and local variance.
    """
    batch_size = images.size(0)
    masks = torch.zeros((batch_size, 1, images.size(2), images.size(3))).to(images.device)

    # Example implementation: Fill the masks with zeros or apply a heuristic
    for i in range(batch_size):
        img = images[i].cpu().permute(1, 2, 0).numpy()
        # Add your heuristic for detecting hiding spots here
        # For example, based on edge detection, variance, or other methods
        pass

    return masks


In [13]:
# Training loop
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

num_epochs = 2
batch_size = 64  # Define the batch size for DataLoader
for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0.0  # Track the total loss for the epoch

    for i, (images, _) in enumerate(train_loader):  # Process in batches
        images = images.to(device)

        # Generate target masks
        target_masks = generate_hiding_spots(images)

        optimizer.zero_grad()

        # Forward pass
        location_maps, _ = model(images)  # Unpack the tuple to get location_map

        # Upsample location_maps to match target_masks
        location_maps_upsampled = torch.nn.functional.interpolate(
            location_maps, size=target_masks.shape[2:], mode='bilinear', align_corners=False
        )

        # Calculate loss
        loss = criterion(location_maps_upsampled, target_masks)

        # Backward pass
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()  # Accumulate loss for the epoch

    # Log average loss for the epoch
    avg_loss = epoch_loss / len(train_loader)
    print(f'Epoch [{epoch+1}/{num_epochs}], Average Loss: {avg_loss:.4f}')


Epoch [1/2], Average Loss: 0.0090
Epoch [2/2], Average Loss: 0.0001


In [14]:
print(f"Location maps size: {location_maps.shape}")
print(f"Target masks size: {target_masks.shape}")


Location maps size: torch.Size([48, 1, 4, 4])
Target masks size: torch.Size([48, 1, 128, 128])


In [15]:
import torch.nn.functional as F

def evaluate_model(model, test_loader, criterion, device):
    model.eval()
    test_loss = 0.0
    with torch.no_grad():
        for images, _ in test_loader:  # Ignore original labels
            images = images.to(device)

            # Generate target masks for testing
            target_masks = generate_hiding_spots(images)

            # Forward pass
            location_maps, _ = model(images)

            # Resize `location_maps` to match `target_masks` size
            location_maps_resized = F.interpolate(
                location_maps, size=target_masks.shape[2:], mode="bilinear", align_corners=False
            )

            # Compute loss
            loss = criterion(location_maps_resized, target_masks)

            # Accumulate loss
            test_loss += loss.item()

    avg_test_loss = test_loss / len(test_loader)
    print(f"Average Test Loss: {avg_test_loss:.4f}")
    return avg_test_loss

# Perform evaluation
average_test_loss = evaluate_model(model, test_loader, criterion, device)


Average Test Loss: 0.0000


In [17]:
torch.save(model, "model.pth")