In [None]:
import sys
sys.path.append("..")

import random
from glob import glob 
from PIL import Image
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt

import torch
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F

from arch import utils, ela, net

In [None]:
DEVICE = utils.device_mapper()
print(f"Device: {str(DEVICE).upper()}")

---

# DATASET

In [None]:
# DATASET = "/mnt/artemis/library/datasets/image-forgery"
DATASET = "/mnt/artemis/library/datasets/casia"

In [None]:
total_files = glob(DATASET + '/**/*')
print(f"Total count: {len(total_files)}")

types = set()
for file in total_files: types.add(file.split('.')[-1])
print(f"Types: {types}")

for i in types:
    print(f"{i}: {len(glob(DATASET + '/*/*' + str(i)))}")

In [None]:
# IMAGE FORGERY DATASET

# Au = glob(DATASET + "/Original/*.jpg") + glob(DATASET + "/Original/*.JPG") + glob(DATASET + "/Original/*.tif") # No PNG BMP
# Tp = glob(DATASET + "/Forged/*.jpg") + glob(DATASET + "/Forged/*.png") + glob(DATASET + "/Forged/*.tif") #  # No JPG 

# print(f"Au files: {len(Au)}")
# print(f"Tp files: {len(Tp)}")

# Ds = Au + Tp
# print(f"Ds files: {len(Ds)}")

In [None]:
# CASIA DATASET
Au = glob(DATASET + "/Au/Au*.jpg")
Tp = glob(DATASET + "/Tp/Tp*.jpg") + glob(DATASET + "/Tp/Tp*.tif")

print(f"Au files: {len(Au)}")
print(f"Tp files: {len(Tp)}")

Ds = Au[:5123] + Tp
print(f"Ds files: {len(Ds)}")

In [None]:
r = 5
c = 25

sampled_paths = random.sample(Ds, r*c)

sample = [np.array(Image.open(path).resize((64,64)), dtype=np.float32) / 255.0 for path in sampled_paths[:r*c]]
result = np.concatenate([np.concatenate(sample[i*c:(i+1)*c], axis=1) for i in range(r)])

plt.figure(figsize=(c*2,r*2))
plt.imshow(result)
plt.axis(False), plt.tight_layout()

In [None]:
def split_data(x, p, shuffle=True):
    if shuffle: random.shuffle(Ds)
    bound = int((len(x)/100) * p)
    return x[:bound], x[bound:]

In [None]:
class CustomDataset(Dataset):
    def __init__(self, image_paths, transform, filter=None):
        super().__init__()
        self.paths = image_paths
        self.len = len(self.paths)
        self.filter = filter
        
        self.transform = transform

    def __len__(self): return self.len

    def __getitem__(self, idx): 
        path = self.paths[idx]
        x = Image.open(path).convert('RGB')
        if random.random() > 0.5: x = x.transpose(Image.FLIP_LEFT_RIGHT)
        
        if self.filter == 'ela':
            x = ela.compute(x, 90)
            x = self.transform(utils.array2image(x))
        else:
            x = self.transform(x)

        # y = 0 if 'Original' in path else 1
        y = 0 if 'Au/Au' in path else 1

        return (x, y)

In [None]:
def denormalize(x, mean=(0.5,), std=(0.5,)):
    mean = torch.tensor(mean).view(-1, 1, 1)
    std = torch.tensor(std).view(-1, 1, 1)
    return x * std + mean

In [None]:
train, test = split_data(Ds, 95)
# Au_count = sum('Original' in path for path in train)
# Tp_count = sum('Forged' in path for path in train)

Au_count = sum('Au' in path for path in train)
Tp_count = sum('Tp' in path for path in train)

print(f"Original: {Au_count} | Forged: {Tp_count}")

In [None]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

In [None]:
train_ds = CustomDataset(train, transform, filter='ela')
train_dl = DataLoader(train_ds, batch_size=128)
print(f"(Train) Images: {len(train_ds)} | Batches: {len(train_dl)}")

In [None]:
x, y = next(iter(train_dl))
print(f"Image batch shape: {x.shape}")
# print(f"Batch of labels: {y}")

plt.figure(figsize=(25,5))
for i in range(5):
    img = denormalize(x[i])
    img = img.permute(1, 2, 0).numpy()
    plt.subplot(1, 5, i+1)
    plt.imshow(img)
    if y[i].numpy() == 0: plt.title(f"Real [{y[i]}]")
    else: plt.title(f"Edited [{y[i]}]")
    plt.axis(False), plt.tight_layout()

---

# MODEL

In [None]:
model = net.CNN()
# model = net.ResNet(img_channels=3, num_layers=18, block=BasicBlock, num_classes=2) # 18, 34, 50, 101, 152
model.to(DEVICE)

parameters = model.parameters()
print("Nparams:", sum(p.nelement() for p in parameters))
model.eval()

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr = 0.001)

epochs = 5
lossi, accui = [], []
for e in range(epochs):
    with tqdm(train_dl, unit='batch') as tepoch:
        for xb, yb in tepoch:
            tepoch.set_description(f"Epoch {e+1}")

            xb, yb = xb.to(DEVICE), yb.to(DEVICE)
            logits = model(xb)
            loss = F.cross_entropy(logits, yb)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            accuracy = ((logits.argmax(dim=1) == yb).float().mean())
            
            lossi.append(loss.item())
            accui.append(accuracy.item())
            tepoch.set_postfix(loss=loss.item(), accuracy=accuracy.item())

In [None]:
from datetime import datetime
model_path = f"bce_{datetime.now().strftime('%Y%m%d%H%M%S')}.pth"
torch.save(model.state_dict(), model_path)

In [None]:
plt.figure(figsize=(25, 5))
plt.plot(lossi, label='loss', lw=0.5)
plt.plot(accui, label='accuracy', lw=0.5)
plt.xlim([0, len(lossi)])
plt.grid(alpha=0.25)
plt.legend(), plt.tight_layout()

---

# TEST / INFER

In [None]:
model = net.CNN()
# model = net.ResNet(img_channels=3, num_layers=18, block=BasicBlock, num_classes=2)
model.load_state_dict(torch.load(model_path))
model.to(DEVICE)
print(model.eval())

In [None]:
test_ds = CustomDataset(test, transform, filter='ela')
test_dl = DataLoader(test_ds, batch_size=128)
print(f"(Test) Images: {len(test_ds)} | Batches: {len(test_dl)}")

In [None]:
with torch.no_grad():
    with tqdm(test_dl, unit='batch') as tepoch:
        loss, accuracy = 0, 0
        for xt, yt in tepoch:
            tepoch.set_description(f"Testing..")
            xt, yt = xt.to(DEVICE), yt.to(DEVICE)
            logits = model(xt)

            loss += F.cross_entropy(logits, yt)
            accuracy += (logits.argmax(dim=1) == yt).float().mean()

        loss /= len(test_dl)
        accuracy /= len(test_dl)

print(f"Test loss: {loss:.4f}, Test accuracy: {accuracy * 100:.2f}%")

In [None]:
sample = random.sample(test, 10)
sample_ds = CustomDataset(sample, transform, filter='ela')
sample_dl = DataLoader(sample_ds, batch_size=len(sample_ds))

xt, yt = next(iter(sample_dl))
xt, yt = xt.to(DEVICE), yt.to(DEVICE)
logits = model(xt)
preds = F.softmax(logits, dim=1)[:, 1].tolist()

fig, axs = plt.subplots(2, len(sample_ds), figsize=(len(sample_ds) * 5, 12))
for idx, x in enumerate(xt):
    pred = preds[idx]
    label = "Forged" if pred > 0.5 else "Original"
    img = Image.open(sample[idx]).resize((224, 224))
    axs[0, idx].imshow(img)
    title_color = "red" if round(pred, 0) != yt[idx] else "black"
    axs[0, idx].set_title(f"{yt[idx].item()} | {pred:.5f} | {label}", color=title_color)
    axs[0, idx].axis(False)
    x = denormalize(x.cpu())
    axs[1, idx].imshow(x.permute(1, 2, 0).numpy())
    axs[1, idx].axis(False)
plt.tight_layout()