In [None]:
!pip install webdataset

Collecting webdataset
  Downloading webdataset-1.0.2-py3-none-any.whl.metadata (12 kB)
Collecting braceexpand (from webdataset)
  Downloading braceexpand-0.1.7-py2.py3-none-any.whl.metadata (3.0 kB)
Downloading webdataset-1.0.2-py3-none-any.whl (74 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m75.0/75.0 kB[0m [31m2.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading braceexpand-0.1.7-py2.py3-none-any.whl (5.9 kB)


In [None]:
import kagglehub
import os
import random
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision import datasets, models
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score
from PIL import Image
import torch.nn.functional as F
import random
from huggingface_hub import hf_hub_url, HfFileSystem
import webdataset as wds
import io


In [3]:
# Use GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cuda


In [38]:
# 6. Load pretrained ResNet-50 and modify final layer
model = models.resnet50(pretrained=True)

# Freeze all layers initially
for param in model.parameters():
    param.requires_grad = False

# Replace final fully connected layer (for 2 classes)
model.fc = nn.Linear(model.fc.in_features, 2)

# Only the final layer's parameters are trainable for now
for param in model.fc.parameters():
    param.requires_grad = True

model = model.to(device)

# 7. Loss and optimizer (only parameters with requires_grad=True are updated)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4)



In [6]:
# 44444
def openImgaes(images):
    pil_images = []
    for img_bytes in images:
        if isinstance(img_bytes, bytes):
            # Convert bytes to PIL Image
            pil_img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
            pil_images.append(pil_img)
        else:
            # If already PIL image, just append
            pil_images.append(img_bytes)

    # Now apply your transforms
    input_tensors = torch.stack([predict_transform(img) for img in pil_images])


    return input_tensors

In [5]:
# 444
predict_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])


In [8]:
# 8. Training function
def train_model(model, train_dataset, val_dataset, epochs=5, device='cuda'):
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters())
    criterion = torch.nn.CrossEntropyLoss()

    for epoch in range(epochs):
        train_loader = DataLoader(train_dataset, batch_size=32)

        model.train()
        total_loss = 0
        batch_count = 0
        for inputs, labels in train_loader:
            newInputs = openImgaes(inputs)
            newInputs = newInputs.to(device)
            labels_tensor = labels_to_tensor(labels).to(device)

            optimizer.zero_grad()
            outputs = model(newInputs)
            loss = criterion(outputs, labels_tensor)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            batch_count += 1

        avg_loss = total_loss / batch_count
        print(f"Epoch [{epoch+1}/{epochs}], Loss: {avg_loss:.4f}")

        # Validation
        model.eval()
        all_preds = []
        all_labels = []
        val_loader = DataLoader(val_dataset, batch_size=32)
        with torch.no_grad():
            for inputs, labels in val_loader:
                newInputs = openImgaes(inputs)
                newInputs = newInputs.to(device)
                labels_tensor = labels_to_tensor(labels).to(device)
                outputs = model(newInputs)
                _, preds = torch.max(outputs, 1)
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels_tensor.cpu().numpy())

        acc = accuracy_score(all_labels, all_preds)
        print(f"Validation Accuracy: {acc*100:.2f}%\n")


In [7]:
# 444
def labels_to_tensor(labels):
    label_to_idx = {
      "fake": 0,
       "real": 1
    }
    # labels is iterable of strings
    numeric_labels = [label_to_idx[l] for l in labels]
    return torch.tensor(numeric_labels, dtype=torch.long)

In [10]:
def unfreeze_model(model, unfreeze_from_layer=6):
    # ResNet layers: layer1, layer2, layer3, layer4
    # unfreeze_from_layer: number between 1 and 4 to unfreeze from that layer onwards
    layers = [model.layer1, model.layer2, model.layer3, model.layer4]

    for param in model.parameters():
        param.requires_grad = False  # Freeze all first

    for param in model.fc.parameters():
        param.requires_grad = True  # Always train final fc layer

    # Unfreeze from specified layer onwards
    for i in range(unfreeze_from_layer - 1, len(layers)):
        for param in layers[i].parameters():
            param.requires_grad = True

    print(f"Unfroze layers from layer{unfreeze_from_layer} onwards")

In [11]:
# 13. unfreeze deeper layers to fine-tune
unfreeze_model(model, unfreeze_from_layer=3)

Unfroze layers from layer3 onwards


In [13]:
# 444
from torch.utils.data import DataLoader

def get_label_from_key(key_str):

    if 'fake' in key_str:
        return 'fake'
    elif 'real' in key_str:
        return 'real'
    else:
        return 'unknown'

def preprocess(sample):
    image = sample.get('png') or sample.get('jpg') or sample.get('tiff')
    key = sample.get('__key__', '')
    label = get_label_from_key(key)
    return image, label


In [14]:
# Replace "your_own_huggingface_token" with your actual Hugging Face access token
# Get one at: https://huggingface.co/settings/tokens
myTtoken = "your_own_huggingface_token"

os.environ["HF_TOKEN"] = myTtoken   
    "train_fake": "**/fake_train/*.tar.gz",
    "train_real": "**/real_train/*.tar.gz",
    "test_fake":  "**/fake_test/*.tar.gz",
    "test_real":  "**/real_test/*.tar.gz"
}

def get_urls(split_pattern):
    fs = HfFileSystem()
    files = [fs.resolve_path(path) for path in fs.glob("hf://datasets/xingjunm/WildDeepfake/" + split_pattern)]
    return [hf_hub_url(file.repo_id, file.path_in_repo, repo_type="dataset") for file in files]

def make_ds(urls):
    urls_pipe = f"pipe: curl -s -L -H 'Authorization: Bearer {myTtoken}' {'::'.join(urls)}"
    return wds.WebDataset(urls_pipe, shardshuffle=False).decode()


In [16]:
train_fake_urls = get_urls(splits["train_fake"])
train_real_urls = get_urls(splits["train_real"])
test_fake_urls  = get_urls(splits["test_fake"])
test_real_urls  = get_urls(splits["test_real"])

random.seed(42)  # For reproducibility
random.shuffle(train_fake_urls)
random.shuffle(train_real_urls)
random.shuffle(test_fake_urls)
random.shuffle(test_real_urls)

In [None]:
print(len(train_fake_urls), len(train_real_urls), len(test_fake_urls), len(test_real_urls), sep='\n')

592
371
115
42


In [17]:
max_samples_training =4000
max_smaples_test= 2000

In [29]:
fake_data_set = []
max_samples_per_url = 70
for url in train_fake_urls:
  if len(fake_data_set) >= max_samples_training:
    break
  train_fake =make_ds([url])
  train_fake_processed = list(train_fake.map(preprocess))
  random.shuffle(train_fake_processed)
  for image,label in train_fake_processed:
    if len(fake_data_set) >= max_samples_training:
      break
    fake_data_set.append((image,label))

fake_train = fake_data_set
random.shuffle(fake_train)


In [30]:
real_data_set = []
max_samples_per_url = 140
for url in train_real_urls:
  if len(real_data_set) >= max_samples_training:
    break
  train_real =make_ds([url])
  train_real_processed = list(train_real.map(preprocess))
  random.shuffle(train_real_processed)
  for image,label in train_real_processed:
    if len(real_data_set) >= max_samples_training:
      break
    real_data_set.append((image,label))


real_train = real_data_set
random.seed(5)
random.shuffle(real_train)


In [28]:
_test_fake_data_set = []
max_samples_per_url = 180
for url in test_fake_urls:
  if len(_test_fake_data_set) >= max_smaples_test:
    break
  test_fake =make_ds([url])
  test_fake_processed = list(test_fake.map(preprocess))
  random.shuffle(test_fake_processed)
  for image,label in test_fake_processed:
    if len(_test_fake_data_set) >= max_smaples_test:
      break
    _test_fake_data_set.append((image,label))



random.seed(13)
random.shuffle(_test_fake_data_set)

test_fake_data_set,fake_val = _test_fake_data_set[:int(len(_test_fake_data_set)*0.5)], _test_fake_data_set[int(len(_test_fake_data_set)*0.5):]

random.shuffle(test_fake_data_set)
random.seed(100)
random.shuffle(fake_val)

In [32]:
train_set = fake_train + real_train
test_set = test_real_data_set + test_fake_data_set
val_set = fake_val + real_val

random.seed(58)
random.shuffle(train_set)
random.seed(1)
random.shuffle(test_set)
random.seed(99)
random.shuffle(val_set)

In [31]:
_test_real_data_set = []
max_samples_per_url = 500
for url in test_real_urls:
  if len(_test_real_data_set) >= max_smaples_test:
    break
  test_real =make_ds([url])
  test_real_processed = list(test_real.map(preprocess))
  random.shuffle(test_real_processed)
  for image,label in test_real_processed:
    if len(_test_real_data_set) >= max_smaples_test:
      break
    _test_real_data_set.append((image,label))


random.seed(67)
random.shuffle(_test_real_data_set)

test_real_data_set,real_val = _test_real_data_set[:int(len(_test_real_data_set)*0.5)], _test_real_data_set[int(len(_test_real_data_set)*0.5):]
random.shuffle(test_real_data_set)

random.seed(87)
random.shuffle(real_val)

In [39]:
train_model(model, train_set, val_set,epochs=4)

Epoch [1/4], Loss: 0.1974
Validation Accuracy: 86.80%

Epoch [2/4], Loss: 0.0695
Validation Accuracy: 81.60%

Epoch [3/4], Loss: 0.0493
Validation Accuracy: 81.20%

Epoch [4/4], Loss: 0.0399
Validation Accuracy: 80.40%



In [40]:
model.eval()
all_preds = []
all_labels = []
val_loader = DataLoader(test_set, batch_size=32)
with torch.no_grad():
    for inputs, labels in val_loader:
        newInputs = openImgaes(inputs)
        newInputs = newInputs.to(device)
        labels_tensor = labels_to_tensor(labels).to(device)
        outputs = model(newInputs)
        _, preds = torch.max(outputs, 1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels_tensor.cpu().numpy())

acc = accuracy_score(all_labels, all_preds)
print(f"test Accuracy: {acc*100:.2f}%\n")

test Accuracy: 80.80%



In [41]:
print(model)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 