In [1]:
import torch
import torchvision
import os
import numpy as np
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
import cv2
import copy

In [2]:
from sklearn.metrics import accuracy_score, precision_score, recall_score
device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")

In [3]:
# split dataset
posA = "dataset/celeA"
posB = "dataset/flickr"
negA = "dataset/stylegan"
negB = "dataset/fake1000"

In [4]:
traindata = [os.path.join(posA,filename)for filename in os.listdir(posA)][:800] + \
            [os.path.join(posB,filename)for filename in os.listdir(posB)][:800] + \
            [os.path.join(negA,filename)for filename in os.listdir(negA)][:800] + \
            [os.path.join(negB,filename)for filename in os.listdir(negB)][:800]
trainlabel = [1]*1600 + [0]*1600

validdata = [os.path.join(posA,filename)for filename in os.listdir(posA)][800:900] + \
            [os.path.join(posB,filename)for filename in os.listdir(posB)][800:900] + \
            [os.path.join(negA,filename)for filename in os.listdir(negA)][800:900] + \
            [os.path.join(negB,filename)for filename in os.listdir(negB)][800:900]
validlabel = [1]*200 + [0]*200

testdata  = [os.path.join(posA,filename)for filename in os.listdir(posA)][900:] + \
            [os.path.join(posB,filename)for filename in os.listdir(posB)][900:] + \
            [os.path.join(negA,filename)for filename in os.listdir(negA)][900:] + \
            [os.path.join(negB,filename)for filename in os.listdir(negB)][900:]
testlabel = [1]*200 + [0]*200

len(traindata),len(validdata),len(testdata),len(trainlabel),len(validlabel),len(testlabel)

(3200, 400, 400, 3200, 400, 400)

In [5]:
def SRMConv(img):
    filter1 = np.array([[0,0,0,0,0],
                        [0,-1,2,-1,0],
                        [0,2,-4,2,0],
                        [0,-1,2,-1,0],
                        [0,0,0,0,0]]) /4
    filter2 = np.array([[-1,2,-2,2,-1],
                        [2,-6,8,-6,2],
                        [-2,8,-12,8,-2],
                        [2,-6,8,-6,2],
                        [-1,2,-2,2,-1]]) / 12
    filter3 = np.array([[0,0,0,0,0],
                        [0,0,0,0,0],
                        [0,1,-2,1,0],
                        [0,0,0,0,0],
                        [0,0,0,0,0]]) / 2
    dst1 = np.sum(cv2.filter2D(img, cv2.CV_32F, filter1),axis=2)
    dst2 = np.sum(cv2.filter2D(img, cv2.CV_32F, filter2),axis=2)
    dst3 = np.sum(cv2.filter2D(img, cv2.CV_32F, filter3),axis=2)
    
    #noise = np.concatenate((dst1,dst2,dst3),axis=2)
    noise = np.dstack((dst1,dst2,dst3))
    return noise

In [6]:
class DeepfakeData(Dataset):
    def __init__(self, pathList, labelList, transform=None):
        self.pathList = pathList
        self.labelList = labelList
        self.transform = transform

    def __len__(self):
        return len(self.pathList)
    
    def __getitem__(self,index):
        
        img = cv2.imread(self.pathList[index])
        raw_img = img
        img = SRMConv(img)
        label = self.labelList[index]
        
        if self.transform:
            img = self.transform(img)
            raw_img = self.transform(raw_img)
            
        return (raw_img, img), label


In [7]:
transform = transforms.Compose([transforms.ToTensor()])
traindataset = DeepfakeData(traindata, trainlabel, transform)
validdataset = DeepfakeData(validdata, validlabel, transform)
testdataset = DeepfakeData(testdata, testlabel, transform)

In [8]:
batchsize = 4
trainloader = DataLoader(traindataset, batch_size=batchsize, shuffle=True, num_workers=4)
validloader = DataLoader(validdataset)
testloader = DataLoader(testdataset)

### model

In [9]:
import resnet_cam
class twoStreamCNN(torch.nn.Module):
    def __init__(self):
        super(twoStreamCNN,self).__init__()
        self.model1 = resnet_cam.resnet18(pretrained=True)
        self.model1.fc = torch.nn.Linear(in_features=512, out_features=2, bias=True)
        self.model2 = resnet_cam.resnet18(pretrained=True)
        self.model2.fc = torch.nn.Linear(in_features=512, out_features=2, bias=True)
        self.fc = torch.nn.Linear(1024,2)
        self.avgpool = torch.nn.AdaptiveAvgPool2d((1, 1))
    
    def forward(self,x1,x2):
        _,x1 = self.model1(x1)
        _,x2 = self.model2(x2)
        x = torch.cat((x1,x2),1)
        x = self.avgpool(x)
        x = x.reshape(x.size(0), -1)
        x.size()
        return x

In [10]:
net = twoStreamCNN()
#net.cuda()
net.to(device)
epoch_max = 50
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

In [12]:
acc_max = 0
for epoch in range(epoch_max):
    loss_epoch = 0
    loss_count = 0
    net.train()
    for i, ((img1,img2),label) in enumerate(trainloader):
        img1 = img1.to(device)
        img2 = img2.to(device)
        label = torch.tensor(label,dtype=torch.int64).to(device)

        optimizer.zero_grad()
        output = net(img1,img2)
        loss = criterion(output, label)
        loss.backward()
        optimizer.step()
        
        loss_epoch += loss
        loss_count += 1
        
        if i*batchsize%400 == 0:
            print("Epoch:{} Loss:{} iter:{}".format(epoch, loss_epoch/loss_count, i*batchsize))
            loss_epoch = 0
            loss_count = 0
            
    
            
    loss_epoch = 0
    loss_count = 0
    predict = []
    expect = []
    net.eval()
    for i, ((img1,img2),label) in enumerate(validloader):
        img1 = img1.to(device)
        img2 = img2.to(device)
        label = torch.tensor(label,dtype=torch.int64).to(device)
        with torch.no_grad():
            output = net(img1,img2)
        loss = criterion(output, label)
        
        loss_epoch += loss
        loss_count += 1
        
        predict.append(np.argmax(output.cpu().numpy()))
        expect.append(label.cpu().numpy().item())
    acc = accuracy_score(expect, predict)
    
    if acc_max < acc:
        print("!")
        acc_max = acc
        best_model_wts = copy.deepcopy(net.state_dict())
        best_epoch = epoch
    
    print("EVALUATION\nEpoch:{} Loss:{} acc:{}".format(epoch, loss_epoch/loss_count, acc))
    print("##########################################################################")

  if __name__ == '__main__':


Epoch:0 Loss:0.9625416994094849 iter:0
Epoch:0 Loss:0.6921111941337585 iter:400
Epoch:0 Loss:0.6977041959762573 iter:800
Epoch:0 Loss:0.45439285039901733 iter:1200
Epoch:0 Loss:0.4538036584854126 iter:1600
Epoch:0 Loss:0.37040990591049194 iter:2000
Epoch:0 Loss:0.4453166127204895 iter:2400
Epoch:0 Loss:0.43949562311172485 iter:2800




!
EVALUATION
Epoch:0 Loss:0.21836665272712708 acc:0.9625
##########################################################################
Epoch:1 Loss:0.19982028007507324 iter:0
Epoch:1 Loss:0.31860360503196716 iter:400
Epoch:1 Loss:0.2967146039009094 iter:800
Epoch:1 Loss:0.31556326150894165 iter:1200
Epoch:1 Loss:0.25907251238822937 iter:1600
Epoch:1 Loss:0.2776937484741211 iter:2000
Epoch:1 Loss:0.26153188943862915 iter:2400
Epoch:1 Loss:0.23635952174663544 iter:2800
!
EVALUATION
Epoch:1 Loss:0.10613177716732025 acc:0.9725
##########################################################################
Epoch:2 Loss:0.04279494285583496 iter:0
Epoch:2 Loss:0.2767707407474518 iter:400
Epoch:2 Loss:0.2682432234287262 iter:800
Epoch:2 Loss:0.22415092587471008 iter:1200
Epoch:2 Loss:0.2369365692138672 iter:1600
Epoch:2 Loss:0.18271483480930328 iter:2000
Epoch:2 Loss:0.22947832942008972 iter:2400
Epoch:2 Loss:0.19693158566951752 iter:2800
!
EVALUATION
Epoch:2 Loss:0.09857143461704254 acc:0.975
#######

Epoch:18 Loss:0.0351397879421711 iter:400
Epoch:18 Loss:0.048843372613191605 iter:800
Epoch:18 Loss:0.03718214109539986 iter:1200
Epoch:18 Loss:0.025898341089487076 iter:1600
Epoch:18 Loss:0.03994334861636162 iter:2000
Epoch:18 Loss:0.021865949034690857 iter:2400
Epoch:18 Loss:0.038187041878700256 iter:2800
EVALUATION
Epoch:18 Loss:0.01715449057519436 acc:0.995
##########################################################################
Epoch:19 Loss:0.23639488220214844 iter:0
Epoch:19 Loss:0.03324906900525093 iter:400
Epoch:19 Loss:0.035276900976896286 iter:800
Epoch:19 Loss:0.04554973542690277 iter:1200
Epoch:19 Loss:0.029206829145550728 iter:1600
Epoch:19 Loss:0.03499673679471016 iter:2000
Epoch:19 Loss:0.038566794246435165 iter:2400
Epoch:19 Loss:0.02994343638420105 iter:2800
EVALUATION
Epoch:19 Loss:0.01659964956343174 acc:0.9975
##########################################################################
Epoch:20 Loss:0.002647876739501953 iter:0
Epoch:20 Loss:0.04231509566307068 iter

Epoch:35 Loss:0.028156591579318047 iter:400
Epoch:35 Loss:0.01650088094174862 iter:800
Epoch:35 Loss:0.03002168983221054 iter:1200
Epoch:35 Loss:0.016120582818984985 iter:1600
Epoch:35 Loss:0.014049186371266842 iter:2000
Epoch:35 Loss:0.015317440032958984 iter:2400
Epoch:35 Loss:0.027850951999425888 iter:2800
EVALUATION
Epoch:35 Loss:0.012398185208439827 acc:0.995
##########################################################################
Epoch:36 Loss:0.0002503395080566406 iter:0
Epoch:36 Loss:0.04084049165248871 iter:400
Epoch:36 Loss:0.018076743930578232 iter:800
Epoch:36 Loss:0.02652854472398758 iter:1200
Epoch:36 Loss:0.0220942422747612 iter:1600
Epoch:36 Loss:0.02446967363357544 iter:2000
Epoch:36 Loss:0.038156960159540176 iter:2400
Epoch:36 Loss:0.026872029528021812 iter:2800
EVALUATION
Epoch:36 Loss:0.006831629201769829 acc:1.0
##########################################################################
Epoch:37 Loss:0.0015168190002441406 iter:0
Epoch:37 Loss:0.022399842739105225 

In [13]:
torch.save(best_model_wts,"twostream_Epoch{}_validAcc{}".format(best_epoch,acc_max))