In [4]:
import torchvision.transforms as transforms
import torch.utils.data as data
import torch.optim as optim
import torchvision.models as models
import torch
import torch.nn as nn
import os
import numpy as np
import cv2
import numpy as np
from torch.utils.data import Dataset

In [8]:
class RealFakeDataset(Dataset):
    def __init__(self, data_folder):
        self.data_folder = data_folder
        self.classes = {'nature': 0, 'ai': 1}
        self.image_files = []
        self.labels = []
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize((224, 224)),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

        for class_name, class_label in self.classes.items():
            class_folder = os.path.join(self.data_folder, class_name)
            for image_file in os.listdir(class_folder):
                self.image_files.append(os.path.join(class_folder, image_file))
                self.labels.append(class_label)

        self.labels = np.array(self.labels)

    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        image_path = self.image_files[idx]
        image = cv2.imread(image_path)
        image = self.transform(image)

        return image, self.labels[idx]

In [23]:
class SwinVision(nn.Module):
    def __init__(self, num_classes):
        super(SwinVision, self).__init__()
        self.num_classes = num_classes
        self.swin = models.swin_v2_t(pretrained=True)
        for param in self.swin.parameters():
            param.requires_grad = False

        self.swin = nn.Sequential(self.swin, nn.Linear(1000, num_classes))

    def forward(self, x):
        return self.swin(x)

In [None]:
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt

num_classes = 2  # Replace with the actual number of classes
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SwinVision(num_classes=2).to(device) # Load the ResNet model
num_epochs = 5
batch_size = 16

# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=1e-3, momentum=0.9, weight_decay=1e-4)

# Define the data loaders
train_dataset = RealFakeDataset(data_folder='imagenet_ai_small/ADM/train')
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

val_dataset = RealFakeDataset(data_folder='imagenet_ai_small/ADM/val')
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

# Training loop
losses = []
accuracies = []
val_losses = []
val_accuracies = []
for epoch in range(num_epochs):
    running_loss = 0.0
    accuracy = 0.0
    for i, data in enumerate(train_loader):
        # Get the inputs and labels from the data loader
        inputs, labels = data
        labels = labels.long()
        inputs = inputs.to(device)
        labels = labels.to(device)

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        accuracy += (outputs.argmax(1) == labels).sum()

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        # Print statistics
        running_loss += loss.item()
        if i % 100 == 99:    # Print every 100 mini-batches
            accuracy = accuracy/(batch_size*100)
            print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {running_loss/100:.4f}, Accuracy: {accuracy:.4f}')
            losses.append(running_loss/100)
            accuracies.append(accuracy)
            running_loss = 0.0
            # Evaluate the model on the validation set
            model.eval()

            val_loss = 0.0
            val_accuracy = 0.0
            with torch.no_grad():
                for data in val_loader:
                    inputs, labels = data
                    labels = labels.long()
                    inputs = inputs.to(device)
                    labels = labels.to(device)

                    outputs = model(inputs)
                    val_loss += criterion(outputs, labels).item()
                    val_accuracy += (outputs.argmax(1) == labels).sum().item()

            val_loss /= len(val_loader)
            val_accuracy /= len(val_dataset)
            val_losses.append(val_loss)
            val_accuracies.append(val_accuracy)

            print(f'Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}')

# Plot and save the training loss and accuracy
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.plot(losses)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss')

plt.subplot(1, 2, 2)
plt.plot(accuracies)
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Training Accuracy')

plt.tight_layout()
plt.savefig('training_plot.png')

# Plot and save the validation loss and accuracy
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.plot(val_losses)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Validation Loss')

plt.subplot(1, 2, 2)
plt.plot(val_accuracies)
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Validation Accuracy')

plt.tight_layout()
plt.savefig('validation_plot.png')

model_path = 'resnet_model.pth'
torch.save(model.state_dict(), model_path)
print('Training finished.')

Epoch [1/5], Step [100/1864], Loss: 0.6618, Accuracy: 0.6106
Validation Loss: 0.5654, Validation Accuracy: 0.7093
Epoch [1/5], Step [200/1864], Loss: 0.5547, Accuracy: 0.7148
Validation Loss: 0.5227, Validation Accuracy: 0.7576
Epoch [1/5], Step [300/1864], Loss: 0.5441, Accuracy: 0.7354
Validation Loss: 0.4927, Validation Accuracy: 0.7719
Epoch [1/5], Step [400/1864], Loss: 0.5078, Accuracy: 0.7573
Validation Loss: 0.4754, Validation Accuracy: 0.7880
Epoch [1/5], Step [500/1864], Loss: 0.4690, Accuracy: 0.7811
Validation Loss: 0.4609, Validation Accuracy: 0.7907
Epoch [1/5], Step [600/1864], Loss: 0.4921, Accuracy: 0.7730
Validation Loss: 0.4591, Validation Accuracy: 0.7907
Epoch [1/5], Step [700/1864], Loss: 0.4499, Accuracy: 0.7805
Validation Loss: 0.4445, Validation Accuracy: 0.8005
Epoch [1/5], Step [800/1864], Loss: 0.4631, Accuracy: 0.7824
Validation Loss: 0.4262, Validation Accuracy: 0.8077
Epoch [1/5], Step [900/1864], Loss: 0.4625, Accuracy: 0.7805
Validation Loss: 0.4254, Va