## Single Colorspace EfficientNet (MC-EffNet)
> Reference: [Distinguishing Natural and Computer-Generated Images using Multi-Colorspace fused EfficientNet](https://arxiv.org/pdf/2110.09428.pdf)

In [39]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, SubsetRandomSampler
from torchvision import transforms
from torchvision.datasets import ImageFolder
from efficientnet_pytorch import EfficientNet
from sklearn.metrics import accuracy_score, precision_score, recall_score
import os
from tqdm import tqdm

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
torch.cuda.empty_cache()

cuda:0


In [40]:
NUM_EPOCH = 10
BS = 32
LR = 1e-4
NUM_CLASSES = 2
MODEL = 'efficientnet-b0'
LEAST_NUM_TRAIN_DATA = 6000
LEAST_NUM_VAL_DATA = 6000
INPUT_IMG_SIZE = (128, 128)

# ----- Original Fake2M dataset config-----
mdl_pth = f'./model/efficientnet_model_SC_Fake2M.pth'
rst_pth = f'./model/efficientnet_rst_SC_Fake2M.pth'
train_data_pth = './Fake2M/train/'
val_data_pth = './Fake2M/val/'

# ----- Random compressed and cropped Fake2M dataset config-----
# mdl_pth = f'./model/efficientnet_model_SC_Fake2M_cc.pth'
# rst_pth = f'./model/efficientnet_rst_SC_Fake2M_cc.pth'
# train_data_pth = './Fake2M_cc/train/'
# val_data_pth = './Fake2M_cc/val/'

In [41]:
transform = transforms.Compose([
    transforms.Resize(INPUT_IMG_SIZE),
    transforms.ToTensor(),
])

train_dataset = ImageFolder(root=train_data_pth, transform=transform)
val_dataset = ImageFolder(root=val_data_pth, transform=transform)

train_sampler = SubsetRandomSampler(range(0, len(train_dataset), len(train_dataset) // LEAST_NUM_TRAIN_DATA)) 
val_sampler = SubsetRandomSampler(range(0, len(val_dataset), len(val_dataset) // LEAST_NUM_VAL_DATA))

train_dataloader = DataLoader(train_dataset, batch_size=BS, sampler=train_sampler)
val_dataloader = DataLoader(val_dataset, batch_size=BS, sampler=val_sampler)

In [42]:
print(f'Number of training data: {len(train_dataloader) * train_dataloader.batch_size}')
print(f'Number of validation data: {len(val_dataloader) * val_dataloader.batch_size}')

Number of training data: 6944
Number of validation data: 6784


In [43]:
# Get the number of data in each class
check_data_label = False
if check_data_label:
    train_lbl_0 = 0
    train_lbl_1 = 0
    for _, labels in tqdm(train_dataloader):
        train_lbl_0 += torch.count_nonzero(1-labels)
        train_lbl_1 += torch.count_nonzero(labels)
        
    print(f'Number of label 0 (fake) training images: {train_lbl_0}')
    print(f'Number of label 1 (real) training images: {train_lbl_1}')

    val_lbl_0 = 0
    val_lbl_1 = 0
    for _, labels in tqdm(val_dataloader):
        val_lbl_0 += torch.count_nonzero(1-labels)
        val_lbl_1 += torch.count_nonzero(labels)
        
    print(f'Number of label 0 (fake) validation images: {val_lbl_0}')
    print(f'Number of label 1 (real) validation images: {val_lbl_1}')

In [44]:
model = EfficientNet.from_pretrained(MODEL, num_classes=NUM_CLASSES)
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
criterion = nn.CrossEntropyLoss()
model.to(device)

if os.path.exists(mdl_pth):
    checkpoint = torch.load(mdl_pth)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    print("Trained Model loaded...")
else:
    print("New Model created...")
    
if os.path.exists(rst_pth):
    checkpoint = torch.load(rst_pth)
    epoch_train_loss = checkpoint['epoch_training_loss']
    epoch_train_accuracy = checkpoint['epoch_training_accuracy']
    epoch_train_precision = checkpoint['epoch_training_precision']
    epoch_train_recall = checkpoint['epoch_train_recall']
    epoch_val_loss = checkpoint['epoch_val_loss']
    epoch_val_accuracy = checkpoint['epoch_val_accuracy']
    epoch_val_precision = checkpoint['epoch_val_precision']
    epoch_val_recall = checkpoint['epoch_val_recall']
    best_val_acc = checkpoint['best_val_acc']
    corres_val_precision = checkpoint['corres_val_precision']
    corres_val_recall = checkpoint['corres_val_recall']
    print("History result loaded...")
    print('Best model:')
    print(f'Accuracy = {best_val_acc:.4f}')
    print(f'Precision = {corres_val_precision:.4f}')
    print(f'Recall = {corres_val_recall:.4f}')
else:
    epoch_train_loss = []
    epoch_train_accuracy = []
    epoch_train_precision = []
    epoch_train_recall = []
    epoch_val_loss = []
    epoch_val_accuracy = []
    epoch_val_precision = []
    epoch_val_recall = []
    best_val_acc = 0.0 
    corres_val_precision = 0.0
    corres_val_recall = 0.0
    print("No history result...")

Loaded pretrained weights for efficientnet-b0
New Model created...
No history result...


In [45]:
for epoch in range(NUM_EPOCH):
    # ====================== train ===========================
    train_all_labels = []
    train_all_preds = []
    model.train()
    for inputs, labels in tqdm(train_dataloader):
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)
        
        train_loss = criterion(outputs, labels)
        train_loss.backward()
        optimizer.step()
        
        train_all_labels.extend(labels.cpu().numpy())
        train_all_preds.extend(preds.cpu().numpy())

    train_accuracy = accuracy_score(train_all_labels, train_all_preds)
    train_precision = precision_score(train_all_labels, train_all_preds, average='weighted')
    train_recall = recall_score(train_all_labels, train_all_preds)
    
    epoch_train_loss += [train_loss.item()]
    epoch_train_accuracy += [train_accuracy]
    epoch_train_precision += [train_precision]
    epoch_train_recall += [train_recall]

    print(f'Epoch [{epoch+1}/{NUM_EPOCH}]')
    print(f'Trn loss: {train_loss:.4f}')
    print(f'Trn accuracy: {train_accuracy:.4f}')
    print(f'Trn precision: {train_precision:.4f}')
    print(f'Trn recall: {train_recall:.4f}')
    
    # ======================= val ============================
    model.eval()
    val_all_labels = []
    val_all_preds = []
    with torch.no_grad():
        for inputs, labels in tqdm(val_dataloader):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            
            val_all_labels.extend(labels.cpu().numpy())
            val_all_preds.extend(preds.cpu().numpy())
            val_loss = criterion(outputs, labels)
            
    val_accuracy = accuracy_score(val_all_labels, val_all_preds)
    val_precision = precision_score(val_all_labels, val_all_preds, average='weighted')
    val_recall = recall_score(val_all_labels, val_all_preds)
    
    epoch_val_loss += [val_loss.item()]
    epoch_val_accuracy += [val_accuracy]
    epoch_val_precision += [val_precision]
    epoch_val_recall += [val_recall]
    
    print(f'Val loss: {val_loss:.4f}')
    print(f'Val accuracy: {val_accuracy:.4f}')
    print(f'Val precision: {val_precision:.4f}')
    print(f'val recall: {val_recall:.4f}')
        
    if best_val_acc < val_accuracy:
        best_val_acc = val_accuracy
        corres_val_precision = val_precision
        corres_val_recall = val_recall
        torch.save({
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
        }, mdl_pth)
        
    torch.save({
        'epoch_training_loss': epoch_train_loss,
        'epoch_training_accuracy': epoch_train_accuracy,
        'epoch_training_precision': epoch_train_precision,
        'epoch_train_recall': epoch_train_recall,
        'epoch_val_loss': epoch_val_loss,
        'epoch_val_accuracy': epoch_val_accuracy,
        'epoch_val_precision': epoch_val_precision,
        'epoch_val_recall': epoch_val_recall,
        'best_val_acc': best_val_acc,
        'corres_val_precision': corres_val_precision,
        'corres_val_recall': corres_val_recall
    }, rst_pth)
        
print("=========================================================")
print(f'Best model: Accuracy = {best_val_acc:.4f}, Precision = {corres_val_precision:.4f}, Recall = {corres_val_recall:.4f}')

100%|██████████| 217/217 [01:30<00:00,  2.40it/s]


Epoch [1/10]
Trn loss: 0.5986
Trn accuracy: 0.7388
Trn precision: 0.7415
Trn recall: 0.8009


100%|██████████| 212/212 [01:24<00:00,  2.50it/s]


Val loss: 0.0157
Val accuracy: 0.6950
Val precision: 0.7682
val recall: 0.9558


100%|██████████| 217/217 [01:27<00:00,  2.47it/s]


Epoch [2/10]
Trn loss: 0.7123
Trn accuracy: 0.8737
Trn precision: 0.8739
Trn recall: 0.8654


100%|██████████| 212/212 [01:22<00:00,  2.57it/s]


Val loss: 0.3210
Val accuracy: 0.7797
Val precision: 0.7994
val recall: 0.9074


100%|██████████| 217/217 [01:28<00:00,  2.45it/s]


Epoch [3/10]
Trn loss: 0.4956
Trn accuracy: 0.9278
Trn precision: 0.9279
Trn recall: 0.9247


100%|██████████| 212/212 [01:23<00:00,  2.55it/s]


Val loss: 0.0007
Val accuracy: 0.7930
Val precision: 0.7981
val recall: 0.8579


100%|██████████| 217/217 [01:33<00:00,  2.31it/s]


Epoch [4/10]
Trn loss: 0.8810
Trn accuracy: 0.9592
Trn precision: 0.9592
Trn recall: 0.9585


100%|██████████| 212/212 [01:23<00:00,  2.52it/s]


Val loss: 0.0394
Val accuracy: 0.7828
Val precision: 0.7930
val recall: 0.8754


100%|██████████| 217/217 [01:31<00:00,  2.36it/s]


Epoch [5/10]
Trn loss: 0.8173
Trn accuracy: 0.9725
Trn precision: 0.9725
Trn recall: 0.9696


100%|██████████| 212/212 [01:20<00:00,  2.62it/s]


Val loss: 0.0396
Val accuracy: 0.7733
Val precision: 0.7991
val recall: 0.9196


100%|██████████| 217/217 [01:32<00:00,  2.34it/s]


Epoch [6/10]
Trn loss: 0.6217
Trn accuracy: 0.9842
Trn precision: 0.9842
Trn recall: 0.9824


100%|██████████| 212/212 [01:23<00:00,  2.53it/s]


Val loss: 0.1209
Val accuracy: 0.7908
Val precision: 0.8046
val recall: 0.8967


100%|██████████| 217/217 [01:27<00:00,  2.48it/s]


Epoch [7/10]
Trn loss: 0.6830
Trn accuracy: 0.9842
Trn precision: 0.9842
Trn recall: 0.9835


100%|██████████| 212/212 [01:21<00:00,  2.59it/s]


Val loss: 0.0291
Val accuracy: 0.7831
Val precision: 0.7964
val recall: 0.8884


100%|██████████| 217/217 [01:29<00:00,  2.42it/s]


Epoch [8/10]
Trn loss: 0.8723
Trn accuracy: 0.9900
Trn precision: 0.9900
Trn recall: 0.9915


100%|██████████| 212/212 [01:22<00:00,  2.57it/s]


Val loss: 0.0112
Val accuracy: 0.7995
Val precision: 0.8085
val recall: 0.8843


100%|██████████| 217/217 [01:37<00:00,  2.23it/s]


Epoch [9/10]
Trn loss: 1.2859
Trn accuracy: 0.9883
Trn precision: 0.9883
Trn recall: 0.9884


100%|██████████| 212/212 [01:27<00:00,  2.42it/s]


Val loss: 2.7068
Val accuracy: 0.7844
Val precision: 0.7886
val recall: 0.8436


100%|██████████| 217/217 [01:33<00:00,  2.33it/s]


Epoch [10/10]
Trn loss: 0.6636
Trn accuracy: 0.9894
Trn precision: 0.9894
Trn recall: 0.9898


100%|██████████| 212/212 [01:23<00:00,  2.54it/s]

Val loss: 0.0006
Val accuracy: 0.7903
Val precision: 0.8031
val recall: 0.8923
Best model: Accuracy = 0.7995, Precision = 0.8085, Recall = 0.8843





---