In [1]:
import torchvision
torchvision.disable_beta_transforms_warning()
from torchvision.transforms import v2
from torchvision import datasets, transforms

import timm
import torch.nn as nn
import torch
from torch.utils.data import DataLoader
import cv2
import numpy as np
from sklearn import metrics

import fnmatch
from PIL import Image, ImageChops, ImageEnhance

#For texture extraction
from skimage import feature
import os

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def extract_texture_features(image):
    gray_image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
    gray_image = gray_image.astype(np.uint8)  # Convert to integer type
    radius = 3
    n_points = 8 * radius
    lbp = feature.local_binary_pattern(gray_image, n_points, radius, method='uniform')

    lbp = lbp / lbp.max()  # Normalize the LBP values to the range [0, 1]
    
    return lbp

def extract_color_features(image, quality=95, enhance_factor=10):
    image_bytes = image.tobytes()
    recompressed = Image.frombytes("RGB", image.size, image_bytes)
    ela_image = ImageChops.difference(image, recompressed)

    enhancer = ImageEnhance.Brightness(ela_image)
    enhanced_ela = enhancer.enhance(enhance_factor)

    resized_ela = enhanced_ela.resize((224, 224)).convert("L")
    feature_array = np.array(resized_ela).astype(np.float32) / 255.0
    return feature_array


def extract_shape_features(image):
    kernel_size = 5
    transform_iteration = 5

    # Define the kernel
    kernel = np.ones((kernel_size, kernel_size), np.uint8)

    image = cv2.resize(image, (224, 224))  # Resize to (224, 224)

    image_dict = {}
    image_dict["original_image"] = image
    image_dict["eroded_image"] = cv2.erode(image_dict["original_image"], kernel, iterations=transform_iteration)
    image_dict["dilated_image"] = cv2.dilate(image_dict["original_image"], kernel, iterations=transform_iteration)
    image_dict["opened_image"] = cv2.dilate(image_dict['eroded_image'], kernel, iterations=transform_iteration)
    image_dict["closed_image"] = cv2.erode(image_dict['dilated_image'], kernel, iterations=transform_iteration)

    opened_image_resized = cv2.cvtColor(image_dict["opened_image"], cv2.COLOR_RGB2GRAY)

    return opened_image_resized  # Shape: (224, 224)

In [3]:
class ImageFolderDataset(datasets.ImageFolder):
    def __init__(self, img_path, transform=None):
        self.img_path = img_path
        self.transform = transform
        self.classes = ['0_real', '1_fake']
        self.img_paths = self._get_image_paths()

    def _get_image_paths(self):
        img_paths = []
        for root, dirs, files in os.walk(self.img_path):
            for name in files:
                if fnmatch.fnmatch(name, "*.jpg"):
                    label = 0 if '0_real' in root else 1
                    img_paths.append((os.path.join(root, name), label))
        return img_paths

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, index):
        img_path, label = self.img_paths[index]
        img = Image.open(img_path)
        img = img.convert("RGB")
        if self.transform:
            img_tensor = self.transform(img)
        
        # Ensure the image is now a tensor
        if not isinstance(img_tensor, torch.Tensor):
            msg = f"Expected image to be a tensor, but got {type(img_tensor)}."
            raise TypeError(msg)
        
        # Convert tensor to numpy array for feature extraction
        img_np = img_tensor.numpy().transpose(1, 2, 0)
        
        # Extract features
        texture_features = extract_texture_features(img_np)
        color_features = extract_color_features(img)
        shape_features = extract_shape_features(img_np)

        img.close()

        features = np.stack([texture_features, color_features, shape_features], axis=0)
        features = torch.tensor(features).float()
        return features, label

In [4]:
def load_data(test_dir, batch_size, image_size):
    transform = transforms.Compose([
        transforms.Resize((image_size, image_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    

    test_dataset = ImageFolderDataset(test_dir, transform=transform)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=True)
    
    return test_loader

In [5]:
class ModifiedSEResNeXt(nn.Module):
    def __init__(self, num_classes=2):
        super(ModifiedSEResNeXt, self).__init__()
        original_model = timm.create_model('seresnext101_32x4d', pretrained=True)

        original_conv1 = original_model.conv1
        self.conv1 = nn.Conv2d(
            in_channels=3,  # 3 channels for texture, color, and shape features
            out_channels=original_conv1.out_channels,
            kernel_size=original_conv1.kernel_size,
            stride=original_conv1.stride,
            padding=original_conv1.padding,
            bias=False
        )
        
        with torch.no_grad():
            self.conv1.weight = nn.Parameter(original_conv1.weight.mean(dim=1, keepdim=True).repeat(1, 3, 1, 1))
        
        self.bn1 = original_model.bn1
        self.act1 = original_model.act1
        self.maxpool = original_model.maxpool
        self.layer1 = original_model.layer1
        self.layer2 = original_model.layer2
        self.layer3 = original_model.layer3
        self.layer4 = original_model.layer4
        self.avg_pool = original_model.global_pool
        
        num_features = original_model.fc.in_features
        self.fc = nn.Sequential(
            nn.Linear(num_features, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, num_classes)
        )
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.act1(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avg_pool(x)
        x = x.flatten(1)
        x = self.fc(x)
        
        return x

def get_model(filename='seresnext_finetuned.pth', force_new=False):
    file_path = os.path.join(os.getcwd(), filename)

    # Use pre-existing weights
    if os.path.exists(file_path) and not force_new:
        model = ModifiedSEResNeXt()
        model.load_state_dict(torch.load(file_path, map_location=DEVICE, weights_only=True))
        print(f"Loaded model weights from {file_path}")
        return model
    
    else:   # Create a new model
        # model = create_custom_model(pretrained=True)
        model = ModifiedSEResNeXt()
        return model


In [6]:
def evaluate(model, test_loader):
    # Validation phase
    model.eval()
    
    y_true = []
    y_pred = []
    
    for img, label in test_loader:
        # Please make sure that the "pred" is binary result
        output = model(img.to(DEVICE))
        pred = np.argmax(output.detach().to('cpu'), axis=1).numpy()
        
        y_true.extend(label.numpy())
        y_pred.extend(pred)

    y_true = np.array(y_true)
    y_pred = np.array(y_pred)

    accuracy = metrics.accuracy_score(y_true, y_pred)
    return accuracy

In [7]:
def validate_model(model, val_loader, criterion):
    model.eval()
    val_loss = 0
    correct = 0
    total = 0
    
    batch = 1
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = model(images)
            loss = criterion(outputs, labels).item()
            val_loss += loss
            predicted = torch.argmax(outputs, dim=1)
            total += labels.size(0)
            correct_no = (predicted == labels).sum().item()
            correct += correct_no
            batch += 1
    val_loss /= len(val_loader)
    accuracy = correct / len(val_loader.dataset)
    return val_loss, accuracy

In [8]:
if __name__ == '__main__':
    DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    model_weight_filename = r'..\some_result\seresnent_e1.pth'
    
    model = get_model(model_weight_filename, force_new=False)
    model = model.to(DEVICE)
    
    # Load the data
    test_dir = r'..\AIGC-Detection-Dataset\AIGC-Detection-Dataset\val'
    batch_size = 16
    image_size = 224
    test_loader= load_data(test_dir, batch_size, image_size)
    
    # Evaluate the model
    accuracy = evaluate(model, test_loader)
    print(f"Accuracy: {accuracy}")
    
    criterion = nn.CrossEntropyLoss().to(DEVICE)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
    loss, acc = validate_model(model, test_loader, criterion)
    print(f"Loss: {loss}, Accuracy: {acc}")
    

  return self.fget.__get__(instance, owner)()


Loaded model weights from c:\CS4487\nn\..\some_result\seresnent_e1.pth
Accuracy: 0.8960336538461539
Loss: 0.2596701517707119, Accuracy: 0.8946
