## train.ipynb


In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!nvidia-smi

Fri Jun 17 09:01:47 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   46C    P8     9W /  70W |      0MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [None]:
import easydict
import os
import sys
from PIL import Image
import tqdm
import shutil
import cv2

import torch
import numpy as np
import random
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
import torchvision.transforms as transforms
import torch.utils.data as data
import torch.backends.cudnn as cudnn
from torchvision import transforms
from google.colab.patches import cv2_imshow

In [None]:
import torch
import torch.nn as nn
from torch.autograd import Function
import torch.nn.functional as F
import random
from torchvision.models import VGG
from torchvision.models.vgg import make_layers

In [None]:
seed = 719
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

## 공격함수

In [None]:
real_train_path = "/content/drive/My Drive/deepfake/real/train/"
fake_train_path = "/content/drive/My Drive/deepfake/fake/train/"
real_valid_path = "/content/drive/My Drive/deepfake/real/valid/"
fake_valid_path = "/content/drive/My Drive/deepfake/fake/valid/"

In [None]:
save_realnoise_train_path = "/content/drive/My Drive/deepfake/new/train/real/"
save_fakenoise_train_path = "/content/drive/My Drive/deepfake/new/train/fake/"
save_realnoise_valid_path = "/content/drive/My Drive/deepfake/new/valid/real/"
save_fakenoise_valid_path = "/content/drive/My Drive/deepfake/new/valid/fake/"

In [None]:
real_train_list = os.listdir(real_train_path)
fake_train_list = os.listdir(fake_train_path)
real_valid_list = os.listdir(real_valid_path)
fake_valid_list = os.listdir(fake_valid_path)

In [None]:
print(len(real_train_list))
print(len(fake_train_list))
print(len(real_valid_list))
print(len(fake_valid_list))

2500
2500
900
900


In [None]:
print(len(os.listdir(save_realnoise_train_path)))
print(len(os.listdir(save_fakenoise_train_path)))
print(len(os.listdir(save_realnoise_valid_path)))
print(len(os.listdir(save_fakenoise_valid_path)))

0
0
0
0


In [None]:
# def get_attack_params(attack_name):
#     attack_dict = {
#         'gaussian_noise': {'noise_std': (30, 300)},
#         'sharpening' : {'center' : (1, 5)},
#         'jpeg' : {'quality' : (5, 95)},
#         'gaussian_blur' : {'ksize' : (2, 20)},
#         # 'unsharp_mask' : {'amount' : (2, 20)},
#         'universal_perturbation' : {'level' : (0.05, 0.5)}
#     }

#     return attack_dict[attack_name.lower()]

In [None]:
def get_attack_params(attack_name):
    attack_dict = {
        'gaussian_noise': {'noise_std': (30, 210)},
        'sharpening' : {'center' : (0.5, 3.5)},
        'jpeg' : {'quality' : (35, 90)},
        'median_blur' : {'ksize' : (2, 15)},
        'universal_perturbation' : {'level' : (0.05, 0.35)}
    }

    return attack_dict[attack_name.lower()]

In [None]:
class Sharpening(Function):

    def forward(ctx, image):
        (start, end) = get_attack_params('sharpening')['center']
        i = np.random.uniform(start, end)
        sharpening_arr = np.array([[0, -i, 0],
                                   [-i, 4*i+1, -i],
                                   [0, -i, 0]])
        output = cv2.filter2D(image, -1, sharpening_arr)
        return output

In [None]:
class GaussianNoise(Function):
    
    def forward(ctx,image):
        (start, end) = get_attack_params('gaussian_noise')['noise_std']
        noise_std = np.random.uniform(start, end)
        mean = 0
        sigma = noise_std ** 0.5
        gauss = np.random.normal(mean, sigma,image.shape)
        res = image + gauss
        noisy = np.clip(res, 0, 255).astype(np.uint8)
        return noisy

In [None]:
# class GaussianBlur(Function):

#     def forward(ctx, image):
#         (start, end) = get_attack_params('gaussian_blur')['ksize']
#         ksize = np.random.uniform(start, end)
#         ksize = int(ksize)
#         if ksize % 2 == 0:
#             i = ksize + 1
#         else:
#             i = ksize
#         blur_img = cv2.GaussianBlur(image, (i, i) , 0)
#         return blur_img

In [None]:
class MedianBlur(Function):

    def forward(ctx, image):
        (start, end) = get_attack_params('median_blur')['ksize']
        ksize = np.random.uniform(start, end)
        ksize = int(ksize)
        if ksize % 2 == 0:
            i = ksize + 1
        else:
            i = ksize
        blur_img = cv2.medianBlur(image, i)
        return blur_img

In [None]:
# class unsharp_mask(Function):

#     def forward(ctx, image):
#         (start, end) = get_attack_params('unsharp_mask')['amount']
#         amount = np.random.uniform(start, end)
#         result = unsharp_mask(image, radius=3, amount=amount)
#         result = cv2.convertScaleAbs(result, alpha=(255.0))
#         return result

In [None]:
class UniversalPerturbation(Function):

    def forward(ctx, image):
        pert_path = '/content/drive/My Drive/deepfake/perturbation_classification/googlenet_no_data.npy'
        pert_ndarr = np.load(pert_path)
        pert_ndarr = np.squeeze(pert_ndarr)

        pert_img = Image.fromarray(pert_ndarr, 'RGB')
        pert_img = pert_img.resize((960, 540))
        np_pert = np.asarray(pert_img)

        # img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        np_human = np.asarray(image)

        (start, end) = get_attack_params('universal_perturbation')['level']
        level = np.random.uniform(start, end)

        np_pert = np_pert * level
        np_pert = np_pert.astype(int)
        np_pert = np.clip(np_pert, 0, 255)

        np_plus = np_pert + np_human.astype(int)
        np_plus = np.clip(np_plus, 0, 255)

        # plus_img = Image.fromarray(np_plus.astype('uint8'), 'RGB')
        
        return np_plus

In [None]:
def fetch_attack():
    
    attacks1 = [MedianBlur.apply]
    attacks2 = [GaussianNoise.apply, Sharpening.apply, UniversalPerturbation.apply]
    random.shuffle(attacks2)  
    attacks = attacks1 + attacks2

    return attacks

In [None]:
def go(real_train_list, fake_train_list, real_valid_list, fake_valid_list):

    attacks = fetch_attack()
    global train_realnoise
    global train_fakenoise
    global valid_realnoise
    global valid_fakenoise
    for real_img in real_train_list:
        img = real_train_path + real_img
        train_realnoise = cv2.imread(img)
        for attack in attacks:
            train_realnoise = attack(train_realnoise)
        # JPEG 적용
        (start, end) = get_attack_params('jpeg')['quality']
        noise_amount = np.random.uniform(start, end)
        save_path = save_realnoise_train_path + real_img
        cv2.imwrite(save_path, train_realnoise, [cv2.IMWRITE_JPEG_QUALITY, noise_amount])

    for fake_img in fake_train_list:
        img = fake_train_path + fake_img
        train_fakenoise = cv2.imread(img)
        for attack in attacks:
            train_fakenoise = attack(train_fakenoise)
        # JPEG 적용
        (start, end) = get_attack_params('jpeg')['quality']
        noise_amount = np.random.uniform(start, end)
        save_path = save_fakenoise_train_path + fake_img
        cv2.imwrite(save_path, train_fakenoise, [cv2.IMWRITE_JPEG_QUALITY, noise_amount])

    for real_img in real_valid_list:
        img = real_valid_path + real_img
        valid_realnoise = cv2.imread(img)
        for attack in attacks:
            valid_realnoise = attack(valid_realnoise)
        # JPEG 적용
        (start, end) = get_attack_params('jpeg')['quality']
        noise_amount = np.random.uniform(start, end)
        save_path = save_realnoise_valid_path +real_img
        cv2.imwrite(save_path, valid_realnoise, [cv2.IMWRITE_JPEG_QUALITY, noise_amount])

    for fake_img in fake_valid_list:
        img = fake_valid_path + fake_img
        valid_fakenoise = cv2.imread(img)
        for attack in attacks:
            valid_fakenoise = attack(valid_fakenoise)
        # JPEG 적용
        (start, end) = get_attack_params('jpeg')['quality']
        noise_amount = np.random.uniform(start, end)
        save_path = save_fakenoise_valid_path + fake_img
        cv2.imwrite(save_path, valid_fakenoise, [cv2.IMWRITE_JPEG_QUALITY, noise_amount])

In [None]:
go(real_train_list, fake_train_list, real_valid_list, fake_valid_list)

In [None]:
print(len(os.listdir(save_realnoise_train_path)))
print(len(os.listdir(save_fakenoise_train_path)))
print(len(os.listdir(save_realnoise_valid_path)))
print(len(os.listdir(save_fakenoise_valid_path)))

2500
2500
900
900


## txt 파일 만들기

In [None]:
train_list_txt = "/content/drive/My Drive/deepfake/new_train_list.txt"

In [None]:
a = os.listdir(real_train_path)
str_a = "real/train/" + " 0\nreal/train/".join(a) + " 0\n"
f = open(train_list_txt, 'w')
f.write(str_a)
f.close()

In [None]:
b = os.listdir(fake_train_path)
str_b = "fake/train/" + " 1\nfake/train/".join(b) + " 1\n"
f = open(train_list_txt, 'a')
f.write(str_b)
f.close()

In [None]:
c = os.listdir(save_realnoise_train_path)
str_c = "new/train/real/" + " 0\nnew/train/real/".join(c) + " 0\n"
f = open(train_list_txt, 'a')
f.write(str_c)
f.close()

In [None]:
d = os.listdir(save_fakenoise_train_path)
str_d = "new/train/fake/" + " 1\nnew/train/fake/".join(d) + " 1\n"
f = open(train_list_txt, 'a')
f.write(str_d)
f.close()

In [None]:
valid_list_txt = "/content/drive/My Drive/deepfake/new_valid_list.txt"

In [None]:
a = os.listdir(real_valid_path)
str_a = "real/valid/" + " 0\nreal/valid/".join(a) + " 0\n"
f = open(valid_list_txt, 'w')
f.write(str_a)
f.close()

In [None]:
b = os.listdir(fake_valid_path)
str_b = "fake/valid/" + " 1\nfake/valid/".join(b) + " 1\n"
f = open(valid_list_txt, 'a')
f.write(str_b)
f.close()

In [None]:
c = os.listdir(save_realnoise_valid_path)
str_c = "new/valid/real/" + " 0\nnew/valid/real/".join(c) + " 0\n"
f = open(valid_list_txt, 'a')
f.write(str_c)
f.close()

In [None]:
d = os.listdir(save_fakenoise_valid_path)
str_d = "new/valid/fake/" + " 1\nnew/valid/fake/".join(d) + " 1\n"
f = open(valid_list_txt, 'a')
f.write(str_d)
f.close()

## 1. 기본 설정

In [None]:
cudnn.benchmark = True

args = easydict.EasyDict({
    "gpu": 0,
    "num_workers": 32,

    "root": "/content/drive/My Drive/deepfake",
    "train_list": "/content/drive/My Drive/deepfake/new_train_list.txt",
    "valid_list": "/content/drive/My Drive/deepfake/new_valid_list.txt",

    "learning_rate": 0.001,
    "num_epochs": 10,
    "batch_size": 32,

    "save_fn": "/content/drive/My Drive/deepfake/model/xception_new_model_result.pth.tar",
})

assert os.path.isfile(args.train_list), 'wrong path'
assert os.path.isfile(args.valid_list), 'wrong path'

## 2. 모델
- 참고문헌\[1]\[2]에 따르면 Xception\[3] 모델이 변조 영상 탐지에 가장 좋은 성능을 보여주어 해당 모델을 기본 모델로 선정

\[1] FaceForensics++: Learning to Detect Manipulated Facial Images, ICCV 2019.  
\[2] A Large-scale Challenging Dataset for DeepFace Forensics, CVPR 2020.  
\[3] Xception: Deep Learning with Depthwise Seperable Convolutions, CVPR 2017.

### 3-1 Xception 구현


In [None]:
"""
Author: Andreas Rössler,
Implemented in https://github.com/ondyari/FaceForensics under MIT license
"""


class SeparableConv2d(nn.Module):
    def __init__(self,in_channels,out_channels,kernel_size=1,stride=1,padding=0,dilation=1,bias=False):
        super(SeparableConv2d,self).__init__()

        self.conv1 = nn.Conv2d(in_channels,in_channels,kernel_size,stride,padding,dilation,groups=in_channels,bias=bias)
        self.pointwise = nn.Conv2d(in_channels,out_channels,1,1,0,1,1,bias=bias)

    def forward(self,x):
        x = self.conv1(x)
        x = self.pointwise(x)
        return x


class Block(nn.Module):
    def __init__(self,in_filters,out_filters,reps,strides=1,start_with_relu=True,grow_first=True):
        super(Block, self).__init__()

        if out_filters != in_filters or strides!=1:
            self.skip = nn.Conv2d(in_filters,out_filters,1,stride=strides, bias=False)
            self.skipbn = nn.BatchNorm2d(out_filters)
        else:
            self.skip=None

        self.relu = nn.ReLU(inplace=True)
        rep=[]

        filters=in_filters
        if grow_first:
            rep.append(self.relu)
            rep.append(SeparableConv2d(in_filters,out_filters,3,stride=1,padding=1,bias=False))
            rep.append(nn.BatchNorm2d(out_filters))
            filters = out_filters

        for i in range(reps-1):
            rep.append(self.relu)
            rep.append(SeparableConv2d(filters,filters,3,stride=1,padding=1,bias=False))
            rep.append(nn.BatchNorm2d(filters))

        if not grow_first:
            rep.append(self.relu)
            rep.append(SeparableConv2d(in_filters,out_filters,3,stride=1,padding=1,bias=False))
            rep.append(nn.BatchNorm2d(out_filters))

        if not start_with_relu:
            rep = rep[1:]
        else:
            rep[0] = nn.ReLU(inplace=False)

        if strides != 1:
            rep.append(nn.MaxPool2d(3,strides,1))
        self.rep = nn.Sequential(*rep)

    def forward(self,inp):
        x = self.rep(inp)

        if self.skip is not None:
            skip = self.skip(inp)
            skip = self.skipbn(skip)
        else:
            skip = inp

        x+=skip
        return x


class Xception(nn.Module):
    def __init__(self, num_classes=1000):
        super(Xception, self).__init__()
        self.num_classes = num_classes

        self.conv1 = nn.Conv2d(3,32,3,2,0,bias=False)
        self.bn1 = nn.BatchNorm2d(32)
        self.relu = nn.ReLU(inplace=True)

        self.conv2 = nn.Conv2d(32,64,3,bias=False)
        self.bn2 = nn.BatchNorm2d(64)

        self.block1=Block(64,128,2,2,start_with_relu=False,grow_first=True)
        self.block2=Block(128,256,2,2,start_with_relu=True,grow_first=True)
        self.block3=Block(256,728,2,2,start_with_relu=True,grow_first=True)

        self.block4=Block(728,728,3,1,start_with_relu=True,grow_first=True)
        self.block5=Block(728,728,3,1,start_with_relu=True,grow_first=True)
        self.block6=Block(728,728,3,1,start_with_relu=True,grow_first=True)
        self.block7=Block(728,728,3,1,start_with_relu=True,grow_first=True)

        self.block8=Block(728,728,3,1,start_with_relu=True,grow_first=True)
        self.block9=Block(728,728,3,1,start_with_relu=True,grow_first=True)
        self.block10=Block(728,728,3,1,start_with_relu=True,grow_first=True)
        self.block11=Block(728,728,3,1,start_with_relu=True,grow_first=True)

        self.block12=Block(728,1024,2,2,start_with_relu=True,grow_first=False)

        self.conv3 = SeparableConv2d(1024,1536,3,1,1)
        self.bn3 = nn.BatchNorm2d(1536)

        self.conv4 = SeparableConv2d(1536,2048,3,1,1)
        self.bn4 = nn.BatchNorm2d(2048)

        self.fc = nn.Linear(2048, num_classes)

    def features(self, input):
        x = self.conv1(input)
        x = self.bn1(x)
        x = self.relu(x)

        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)

        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        x = self.block5(x)
        x = self.block6(x)
        x = self.block7(x)
        x = self.block8(x)
        x = self.block9(x)
        x = self.block10(x)
        x = self.block11(x)
        x = self.block12(x)

        x = self.conv3(x)
        x = self.bn3(x)
        x = self.relu(x)

        x = self.conv4(x)
        x = self.bn4(x)
        return x

    def logits(self, features):
        x = self.relu(features)

        x = F.adaptive_avg_pool2d(x, (1, 1)) 
        x = x.view(x.size(0), -1)
        x = self.last_linear(x)
        return x

    def forward(self, input):
        x = self.features(input)
        x = self.logits(x)
        return x


## 기존 Xception에 Dropout만 추가
class xception(nn.Module):
    def __init__(self, num_out_classes=2, dropout=0.5):
        super(xception, self).__init__()

        self.model = Xception(num_classes=num_out_classes)
        self.model.last_linear = self.model.fc
        del self.model.fc

        num_ftrs = self.model.last_linear.in_features
        if not dropout:
            self.model.last_linear = nn.Linear(num_ftrs, num_out_classes)
        else:            
            self.model.last_linear = nn.Sequential(
                nn.Dropout(p=dropout),
                nn.Linear(num_ftrs, num_out_classes)
            )

    def forward(self, x):
        x = self.model(x)
        return x

### 2-2. Pretrained Weight
- FaceForensics++ 데이터로 학습된 pre-trained model weight 다운\[4]

\[4]https://github.com/HongguLiu/Deepfake-Detection

In [None]:
!wget -O deepfake_c0_xception.pkl --no-check-certificate 'https://docs.google.com/uc?export=download&id=1eHRN117X0loEff7EBk1mGMJeGbGKsd7m'

--2022-06-17 09:02:12--  https://docs.google.com/uc?export=download&id=1eHRN117X0loEff7EBk1mGMJeGbGKsd7m
Resolving docs.google.com (docs.google.com)... 142.250.148.100, 142.250.148.102, 142.250.148.138, ...
Connecting to docs.google.com (docs.google.com)|142.250.148.100|:443... connected.
HTTP request sent, awaiting response... 303 See Other
Location: https://doc-0c-2k-docs.googleusercontent.com/docs/securesc/ha0ro937gcuc7l7deffksulhg5h7mbp1/dfop74h5unqqe80pl1nketg7if6grnsu/1655456475000/05567444099345578170/*/1eHRN117X0loEff7EBk1mGMJeGbGKsd7m?e=download [following]
--2022-06-17 09:02:17--  https://doc-0c-2k-docs.googleusercontent.com/docs/securesc/ha0ro937gcuc7l7deffksulhg5h7mbp1/dfop74h5unqqe80pl1nketg7if6grnsu/1655456475000/05567444099345578170/*/1eHRN117X0loEff7EBk1mGMJeGbGKsd7m?e=download
Resolving doc-0c-2k-docs.googleusercontent.com (doc-0c-2k-docs.googleusercontent.com)... 74.125.126.132, 2607:f8b0:4001:c1d::84
Connecting to doc-0c-2k-docs.googleusercontent.com (doc-0c-2k-docs.

## 3. 훈련

### 3-1 전처리


In [None]:
xception_default = {
    'train': transforms.Compose([transforms.CenterCrop((299, 299)),
                                 transforms.ToTensor(),
                                 transforms.RandomHorizontalFlip(),
                                 transforms.Normalize([0.5]*3, [0.5]*3),
                                 ]),
    'valid': transforms.Compose([transforms.CenterCrop((299, 299)),
                                 transforms.ToTensor(),
                                 transforms.Normalize([0.5]*3, [0.5]*3),
                                 ]),
    'test': transforms.Compose([transforms.CenterCrop((299, 299)),
                                transforms.ToTensor(),
                                transforms.Normalize([0.5] * 3, [0.5] * 3),
                                ]),
}

### 3-2 Train/Validate/Dataset 함수


In [None]:
# util

def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, 'model_best.pth.tar')

def adjust_learning_rate(optimizer, epoch, args):
    """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
    lr = args.lr * (0.1 ** (epoch // 30))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

In [None]:
# custom dataset

class ImageRecord(object):
    def __init__(self, row):
        self._data = row

    @property
    def path(self):
        return self._data[0]

    @property
    def label(self):
        return int(self._data[1])


class DFDCDatatset(data.Dataset):
    def __init__(self, root_path, list_file, transform=None):
        self.root_path = root_path
        self.list_file = list_file
        self.transform = transform

        self._parse_list()

    def _load_image(self, image_path):
        return Image.open(image_path).convert('RGB')

    def _parse_list(self):
        self.image_list = [ImageRecord(x.strip().split(' ')) for x in open(self.list_file)]

    def __getitem__(self, index):
        record = self.image_list[index]
        image_name = os.path.join(self.root_path, record.path)
        image = self._load_image(image_name)

        if self.transform is not None:
            image = self.transform(image)

        return image, record.label

    def __len__(self):
        return len(self.image_list)

In [None]:
# train / validate

def train(train_loader, model, criterion, optimizer, epoch):   
    n = 0
    running_loss = 0.0
    running_corrects = 0

    model.train()

    with tqdm.tqdm(train_loader, total=len(train_loader), desc="Train", file=sys.stdout) as iterator:
        for images, target in iterator:
            if args.gpu is not None:
                images = images.cuda(args.gpu, non_blocking=True)
                target = target.cuda(args.gpu, non_blocking=True)

            outputs = model(images)
            _, pred = torch.max(outputs.data, 1)

            loss = criterion(outputs, target)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            n += images.size(0)
            running_loss += loss.item() * images.size(0)
            running_corrects += torch.sum(pred == target.data)

            epoch_loss = running_loss / float(n)
            epoch_acc = running_corrects / float(n)

            log = 'loss - {:.4f}, acc - {:.3f}'.format(epoch_loss, epoch_acc)
            iterator.set_postfix_str(log)

    scheduler.step()


def validate(test_loader, model, criterion):
    n = 0
    running_loss = 0.0
    running_corrects = 0

    model.eval()

    with tqdm.tqdm(valid_loader, total=len(valid_loader), desc="Valid", file=sys.stdout) as iterator:
        for images, target in iterator:
            if args.gpu is not None:
                images = images.cuda(args.gpu, non_blocking=True)
                target = target.cuda(args.gpu, non_blocking=True)

            with torch.no_grad():
                output = model(images)

            loss = criterion(output, target)
            _, pred = torch.max(output.data, 1)

            n += images.size(0)
            running_loss += loss.item() * images.size(0)
            running_corrects += torch.sum(pred == target.data)

            epoch_loss = running_loss / float(n)
            epoch_acc = running_corrects / float(n)

            log = 'loss - {:.4f}, acc - {:.3f}'.format(epoch_loss, epoch_acc)
            iterator.set_postfix_str(log)

    return epoch_acc

### 3-3. 훈련



In [None]:
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim

import torchvision
from torchvision import transforms
from torchvision import  models


import os
import random


In [None]:
if torch.cuda.is_available():
    device = torch.device('cuda:0')
    print("Use Cuda")
else:
    device = torch.device('cpu')
    print("Use CPU")

Use Cuda


In [None]:
model = xception(num_out_classes=2, dropout=0.5)
print("=> creating model '{}'".format('xception'))

best_model = torch.load("/content/drive/My Drive/deepfake/model/origin_realfake.pth.tar")
model.load_state_dict(best_model['state_dict'])
print("=> model weight best_model is loaded")

model.to(device)

criterion = nn.CrossEntropyLoss().cuda()
optimizer = optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08)
scheduler = lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)

=> creating model 'xception'
=> model weight best_model is loaded


In [None]:
train_dataset = DFDCDatatset(args.root,
                             args.train_list,
                             xception_default["train"],
                             )

valid_dataset = DFDCDatatset(args.root,
                             args.valid_list,
                             xception_default["valid"],
                             )

In [None]:
train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=args.batch_size,
                                           shuffle=True,
                                           num_workers=args.num_workers,
                                           pin_memory=True,
                                           )

valid_loader = torch.utils.data.DataLoader(valid_dataset,
                                           batch_size=args.batch_size,
                                           shuffle=False,
                                           num_workers=args.num_workers,
                                           pin_memory=False,
                                           )

  cpuset_checked))


In [None]:

## epoch = 10
print('-' * 50)

for i in range(args.num_epochs):
    print('Epoch {}/{}'.format(i+1, args.num_epochs))
    train(train_loader, model, criterion, optimizer, 0)
    acc = validate(valid_loader, model, criterion)

max_acc = 0

if acc >= max_acc:
        save_checkpoint(state={'epoch': args.num_epochs + 1,
                               'state_dict': model.state_dict(),
                               'best_acc1': acc,
                               'optimizer': optimizer.state_dict(),},
                        is_best=False,
                        filename=args.save_fn,
                        )
        max_acc = acc