In [1]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import cv2
from sklearn.model_selection import train_test_split

from PIL import Image
import PIL
import torchvision

import os

In [11]:
class Meso4(nn.Module):
    def __init__(self):
        super(Meso4, self).__init__()

        self.resnet = torchvision.models.resnet.resnet34()
        
        self.conv1 = nn.Conv2d(3, 8, 3, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(8)
        self.maxpooling1 = nn.MaxPool2d(kernel_size=(2, 2))
        
        self.conv2 = nn.Conv2d(8, 3, 5, padding=2, bias=False)
        self.bn2 = nn.BatchNorm2d(3)
        
        self.conv3 = nn.Conv2d(8, 16, 5, padding=2, bias=False)
        self.bn3 = nn.BatchNorm2d(16)
        
        self.conv4 = nn.Conv2d(16, 16, 5, padding=2, bias=False)
        self.maxpooling4 = nn.MaxPool2d(kernel_size=(4, 4))
        
        self.dropout = nn.Dropout2d(0.5) # dropout or dropout2d?
        self.fc1 = nn.Linear(1000, 16)
        self.fc2 = nn.Linear(16, 2)
        
        self.relu = nn.ReLU(inplace=True) # is inplace needed?
        self.leakyrelu = nn.LeakyReLU(0.1)
        
    def forward(self, input):
        # input should be (3, 256, 256)
        
        x = self.conv1(input)
        x = self.relu(x)
        x = self.bn1(x)
        x = self.maxpooling1(x)
        
        x = self.conv2(x)
        x = self.relu(x)
        x = self.bn2(x)
        x = self.maxpooling1(x)
        
        x = self.resnet(x)
                
        x = self.dropout(x)
        x = self.fc1(x)
        x = self.leakyrelu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        
        return x
        

def get_file_name(base):
    for path, dir, file in os.walk(base):
        for f in file:
            if ".png" in f:
                yield(os.path.join(path, f))




In [None]:
TRANSFORM = torchvision.transforms.Compose([
    torchvision.transforms.Resize((256, 256)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize([0.5]*3, [0.5]*3)
])


img_data = [] # 12000 * (3, 299, 299)
img_label = [] # 1D array

# fake - 0
# real - 1

for idx, file_path in enumerate(get_file_name('./data/')):
    img = cv2.imread(file_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = Image.fromarray(img)
    
    img = TRANSFORM(img)
    img_data.append(img)
    
    if 'manipulated' in file_path: 
        img_label.append(0)
    else:
        img_label.append(1)
        

img_label = np.asarray(img_label)


In [3]:
X_train, X_test, y_train, y_test = train_test_split(img_data, img_label, test_size = 0.2, random_state = 4487)

train_batchsize = 64
test_batchsize = 1

X_train, y_train = torch.stack(X_train), torch.Tensor(y_train)
X_test, y_test = torch.stack(X_test), torch.Tensor(y_test)
train_dataset = torch.utils.data.TensorDataset(X_train, y_train)
test_dataset = torch.utils.data.TensorDataset(X_test, y_test)

train_loader = torch.utils.data.DataLoader(dataset = train_dataset, batch_size = train_batchsize, shuffle = True)
test_loader = torch.utils.data.DataLoader(dataset = test_dataset, batch_size = test_batchsize, shuffle = True)


In [13]:
echos = 50
learning_rate = 0.001
model = Meso4()
model_adam = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay = 1e-8);
loss_function = nn.CrossEntropyLoss()

has_cuda = torch.cuda.is_available()
if has_cuda:
    model = model.cuda()

# train
for i in range(echos):
    print(i)
    model.train()
    for batch_idx,(image, target) in enumerate(train_loader):
        if has_cuda:
            image = image.cuda()
            target = target.cuda()

        model_adam.zero_grad()
        output = model(image)
        _, preds = torch.max(output.data, 1)
        loss = loss_function(output, target.long())
        
        if batch_idx == 1:
            print("loss:", loss)
            print("accuracy:", torch.sum(preds == target).data/train_batchsize)
        
        loss.backward()
        model_adam.step()

# torch.save(model.state_dict(), "output/meso4_" + str(echos) + ".pkl")


# test
correct = 0
with torch.no_grad():
    for batch_idx,(image, target) in enumerate(test_loader):
        model.eval()
        if has_cuda:
                image = image.cuda()
                target = target.cuda()

        output = model(image)
        _, preds = torch.max(output.data, 1)
        loss = loss_function(output, target.long())
        
        correct += (preds == target)
        print("loss:", loss)
        # print("accuracy:", torch.sum(preds == target).data/train_batchsize)

    print(correct / test_loader.__len__())
            

0
loss: tensor(0.9440, grad_fn=<NllLossBackward0>)
accuracy: tensor(0.4062)
1
loss: tensor(0.6579, grad_fn=<NllLossBackward0>)
accuracy: tensor(0.6250)
2
loss: tensor(0.6443, grad_fn=<NllLossBackward0>)
accuracy: tensor(0.6719)
3
loss: tensor(0.6602, grad_fn=<NllLossBackward0>)
accuracy: tensor(0.6406)
4
loss: tensor(0.6630, grad_fn=<NllLossBackward0>)
accuracy: tensor(0.6094)
5
loss: tensor(0.5879, grad_fn=<NllLossBackward0>)
accuracy: tensor(0.7344)
6
loss: tensor(0.6760, grad_fn=<NllLossBackward0>)
accuracy: tensor(0.6094)
7
loss: tensor(0.5764, grad_fn=<NllLossBackward0>)
accuracy: tensor(0.7500)
8
loss: tensor(0.6143, grad_fn=<NllLossBackward0>)
accuracy: tensor(0.7031)
9
loss: tensor(0.6096, grad_fn=<NllLossBackward0>)
accuracy: tensor(0.7031)
10
loss: tensor(0.6231, grad_fn=<NllLossBackward0>)
accuracy: tensor(0.6719)
11
loss: tensor(0.6545, grad_fn=<NllLossBackward0>)
accuracy: tensor(0.6406)
12
loss: tensor(0.6234, grad_fn=<NllLossBackward0>)
accuracy: tensor(0.6875)
13
loss: 

loss: tensor(0.4038)
loss: tensor(0.4038)
loss: tensor(0.4038)
loss: tensor(0.4038)
loss: tensor(1.1020)
loss: tensor(0.4038)
loss: tensor(0.4037)
loss: tensor(0.4038)
loss: tensor(0.4037)
loss: tensor(1.1020)
loss: tensor(1.1020)
loss: tensor(0.4038)
loss: tensor(1.1020)
loss: tensor(1.1020)
loss: tensor(0.4037)
loss: tensor(1.1020)
loss: tensor(1.1020)
loss: tensor(0.4038)
loss: tensor(1.1020)
loss: tensor(1.1020)
loss: tensor(1.1020)
loss: tensor(0.4038)
loss: tensor(0.4038)
loss: tensor(0.4038)
loss: tensor(0.4038)
loss: tensor(1.1020)
loss: tensor(0.4038)
loss: tensor(0.4038)
loss: tensor(0.4037)
loss: tensor(1.1020)
loss: tensor(1.1020)
loss: tensor(1.1020)
loss: tensor(0.4038)
loss: tensor(0.4037)
loss: tensor(1.1021)
loss: tensor(1.1020)
loss: tensor(1.1026)
loss: tensor(0.4038)
loss: tensor(0.4038)
loss: tensor(0.4038)
loss: tensor(0.4038)
loss: tensor(1.1020)
loss: tensor(0.4038)
loss: tensor(0.4030)
loss: tensor(0.4037)
loss: tensor(0.4038)
loss: tensor(0.4038)
loss: tensor(

loss: tensor(0.4038)
loss: tensor(0.4038)
loss: tensor(1.1020)
loss: tensor(0.4038)
loss: tensor(1.1020)
loss: tensor(0.4038)
loss: tensor(0.4037)
loss: tensor(1.1021)
loss: tensor(0.4038)
loss: tensor(0.4038)
loss: tensor(0.4038)
loss: tensor(1.1020)
loss: tensor(0.4038)
loss: tensor(0.4038)
loss: tensor(0.4066)
loss: tensor(1.1021)
loss: tensor(0.4038)
loss: tensor(0.4038)
loss: tensor(1.1020)
loss: tensor(0.4038)
loss: tensor(0.4037)
loss: tensor(0.4038)
loss: tensor(1.1020)
loss: tensor(1.1020)
loss: tensor(1.1020)
loss: tensor(1.1020)
loss: tensor(1.1020)
loss: tensor(0.4038)
loss: tensor(1.1020)
loss: tensor(1.1031)
loss: tensor(1.1020)
loss: tensor(0.4038)
loss: tensor(0.4037)
loss: tensor(0.4038)
loss: tensor(0.4038)
loss: tensor(0.4038)
loss: tensor(0.4038)
loss: tensor(1.1031)
loss: tensor(1.1020)
loss: tensor(0.4038)
loss: tensor(0.4038)
loss: tensor(0.4037)
loss: tensor(0.4037)
loss: tensor(0.4038)
loss: tensor(1.1021)
loss: tensor(0.4038)
loss: tensor(1.1020)
loss: tensor(

loss: tensor(0.4038)
loss: tensor(0.4037)
loss: tensor(0.4038)
loss: tensor(0.4038)
loss: tensor(1.1020)
loss: tensor(0.4038)
loss: tensor(0.4038)
loss: tensor(0.4038)
loss: tensor(1.1021)
loss: tensor(1.1020)
loss: tensor(0.4038)
loss: tensor(1.1020)
loss: tensor(0.4038)
loss: tensor(0.4035)
loss: tensor(0.4038)
loss: tensor(1.1020)
loss: tensor(0.4028)
loss: tensor(1.1021)
loss: tensor(0.4032)
loss: tensor(0.4038)
loss: tensor(0.4038)
loss: tensor(1.1021)
loss: tensor(1.1020)
loss: tensor(1.1020)
loss: tensor(1.1020)
loss: tensor(0.4038)
loss: tensor(0.4038)
loss: tensor(0.4038)
loss: tensor(0.4038)
loss: tensor(1.1037)
loss: tensor(1.1020)
loss: tensor(1.1020)
loss: tensor(0.4038)
loss: tensor(0.4036)
loss: tensor(0.4038)
loss: tensor(1.1020)
loss: tensor(0.4038)
loss: tensor(0.4038)
loss: tensor(1.1020)
loss: tensor(0.4038)
loss: tensor(1.1020)
loss: tensor(0.4038)
loss: tensor(0.4038)
loss: tensor(0.4038)
loss: tensor(0.4038)
loss: tensor(1.1020)
loss: tensor(1.1020)
loss: tensor(

loss: tensor(0.4038)
loss: tensor(1.1020)
loss: tensor(0.4038)
loss: tensor(0.4038)
loss: tensor(0.4038)
loss: tensor(0.4038)
loss: tensor(0.4038)
loss: tensor(1.1020)
loss: tensor(0.4038)
loss: tensor(1.1020)
loss: tensor(0.4084)
loss: tensor(0.4037)
loss: tensor(0.4038)
loss: tensor(0.4038)
loss: tensor(1.1020)
loss: tensor(0.4038)
loss: tensor(0.4038)
loss: tensor(0.4038)
loss: tensor(1.0998)
loss: tensor(0.4032)
loss: tensor(0.4038)
loss: tensor(0.4037)
loss: tensor(0.4038)
loss: tensor(1.1021)
loss: tensor(0.4037)
loss: tensor(0.4038)
loss: tensor(0.4038)
loss: tensor(0.4038)
loss: tensor(0.4038)
loss: tensor(0.4037)
loss: tensor(0.4038)
loss: tensor(0.4038)
loss: tensor(1.1020)
loss: tensor(0.4038)
loss: tensor(1.1020)
loss: tensor(0.4038)
loss: tensor(1.1020)
loss: tensor(0.4037)
loss: tensor(0.4038)
loss: tensor(0.4037)
loss: tensor(0.4038)
loss: tensor(0.4038)
loss: tensor(0.4038)
loss: tensor(1.1020)
loss: tensor(0.4037)
loss: tensor(0.4038)
loss: tensor(0.4038)
loss: tensor(

loss: tensor(0.4038)
loss: tensor(0.4037)
loss: tensor(1.1021)
loss: tensor(1.1020)
loss: tensor(0.4038)
loss: tensor(1.1020)
loss: tensor(0.4038)
loss: tensor(1.1020)
loss: tensor(0.4038)
loss: tensor(0.4038)
loss: tensor(0.4038)
loss: tensor(0.4038)
loss: tensor(0.4038)
loss: tensor(1.1020)
loss: tensor(0.4038)
loss: tensor(0.4038)
loss: tensor(0.4038)
loss: tensor(0.4038)
loss: tensor(0.4038)
loss: tensor(0.4038)
loss: tensor(0.4038)
loss: tensor(1.1020)
loss: tensor(1.1020)
loss: tensor(0.4038)
loss: tensor(0.4038)
loss: tensor(0.4038)
loss: tensor(0.4035)
loss: tensor(0.4038)
loss: tensor(0.4037)
loss: tensor(0.4037)
loss: tensor(0.4037)
loss: tensor(0.4038)
loss: tensor(1.1020)
loss: tensor(0.4038)
loss: tensor(0.4038)
loss: tensor(0.4035)
loss: tensor(0.4037)
loss: tensor(0.4038)
loss: tensor(1.1040)
loss: tensor(1.1021)
loss: tensor(0.4037)
loss: tensor(0.4038)
loss: tensor(0.4037)
loss: tensor(0.4038)
loss: tensor(0.4038)
loss: tensor(0.4038)
loss: tensor(0.4038)
loss: tensor(

loss: tensor(0.4038)
loss: tensor(0.4038)
loss: tensor(0.4038)
loss: tensor(0.4037)
loss: tensor(1.1020)
loss: tensor(0.4038)
loss: tensor(1.1020)
loss: tensor(0.4038)
loss: tensor(1.1020)
loss: tensor(1.1020)
loss: tensor(1.1020)
loss: tensor(0.4038)
loss: tensor(0.4035)
loss: tensor(0.4037)
loss: tensor(1.1020)
loss: tensor(0.4038)
loss: tensor(0.4038)
loss: tensor(1.1020)
loss: tensor(0.4035)
loss: tensor(0.4038)
loss: tensor(1.1020)
loss: tensor(0.4038)
loss: tensor(0.4029)
loss: tensor(0.4038)
loss: tensor(0.4038)
loss: tensor(0.4038)
loss: tensor(0.4038)
loss: tensor(0.4038)
loss: tensor(0.4038)
loss: tensor(0.4038)
loss: tensor(0.4037)
loss: tensor(0.4038)
loss: tensor(1.1021)
loss: tensor(1.1020)
loss: tensor(0.4038)
loss: tensor(0.4038)
loss: tensor(1.1025)
loss: tensor(1.1021)
loss: tensor(0.4038)
loss: tensor(0.4061)
loss: tensor(0.4038)
loss: tensor(0.4032)
loss: tensor(0.4037)
loss: tensor(1.1020)
loss: tensor(0.4038)
loss: tensor(0.4031)
loss: tensor(0.4038)
loss: tensor(