In [1]:
# Import necessary libraries
import os
import sys
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms, models
from models import Model
from datasets import get_data_loaders
from train import train
from test import evaluate, load_model
from pathlib import Path
# from tensorboard_utils import setup_tensorboard

In [2]:
def parse_args():
    parser = argparse.ArgumentParser(description='Training script')
    parser.add_argument('--learning_rate', type=float, default=1e-3, help='Learning rate for the optimizer')
    parser.add_argument('--checkpoint_dir', type=str, default='./checkpoints', help='Directory to save checkpoints')
    parser.add_argument('--log_dir', type=str, default='./logs', help='Directory for TensorBoard logs')
    parser.add_argument('--train_dir', type=str, default='D:/sp_cup/dataset/valid', help='Directory for training data')
    parser.add_argument('--test_dir', type=str, default='D:/sp_cup/dataset/valid', help='Directory for testing data')
    parser.add_argument('--batch_size', type=int, default=16, help='Batch size for training')
    parser.add_argument('--epochs', type=int, default=20, help='Number of epochs to train')
    
    # Use parse_known_args to avoid errors
    args, unknown = parser.parse_known_args()
    return args


args = parse_args()

os.makedirs(args.checkpoint_dir, exist_ok=True)
os.makedirs(args.log_dir, exist_ok=True)

print(args)


Namespace(learning_rate=0.001, checkpoint_dir='./checkpoints', log_dir='./logs', train_dir='D:/sp_cup/dataset/valid', test_dir='D:/sp_cup/dataset/valid', batch_size=16, epochs=20)


In [3]:
# Setup TensorBoard
# writer = setup_tensorboard(args.log_dir)

### Initialize Model

In [4]:
# Initialize model
print("Initializing Model")

#model = Model()

model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
num_features = model.fc.in_features  # Get the number of features from the current fc layer
model.fc = nn.Sequential(
    nn.Linear(num_features, 512),
    nn.ReLU(),
    nn.Dropout(0.5),
    nn.Linear(512, 1), # Output layer for binary classification (Fake/Real)
    nn.Sigmoid()
)

print(model)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"device: {device}")
model = model.to(device)

Initializing Model
ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=

### Loss Function and Optimizer

In [5]:
# Define loss function and optimizer
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)

### Image Preprocessing

In [6]:
# Load datasets
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
    #transforms.Normalize([0.3996, 0.3194, 0.3223], [0.2321, 0.1766, 0.1816])
])
train_loader, valid_loader = get_data_loaders(args.train_dir, args.test_dir, args.batch_size, transform)

### Train Model

In [7]:
# Train the model
print("Start Training")
train(model, train_loader, criterion, optimizer, args.epochs, args.checkpoint_dir)

Start Training


  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/192 [00:00<?, ?it/s]

Epoch [1/20], Batch [1/192], Loss: 0.6974
Epoch [1/20], Batch [51/192], Loss: 0.2284
Epoch [1/20], Batch [101/192], Loss: 0.0915
Epoch [1/20], Batch [151/192], Loss: 0.0430
Epoch [1/20], Loss: 0.1259, Accuracy: 94.99%


  0%|          | 0/192 [00:00<?, ?it/s]

Epoch [2/20], Batch [1/192], Loss: 0.0290
Epoch [2/20], Batch [51/192], Loss: 0.0541
Epoch [2/20], Batch [101/192], Loss: 0.0355
Epoch [2/20], Batch [151/192], Loss: 0.0599
Epoch [2/20], Loss: 0.0580, Accuracy: 97.53%


  0%|          | 0/192 [00:00<?, ?it/s]

Epoch [3/20], Batch [1/192], Loss: 0.0206
Epoch [3/20], Batch [51/192], Loss: 0.0001
Epoch [3/20], Batch [101/192], Loss: 0.0001
Epoch [3/20], Batch [151/192], Loss: 0.0936
Epoch [3/20], Loss: 0.0655, Accuracy: 97.46%


  0%|          | 0/192 [00:00<?, ?it/s]

Epoch [4/20], Batch [1/192], Loss: 0.1076
Epoch [4/20], Batch [51/192], Loss: 0.0397
Epoch [4/20], Batch [101/192], Loss: 0.0009
Epoch [4/20], Batch [151/192], Loss: 0.0971
Epoch [4/20], Loss: 0.0796, Accuracy: 96.35%


  0%|          | 0/192 [00:00<?, ?it/s]

Epoch [5/20], Batch [1/192], Loss: 0.0119
Epoch [5/20], Batch [51/192], Loss: 0.0006
Epoch [5/20], Batch [101/192], Loss: 0.0322
Epoch [5/20], Batch [151/192], Loss: 0.0010
Epoch [5/20], Loss: 0.0609, Accuracy: 97.33%
Checkpoint saved at ./checkpoints\checkpoint_epoch_5.pth


  0%|          | 0/192 [00:00<?, ?it/s]

Epoch [6/20], Batch [1/192], Loss: 0.0001
Epoch [6/20], Batch [51/192], Loss: 0.0498
Epoch [6/20], Batch [101/192], Loss: 0.0057
Epoch [6/20], Batch [151/192], Loss: 0.0300
Epoch [6/20], Loss: 0.0395, Accuracy: 97.85%


  0%|          | 0/192 [00:00<?, ?it/s]

Epoch [7/20], Batch [1/192], Loss: 0.0499
Epoch [7/20], Batch [51/192], Loss: 0.0005
Epoch [7/20], Batch [101/192], Loss: 0.0007
Epoch [7/20], Batch [151/192], Loss: 0.0000
Epoch [7/20], Loss: 0.0339, Accuracy: 97.95%


  0%|          | 0/192 [00:00<?, ?it/s]

Epoch [8/20], Batch [1/192], Loss: 0.0971
Epoch [8/20], Batch [51/192], Loss: 0.0017
Epoch [8/20], Batch [101/192], Loss: 0.0002
Epoch [8/20], Batch [151/192], Loss: 0.0463
Epoch [8/20], Loss: 0.0346, Accuracy: 97.92%


  0%|          | 0/192 [00:00<?, ?it/s]

Epoch [9/20], Batch [1/192], Loss: 0.0491
Epoch [9/20], Batch [51/192], Loss: 0.2336
Epoch [9/20], Batch [101/192], Loss: 0.0292
Epoch [9/20], Batch [151/192], Loss: 0.1633
Epoch [9/20], Loss: 0.1136, Accuracy: 97.04%


  0%|          | 0/192 [00:00<?, ?it/s]

Epoch [10/20], Batch [1/192], Loss: 0.0001
Epoch [10/20], Batch [51/192], Loss: 0.2738
Epoch [10/20], Batch [101/192], Loss: 0.0912
Epoch [10/20], Batch [151/192], Loss: 0.0380
Epoch [10/20], Loss: 0.0520, Accuracy: 97.33%
Checkpoint saved at ./checkpoints\checkpoint_epoch_10.pth


  0%|          | 0/192 [00:00<?, ?it/s]

Epoch [11/20], Batch [1/192], Loss: 0.0511
Epoch [11/20], Batch [51/192], Loss: 0.0002
Epoch [11/20], Batch [101/192], Loss: 0.0421
Epoch [11/20], Batch [151/192], Loss: 0.0383
Epoch [11/20], Loss: 0.0379, Accuracy: 97.95%


  0%|          | 0/192 [00:00<?, ?it/s]

Epoch [12/20], Batch [1/192], Loss: 0.0044
Epoch [12/20], Batch [51/192], Loss: 0.0403
Epoch [12/20], Batch [101/192], Loss: 0.0363
Epoch [12/20], Batch [151/192], Loss: 0.0018
Epoch [12/20], Loss: 0.0422, Accuracy: 97.66%


  0%|          | 0/192 [00:00<?, ?it/s]

Epoch [13/20], Batch [1/192], Loss: 0.1618
Epoch [13/20], Batch [51/192], Loss: 0.0007
Epoch [13/20], Batch [101/192], Loss: 0.0264
Epoch [13/20], Batch [151/192], Loss: 0.0208
Epoch [13/20], Loss: 0.0501, Accuracy: 97.49%


  0%|          | 0/192 [00:00<?, ?it/s]

Epoch [14/20], Batch [1/192], Loss: 0.0391
Epoch [14/20], Batch [51/192], Loss: 0.1288
Epoch [14/20], Batch [101/192], Loss: 0.0001
Epoch [14/20], Batch [151/192], Loss: 0.0002
Epoch [14/20], Loss: 0.0515, Accuracy: 97.04%


  0%|          | 0/192 [00:00<?, ?it/s]

Epoch [15/20], Batch [1/192], Loss: 0.0302
Epoch [15/20], Batch [51/192], Loss: 0.0000
Epoch [15/20], Batch [101/192], Loss: 0.0111
Epoch [15/20], Batch [151/192], Loss: 0.1617
Epoch [15/20], Loss: 0.0422, Accuracy: 97.59%
Checkpoint saved at ./checkpoints\checkpoint_epoch_15.pth


  0%|          | 0/192 [00:00<?, ?it/s]

Epoch [16/20], Batch [1/192], Loss: 0.1003
Epoch [16/20], Batch [51/192], Loss: 0.0534
Epoch [16/20], Batch [101/192], Loss: 0.0449
Epoch [16/20], Batch [151/192], Loss: 0.0000
Epoch [16/20], Loss: 0.0333, Accuracy: 97.85%


  0%|          | 0/192 [00:00<?, ?it/s]

Epoch [17/20], Batch [1/192], Loss: 0.0447
Epoch [17/20], Batch [51/192], Loss: 0.0000
Epoch [17/20], Batch [101/192], Loss: 0.0000
Epoch [17/20], Batch [151/192], Loss: 0.0429
Epoch [17/20], Loss: 0.0327, Accuracy: 97.88%


  0%|          | 0/192 [00:00<?, ?it/s]

Epoch [18/20], Batch [1/192], Loss: 0.0001
Epoch [18/20], Batch [51/192], Loss: 0.0424
Epoch [18/20], Batch [101/192], Loss: 0.0000
Epoch [18/20], Batch [151/192], Loss: 0.0001
Epoch [18/20], Loss: 0.0327, Accuracy: 97.92%


  0%|          | 0/192 [00:00<?, ?it/s]

Epoch [19/20], Batch [1/192], Loss: 0.0029
Epoch [19/20], Batch [51/192], Loss: 0.0352
Epoch [19/20], Batch [101/192], Loss: 0.0018
Epoch [19/20], Batch [151/192], Loss: 0.0408
Epoch [19/20], Loss: 0.0351, Accuracy: 97.75%


  0%|          | 0/192 [00:00<?, ?it/s]

Epoch [20/20], Batch [1/192], Loss: 0.0013
Epoch [20/20], Batch [51/192], Loss: 0.1592
Epoch [20/20], Batch [101/192], Loss: 0.0214
Epoch [20/20], Batch [151/192], Loss: 0.0164
Epoch [20/20], Loss: 0.0823, Accuracy: 97.27%
Checkpoint saved at ./checkpoints\checkpoint_epoch_20.pth
Saving model to: models\model_01_resnet50.pth


### Load Model

In [8]:
loaded_model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)

# Modify the fully connected layer to match the saved model
num_features = loaded_model.fc.in_features
loaded_model.fc = nn.Sequential(
	nn.Linear(num_features, 512),
	nn.ReLU(),
	nn.Dropout(0.5),
	nn.Linear(512, 1),
	nn.Sigmoid()
)

MODEL_SAVE_PATH = 'models/model_01_resnet50.pth'
state_dict = torch.load(MODEL_SAVE_PATH)

loaded_model.load_state_dict(state_dict)
loaded_model = loaded_model.to(device)
print(f"Model loaded from {MODEL_SAVE_PATH}")

Model loaded from models/model_01_resnet50.pth


  state_dict = torch.load(MODEL_SAVE_PATH)


In [9]:
# Set both models to evaluation mode
model.eval()
loaded_model.eval()

# Compare the state dictionaries
def compare_models(model1, model2):
    model1_dict = model1.state_dict()
    model2_dict = model2.state_dict()
    
    for key in model1_dict:
        if not torch.equal(model1_dict[key], model2_dict[key]):
            return False
    return True

# Check if the parameters are identical
are_identical = compare_models(model, loaded_model)
print(f"Are the trained model and loaded model parameters identical? {are_identical}")

Are the trained model and loaded model parameters identical? True


In [10]:
#checkpoint_path = os.path.join(args.checkpoint_dir, 'final_checkpoint.pth')
#print(f"Loading model from {checkpoint_path}")
#model = load_model(checkpoint_path, model)

evaluate(loaded_model, valid_loader)

Accuracy of the model on the validation images: 49.61%
True Negatives (Real identified as Real): 1524
False Positives (Real identified as Fake): 0
False Negatives (Fake identified as Real): 1548
True Positives (Fake identified as Fake): 0
Confusion Matrix:
[[1524    0]
 [1548    0]]
