## Multi-Colorspace EfficientNet (MC-EffNet)
> Reference: [Distinguishing Natural and Computer-Generated Images using Multi-Colorspace fused EfficientNet](https://arxiv.org/pdf/2110.09428.pdf)

In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, SubsetRandomSampler
from torchvision import transforms
from torchvision.datasets import ImageFolder
from efficientnet_pytorch import EfficientNet
from skimage.color import rgb2lab, lab2lch, rgb2hsv
import numpy as np
from sklearn.metrics import accuracy_score, precision_score, recall_score
import os
from tqdm import tqdm

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
torch.cuda.empty_cache()

cuda:0


In [2]:
NUM_EPOCH = 5
BS = 32
LR = 1e-4
NUM_CLASSES = 2
MODEL = 'efficientnet-b0'
LEAST_NUM_TRAIN_DATA = 6000
LEAST_NUM_VAL_DATA = 6000
INPUT_IMG_SIZE = (128, 128)

# ----- Original Fake2M dataset -----
mdl_pth = f'./model/efficientnet_model_MC_Fake2M.pth'
rst_pth = f'./model/efficientnet_rst_MC_Fake2M.pth'
train_data_pth = './Fake2M/train/'
val_data_pth = './Fake2M/val/'

# ----- Random compressed and cropped Fake2M dataset -----
# mdl_pth = f'./model/efficientnet_model_MC_Fake2M_cc.pth'
# rst_pth = f'./model/efficientnet_rst_MC_Fake2M_cc.pth'
# train_data_pth = './Fake2M_cc/train/'
# val_data_pth = './Fake2M_cc/val/'

In [3]:
def normalize0to1(image):
    modified_img = image
    min_1 = np.min(modified_img[:,:,0])
    max_1 = np.max(modified_img[:,:,0])
    modified_img[:,:,0] = np.minimum((modified_img[:,:,0] - min_1) / (max_1 - min_1 + 1e-8), 1.0)
    min_2 = np.min(modified_img[:,:,1])
    max_2 = np.max(modified_img[:,:,1])
    modified_img[:,:,1] = np.minimum((modified_img[:,:,1] - min_2) / (max_2 - min_2 + 1e-8), 1.0)
    min_3 = np.min(modified_img[:,:,2])
    max_3 = np.max(modified_img[:,:,2])
    modified_img[:,:,2] = np.minimum((modified_img[:,:,2] - min_3) / (max_3 - min_3 + 1e-8), 1.0)
    return modified_img

def rgb_to_hsv(image):
    hsv_image = rgb2hsv(image)
    hsv_image = normalize0to1(hsv_image)
    return hsv_image

def rgb_to_lch(image):
    lab_image = rgb2lab(image)
    lch_image = lab2lch(lab_image)
    lch_image = normalize0to1(lch_image)
    np.nan_to_num(lch_image, copy=False, nan=0.0, posinf=None, neginf=None)
    return lch_image


class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, root, rgb_tranform, hsv_tranform, lch_tranform):
        self.rgb_dataset = ImageFolder(root, transform=rgb_tranform)
        self.hsv_dataset = ImageFolder(root, transform=hsv_tranform)
        self.lch_dataset = ImageFolder(root, transform=lch_tranform)
        

    def __getitem__(self, index):
        rgb_data, labels = self.rgb_dataset[index]
        hsv_data, _ = self.hsv_dataset[index]
        lch_data, _ = self.lch_dataset[index]

        return rgb_data, hsv_data, lch_data, labels

    def __len__(self):
        return len(self.rgb_dataset)

class EfficientNetWrapper(nn.Module):
    def __init__(self, num_classes=1280):
        super(EfficientNetWrapper, self).__init__()
        self.rgb_model = EfficientNet.from_pretrained('efficientnet-b0')
        self.hsv_model = EfficientNet.from_pretrained('efficientnet-b0')
        self.lch_model = EfficientNet.from_pretrained('efficientnet-b0')

        # remove last layer
        self.rgb_model._fc = nn.Identity()
        self.hsv_model._fc = nn.Identity()
        self.lch_model._fc = nn.Identity()

        # concat
        self.fc = nn.Linear(3 * num_classes, 2)

    def forward(self, rgb_input, hsv_input, lch_input):
        rgb_output = self.rgb_model(rgb_input)
        hsv_output = self.hsv_model(hsv_input)
        lch_output = self.lch_model(lch_input)

        # concat output
        concatenated_output = torch.cat([rgb_output, hsv_output, lch_output], dim=1)

        # final output
        final_output = self.fc(concatenated_output)

        return final_output

In [4]:
rgb_transform = transforms.Compose([
    transforms.Resize(INPUT_IMG_SIZE),
    transforms.ToTensor()
])
hsv_transform = transforms.Compose([
    transforms.Resize(INPUT_IMG_SIZE),
    transforms.Lambda(lambda x: rgb_to_hsv(np.array(x))),
    transforms.ToTensor()
])
lch_transform = transforms.Compose([
    transforms.Resize(INPUT_IMG_SIZE),
    transforms.Lambda(lambda x: rgb_to_lch(np.array(x))),
    transforms.ToTensor()
])

train_dataset = CustomDataset(
    train_data_pth,
    rgb_tranform=rgb_transform,
    hsv_tranform=hsv_transform,
    lch_tranform=lch_transform
)
val_dataset = CustomDataset(
    val_data_pth,
    rgb_tranform=rgb_transform,
    hsv_tranform=hsv_transform,
    lch_tranform=lch_transform
)

In [5]:
train_sampler = SubsetRandomSampler(range(0, len(train_dataset), len(train_dataset) // LEAST_NUM_TRAIN_DATA)) 
val_sampler = SubsetRandomSampler(range(0, len(val_dataset), len(val_dataset) // LEAST_NUM_VAL_DATA))

train_dataloader = DataLoader(train_dataset, batch_size=BS, sampler=train_sampler)
val_dataloader = DataLoader(val_dataset, batch_size=BS, sampler=val_sampler)

In [6]:
print(f'Number of training data: {len(train_dataloader) * train_dataloader.batch_size}')
print(f'Number of validation data: {len(val_dataloader) * val_dataloader.batch_size}')

Number of training data: 6944
Number of validation data: 6784


In [7]:
for rgb, hsv, lch, labels in train_dataloader:
    rgb_img = rgb[0]
    hsv_img = hsv[0]
    lch_img = lch[0]
    print(f'RGB range: ({torch.min(rgb_img)}, {torch.max(rgb_img)})')
    print(f'HSV range: ({torch.min(hsv_img)}, {torch.max(hsv_img)})')
    print(f'LCH range: ({torch.min(lch_img)}, {torch.max(lch_img)})')
    break

RGB range: (0.007843137718737125, 0.9960784316062927)
HSV range: (0.0, 0.9999999899354839)
LCH range: (0.0, 0.9999999998899461)


In [8]:
# Get the number of data in each class
check_data_label = False 
if check_data_label:
    train_lbl_0 = 0
    train_lbl_1 = 0
    for _, _, _, labels in tqdm(train_dataloader):
        train_lbl_0 += torch.count_nonzero(1-labels)
        train_lbl_1 += torch.count_nonzero(labels)
        
    print(f'Number of label 0 (fake) training images: {train_lbl_0}')
    print(f'Number of label 1 (real) training images: {train_lbl_1}')

    val_lbl_0 = 0
    val_lbl_1 = 0
    for _, _, _, labels in tqdm(val_dataloader):
        val_lbl_0 += torch.count_nonzero(1-labels)
        val_lbl_1 += torch.count_nonzero(labels)
        
    print(f'Number of label 0 (fake) validation images: {val_lbl_0}')
    print(f'Number of label 1 (real) validation images: {val_lbl_1}')

In [9]:
model = EfficientNetWrapper()
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
criterion = nn.CrossEntropyLoss()
model.to(device)

if os.path.exists(mdl_pth):
    checkpoint = torch.load(mdl_pth)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    print("Trained Model loaded...")
else:
    print("New Model created...")
    
if os.path.exists(rst_pth):
    checkpoint = torch.load(rst_pth)
    epoch_train_loss = checkpoint['epoch_training_loss']
    epoch_train_accuracy = checkpoint['epoch_training_accuracy']
    epoch_train_precision = checkpoint['epoch_training_precision']
    epoch_train_recall = checkpoint['epoch_train_recall']
    epoch_val_loss = checkpoint['epoch_val_loss']
    epoch_val_accuracy = checkpoint['epoch_val_accuracy']
    epoch_val_precision = checkpoint['epoch_val_precision']
    epoch_val_recall = checkpoint['epoch_val_recall']
    best_val_acc = checkpoint['best_val_acc']
    corres_val_precision = checkpoint['corres_val_precision']
    corres_val_recall = checkpoint['corres_val_recall']
    print("History result loaded...")
    print('Best model:')
    print(f'Accuracy = {best_val_acc:.4f}')
    print(f'Precision = {corres_val_precision:.4f}')
    print(f'Recall = {corres_val_recall:.4f}')
else:
    epoch_train_loss = []
    epoch_train_accuracy = []
    epoch_train_precision = []
    epoch_train_recall = []
    epoch_val_loss = []
    epoch_val_accuracy = []
    epoch_val_precision = []
    epoch_val_recall = []
    best_val_acc = 0.0 
    corres_val_precision = 0.0
    corres_val_recall = 0.0
    print("No history result...")

Loaded pretrained weights for efficientnet-b0
Loaded pretrained weights for efficientnet-b0
Loaded pretrained weights for efficientnet-b0
Trained Model loaded...
History result loaded...
Best model:
Accuracy = 0.9098
Precision = 0.9132
Recall = 0.9549


In [10]:
for epoch in range(NUM_EPOCH):
    # ====================== train ===========================
    train_all_labels = []
    train_all_preds = []
    model.train()
    for rgb_data, hsv_data, lch_data, labels in tqdm(train_dataloader):
        rgb_data, hsv_data, lch_data, labels = rgb_data.to(device), hsv_data.to(device), lch_data.to(device), labels.to(device)

        hsv_data = hsv_data.reshape((rgb_data.shape)).float()
        lch_data = lch_data.reshape((rgb_data.shape)).float()

        optimizer.zero_grad()

        outputs = model(rgb_data, hsv_data, lch_data)
        # hsv_output = model.hsv_model(hsv_data)
        # lch_output = model.lch_model(lch_data)
        
        # concatenated_output = torch.cat([rgb_output, hsv_output, lch_output], dim=1)
        
        # outputs = model.fc(concatenated_output)
        
        _, preds = torch.max(outputs, 1)
        
        train_loss = criterion(outputs, labels)
        train_loss.backward()
        optimizer.step()
        
        train_all_labels.extend(labels.cpu().numpy())
        train_all_preds.extend(preds.cpu().numpy())
        
    train_accuracy = accuracy_score(train_all_labels, train_all_preds)
    train_precision = precision_score(train_all_labels, train_all_preds, average='weighted')
    train_recall = recall_score(train_all_labels, train_all_preds)
        
    epoch_train_loss += [train_loss.item()]
    epoch_train_accuracy += [train_accuracy]
    epoch_train_precision += [train_precision]
    epoch_train_recall += [train_recall]

    print(f'Epoch [{epoch+1}/{NUM_EPOCH}]')
    print(f'Trn loss: {train_loss:.4f}')
    print(f'Trn accuracy: {train_accuracy:.4f}')
    print(f'Trn precision: {train_precision:.4f}')
    print(f'Trn recall: {train_recall:.4f}')
    
    # ======================= val ============================
    model.eval()
    val_all_labels = []
    val_all_preds = []
    val_recall = []
    with torch.no_grad():
        for rgb_data, hsv_data, lch_data, labels in tqdm(val_dataloader):
            rgb_data, hsv_data, lch_data, labels = rgb_data.to(device), hsv_data.to(device), lch_data.to(device), labels.to(device)

            hsv_data = hsv_data.reshape((rgb_data.shape)).float()
            lch_data = lch_data.reshape((rgb_data.shape)).float()

            optimizer.zero_grad()

            # rgb_output = model.rgb_model(rgb_data)
            # hsv_output = model.hsv_model(hsv_data)
            # lch_output = model.lch_model(lch_data)

            # concatenated_output = torch.cat([rgb_output, hsv_output, lch_output], dim=1)
            # outputs = model.fc(concatenated_output)
            
            outputs = model(rgb_data, hsv_data, lch_data)          
            
            _, preds = torch.max(outputs, 1)
            val_all_labels.extend(labels.cpu().numpy())
            val_all_preds.extend(preds.cpu().numpy())
            
            val_loss = criterion(outputs, labels)

    val_accuracy = accuracy_score(val_all_labels, val_all_preds)
    val_precision = precision_score(val_all_labels, val_all_preds, average='weighted')
    val_recall = recall_score(val_all_labels, val_all_preds)
    
    epoch_val_loss += [val_loss.item()]
    epoch_val_accuracy += [val_accuracy]
    epoch_val_precision += [val_precision]
    epoch_val_recall += [val_recall]
    
    print(f'Val loss: {val_loss:.4f}')
    print(f'Val accuracy: {val_accuracy:.4f}')
    print(f'Val precision: {val_precision:.4f}')
    print(f'val recall: {val_recall:.4f}')
        
    if best_val_acc < val_accuracy:
        best_val_acc = val_accuracy
        corres_val_precision = val_precision
        corres_val_recall = val_recall
        torch.save({
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
        }, mdl_pth)
        
    torch.save({
        'epoch_training_loss': epoch_train_loss,
        'epoch_training_accuracy': epoch_train_accuracy,
        'epoch_training_precision': epoch_train_precision,
        'epoch_train_recall': epoch_train_recall,
        'epoch_val_loss': epoch_val_loss,
        'epoch_val_accuracy': epoch_val_accuracy,
        'epoch_val_precision': epoch_val_precision,
        'epoch_val_recall': epoch_val_recall,
        'best_val_acc': best_val_acc,
        'corres_val_precision': corres_val_precision,
        'corres_val_recall': corres_val_recall
    }, rst_pth)
    
    print(f'Best model: Accuracy = {best_val_acc:.4f}, Precision = {corres_val_precision:.4f}, Recall = {corres_val_recall:.4f}')
print("=========================================================")
print(f'Best model: Accuracy = {best_val_acc:.4f}, Precision = {corres_val_precision:.4f}, Recall = {corres_val_recall:.4f}')

100%|██████████| 217/217 [06:09<00:00,  1.70s/it]


Epoch [1/5]
Trn loss: 1.0147
Trn accuracy: 0.9926
Trn precision: 0.9926
Trn recall: 0.9915


100%|██████████| 212/212 [05:23<00:00,  1.53s/it]


Val loss: 0.0001
Val accuracy: 0.9196
Val precision: 0.9221
val recall: 0.9579
Best model: Accuracy = 0.9196, Precision = 0.9221, Recall = 0.9579


100%|██████████| 217/217 [05:27<00:00,  1.51s/it]


Epoch [2/5]
Trn loss: 0.8701
Trn accuracy: 0.9952
Trn precision: 0.9952
Trn recall: 0.9952


100%|██████████| 212/212 [04:53<00:00,  1.38s/it]


Val loss: 0.0001
Val accuracy: 0.9235
Val precision: 0.9247
val recall: 0.9501
Best model: Accuracy = 0.9235, Precision = 0.9247, Recall = 0.9501


100%|██████████| 217/217 [05:28<00:00,  1.52s/it]


Epoch [3/5]
Trn loss: 1.1013
Trn accuracy: 0.9945
Trn precision: 0.9945
Trn recall: 0.9943


100%|██████████| 212/212 [04:56<00:00,  1.40s/it]


Val loss: 0.0002
Val accuracy: 0.9236
Val precision: 0.9245
val recall: 0.9463
Best model: Accuracy = 0.9236, Precision = 0.9245, Recall = 0.9463


100%|██████████| 217/217 [05:25<00:00,  1.50s/it]


Epoch [4/5]
Trn loss: 0.2137
Trn accuracy: 0.9965
Trn precision: 0.9965
Trn recall: 0.9969


100%|██████████| 212/212 [04:56<00:00,  1.40s/it]


Val loss: 0.1811
Val accuracy: 0.9248
Val precision: 0.9264
val recall: 0.9555
Best model: Accuracy = 0.9248, Precision = 0.9264, Recall = 0.9555


100%|██████████| 217/217 [05:25<00:00,  1.50s/it]


Epoch [5/5]
Trn loss: 0.1916
Trn accuracy: 0.9967
Trn precision: 0.9967
Trn recall: 0.9960


100%|██████████| 212/212 [04:53<00:00,  1.38s/it]


Val loss: 0.0008
Val accuracy: 0.9313
Val precision: 0.9315
val recall: 0.9430
Best model: Accuracy = 0.9313, Precision = 0.9315, Recall = 0.9430
Best model: Accuracy = 0.9313, Precision = 0.9315, Recall = 0.9430


---