# Requirements

In [33]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.backends import cudnn
from torch import optim
from torch.utils.tensorboard import SummaryWriter

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from torchvision import models
from torchvision.models.feature_extraction import create_feature_extractor

import pandas as pd
import os
import random
import numpy as np
import matplotlib.pyplot as plt
import math


from skimage import io, transform
from PIL import Image


# Path

In [2]:
CSV_ADD = 'train_add.csv'

# ِDataset

In [3]:
class FaceForensic():
    """Face Forensic Dataset."""
    def __init__(self, csv_file, transform=None):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.face_add = pd.read_csv(csv_file)
        self.transform = transform
    
    def __len__(self):
        return len(self.face_add)
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_add = self.face_add.iloc[idx, 1]
        image = Image.open(img_add)
        if self.transform:
            if idx % 4 in [0,2]: 
                image = self.transform['transform1'](image)
            #elif idx % 4 == 1:
                #image = self.transform['transform2'](image)
            else:
                image = self.transform['transform2'](image)

        return image

## Transforms

In [4]:
# Resize input
inp_size = (224,224)

# Normalize inputs
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

# Gaussian blur input
kernel_size = (5, 9)
sigma = (0.1, 5)

# Probablity of verflip or horflip
p_verflip = 0.5
p_horflip = 0.5


transform1 = transforms.Compose([transforms.Resize(inp_size), transforms.ToTensor(),
                                  transforms.Normalize(mean, std)])

transform2 = transforms.Compose([transforms.RandomVerticalFlip(p=p_verflip),transforms.RandomHorizontalFlip(p=p_horflip),
                                 transforms.GaussianBlur(kernel_size=kernel_size, sigma=sigma),
                                 transforms.Resize(inp_size),transforms.ToTensor(),
                                 transforms.Normalize(mean, std)])

trasnform3 = transforms.Compose([transforms.RandomHorizontalFlip(p=1),transforms.Resize(inp_size),transforms.ToTensor(),
                  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

trfms = {'transform1': transform1 ,'transform2': transform2}
#forensic_dataset = FaceForensic('train_add.csv', trfms)

# Dataloader

In [5]:
def set_contrastive_loader(dataset, transforms, csv_add, batch_size=128, shuffle=False):
    """Return data loaders"""
    
    forensic_dataset = dataset(csv_add, transforms)
    
    dataloader = DataLoader(forensic_dataset, batch_size=batch_size, shuffle=shuffle)
    
    return dataloader

# Contrastive Model

In [6]:
class ConEfficient(nn.Module):
    """backbone + projection head"""
    
    def __init__(self, name='Efficient-B0', pretrained=True, head='mlp', dim_in=1280, feat_dim=128):
        
        super(ConEfficient, self).__init__()
        efficientnet = models.efficientnet_b0(pretrained=pretrained)
        return_nodes = {"avgpool": "represent"}
        self.encoder = create_feature_extractor(efficientnet, return_nodes=return_nodes)
        self.dim_in = dim_in
        if head == 'linear':
            self.head = nn.Linear(dim_in, feat_dim)
        elif head == 'mlp':
            self.head = nn.Sequential(
                nn.Linear(dim_in, dim_in),
                nn.ReLU(inplace=True),
                nn.Linear(dim_in, feat_dim)
            )
        else:
            raise NotImplementedError(
                'head not supported: {}'.format(head))

    def forward(self, x):
        
        feat = self.encoder(x)['represent'].view(-1,self.dim_in)
        feat = F.normalize(self.head(feat), dim=1)
        return feat

# Adjust Learning Rate

In [17]:
def adjust_learning_rate(optimizer, epoch, mode, args):
    """
    :param optimizer: torch.optim
    :param epoch: int
    :param mode: str
    :param args: argparse.Namespace
    :return: None
    """
    if mode == "contrastive":
        lr = args['lr_contrastive']
        n_epochs = args['n_epochs_contrastive']
    elif mode == "cross_entropy":
        lr = args['lr_cross_entropy']
        n_epochs = args['n_epochs_cross_entropy']
    else:
        raise ValueError("Mode {} unknown".format(mode))

    if args['cosine']:
        eta_min = lr * (args['lr_decay_rate'] ** 3)
        lr = eta_min + (lr - eta_min) * (1 + math.cos(math.pi * epoch / n_epochs)) / 2
    else:
        n_steps_passed = np.sum(epoch > np.asarray(args['lr_decay_epochs']))
                                
        if n_steps_passed > 0:
                lr = lr * (args['lr_decay_rate'] ** n_steps_passed)

    for param_group in optimizer.param_groups:
        param_group["lr"] = lr

# Contrastive Loss

In [26]:
class ConLoss(nn.Module):
    """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf.
    It also supports the unsupervised contrastive loss in SimCLR"""
    def __init__(self, temperature=0.07, contrast_mode='all',
                 base_temperature=0.07, device="cpu"):
        super(ConLoss, self).__init__()
        self.temperature = temperature
        self.contrast_mode = contrast_mode
        self.base_temperature = base_temperature
        self.device = device
    
    def calc_loss(self, sim_mat):
        
        """Calculate Loss for every four images"""
        
        mask_1 = torch.Tensor([[0,1,0,0],[1,0,0,0],[0,0,0,1],[0,0,1,0]])
        mask_1 = mask_1.to(self.device)
        mask_2 = torch.ones((4,4)) - torch.eye(4)
        mask_2 = mask_2.to(self.device)
        
        sim_mat_num = sim_mat * mask_1
        sim_mat_denum = torch.exp(sim_mat) * mask_2
        
        num = sim_mat_num.sum(1, keepdim=True)
        denum = torch.log(sim_mat_denum.sum(1, keepdim=True))
        
        loss = -1 * (num - denum)
        return loss.sum()
    
    def forward(self, features, labels=None, mask=None):

        """Compute loss for model. If both `labels` and `mask` are None,
        it degenerates to SimCLR unsupervised loss:
        https://arxiv.org/pdf/2002.05709.pdf
        Args:
            features: hidden vector of shape [bsz, n_views, ...].
            labels: ground truth of shape [bsz].
            mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
                has the same class as sample i. Can be asymmetric.
        Returns:
            A loss scalar.
        """
        
        features = features.view(-1,4,features.size()[1])
        features_trp = torch.transpose(features, 1, 2)
        # compute logits
        anchor_dot_contrast = torch.div(
            torch.matmul(features, features_trp),
            self.temperature)
        #print(anchor_dot_contrast.size())
        #print(anchor_dot_contrast)
        # for numerical stability
        total_loss = 0
        for idx in range(features.size()[0]):
            loss = self.calc_loss(anchor_dot_contrast)
            total_loss += loss
        
        avg_loss = loss / (features.size()[0] * features.size()[1]) 
        
        return loss

# Train Contrastive

In [28]:
def train_contrastive(model, train_loader, criterion, optimizer, writer, args):
    """
    :param model: torch.nn.Module Model
    :param train_loader: torch.utils.data.DataLoader
    :param criterion: torch.nn.Module Loss
    :param optimizer: torch.optim
    :param writer: torch.utils.tensorboard.SummaryWriter
    :param args: argparse.Namespace
    :return: None
    """
    model.train()
    best_loss = float("inf")
    
    for epoch in range(args['n_epochs_contrastive']):
        print("Epoch [{}/{}]".format(epoch + 1, args['n_epochs_contrastive']))
        train_loss = 0
        
        for batch_idx, (inputs) in enumerate(train_loader):
            
            inputs = inputs.to(args['device'])
            optimizer.zero_grad()

            projections = model(inputs)
            loss = criterion(projections)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            writer.add_scalar(
                "Loss train | Supervised Contrastive",
                loss.item(),
                epoch * len(train_loader) + batch_idx,
            )
            progress_bar(
                batch_idx,
                len(train_loader),
                "Loss: {:.3f} ".format(train_loss / (batch_idx + 1)),
            )
            
            avg_loss = train_loss / (batch_idx + 1)
        
        # Only check every 10 epochs otherwise you will always save
        if epoch % 10 == 0:
            if (train_loss / (batch_idx + 1)) < best_loss:
                print("Saving..")
                state = {
                    "net": model.state_dict(),
                    "avg_loss": avg_loss,
                    "epoch": epoch,
                }
                if not os.path.isdir("checkpoint"):
                    os.mkdir("checkpoint")
                torch.save(state, "./checkpoint/ckpt_contrastive.pth")
                best_loss = avg_loss
                
        adjust_learning_rate(optimizer, epoch, mode="contrastive", args=args)

# Train Crossentropy

In [22]:
def train_cross_entropy(model, train_loader, test_loader, criterion, optimizer, writer, args):
    """
    :param model: torch.nn.Module Model
    :param train_loader: torch.utils.data.DataLoader
    :param test_loader: torch.utils.data.DataLoader
    :param criterion: torch.nn.Module Loss
    :param optimizer: torch.optim
    :param writer: torch.utils.tensorboard.SummaryWriter
    :param args: argparse.Namespace
    :return:
    """

    for epoch in range(args['n_epochs_cross_entropy']):  # loop over the dataset multiple times
        print("Epoch [{}/{}]".format(epoch + 1, args['n_epochs_cross_entropy']))

        model.train()
        train_loss = 0
        correct = 0
        total = 0
        for batch_idx, (inputs, targets) in enumerate(train_loader):
            inputs, targets = inputs.to(args['device']), targets.to(args['device'])
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            _, predicted = outputs.max(1)

            total_batch = targets.size(0)
            correct_batch = predicted.eq(targets).sum().item()
            total += total_batch
            correct += correct_batch

            writer.add_scalar(
                "Loss train | Cross Entropy",
                loss.item(),
                epoch * len(train_loader) + batch_idx,
            )
            writer.add_scalar(
                "Accuracy train | Cross Entropy",
                correct_batch / total_batch,
                epoch * len(train_loader) + batch_idx,
            )
            progress_bar(
                batch_idx,
                len(train_loader),
                "Loss: {:.3f} | Acc: {:.3f}%% ({}/{})".format(
                    train_loss / (batch_idx + 1),
                    100.0 * correct / total,
                    correct,
                    total,
                ),
            )

        validation(epoch, model, test_loader, criterion, writer, args)

        adjust_learning_rate(optimizer, epoch, mode='cross_entropy', args=args)
              
    print("Finished Training")
              
              
def validation(epoch, model, test_loader, criterion, writer, args):
    """
    :param epoch: int
    :param model: torch.nn.Module, Model
    :param test_loader: torch.utils.data.DataLoader
    :param criterion: torch.nn.Module, Loss
    :param writer: torch.utils.tensorboard.SummaryWriter
    :param args: argparse.Namespace
    :return:
    """

    model.eval()
    test_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(test_loader):
            inputs, targets = inputs.to(args['device']), targets.to(args['device'])
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            progress_bar(
                batch_idx,
                len(test_loader),
                "Loss: {:.3f} | Acc: {:.3f}%% ({}/{})".format(
                    test_loss / (batch_idx + 1),
                    100.0 * correct / total,
                    correct,
                    total,
                ),
            )

    # Save checkpoint.
    acc = 100.0 * correct / total
    writer.add_scalar("Accuracy validation | Cross Entropy", acc, epoch)

    if acc > args.best_acc:
        print("Saving..")
        state = {
            "net": model.state_dict(),
            "acc": acc,
            "epoch": epoch,
        }
        if not os.path.isdir("checkpoint"):
            os.mkdir("checkpoint")
        torch.save(state, "./checkpoint/ckpt_cross_entropy.pth")
        args.best_acc = acc

# Utils

In [30]:
import shutil
import sys
import time


_, term_width = shutil.get_terminal_size()
term_width = int(term_width)

TOTAL_BAR_LENGTH = 65.
last_time = time.time()
begin_time = last_time


def progress_bar(current, total, msg=None):
    global last_time, begin_time
    if current == 0:
        begin_time = time.time()  # Reset for new bar.

    cur_len = int(TOTAL_BAR_LENGTH*current/total)
    rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1

    sys.stdout.write(' [')
    for i in range(cur_len):
        sys.stdout.write('=')
    sys.stdout.write('>')
    for i in range(rest_len):
        sys.stdout.write('.')
    sys.stdout.write(']')

    cur_time = time.time()
    step_time = cur_time - last_time
    last_time = cur_time
    tot_time = cur_time - begin_time

    L = []
    L.append('  Step: %s' % format_time(step_time))
    L.append(' | Tot: %s' % format_time(tot_time))
    if msg:
        L.append(' | ' + msg)

    msg = ''.join(L)
    sys.stdout.write(msg)
    for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3):
        sys.stdout.write(' ')

    # Go back to the center of the bar.
    for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2):
        sys.stdout.write('\b')
    sys.stdout.write(' %d/%d ' % (current+1, total))

    if current < total-1:
        sys.stdout.write('\r')
    else:
        sys.stdout.write('\n')
    sys.stdout.flush()


def format_time(seconds):
    days = int(seconds / 3600/24)
    seconds = seconds - days*3600*24
    hours = int(seconds / 3600)
    seconds = seconds - hours*3600
    minutes = int(seconds / 60)
    seconds = seconds - minutes*60
    secondsf = int(seconds)
    seconds = seconds - secondsf
    millis = int(seconds*1000)

    f = ''
    i = 1
    if days > 0:
        f += str(days) + 'D'
        i += 1
    if hours > 0 and i <= 2:
        f += str(hours) + 'h'
        i += 1
    if minutes > 0 and i <= 2:
        f += str(minutes) + 'm'
        i += 1
    if secondsf > 0 and i <= 2:
        f += str(secondsf) + 's'
        i += 1
    if millis > 0 and i <= 2:
        f += str(millis) + 'ms'
        i += 1
    if f == '':
        f = '0ms'
    return f

In [24]:
def main(args):

    '''
    train_set = None
    train_loader = torch.utils.data.DataLoader(
        train_set,
        batch_size = args['batch_size'],
        shuffle=False,
        num_workers=args['num_workers']
    )
    
    test_set = None
    test_loader = torch.utils.data.DataLoader(
        test_set,
        batch_size=args['batch_size'],
        shuffle=True,
        num_workers=argss['num_workers']
    )
    '''
    # Cotrastive Model
    
    model = ConEfficient()
    model = model.to(args['device'])
    
    cudnn.benchmark = True
    
    if not os.path.isdir("logs"):
        os.mkdir("logs")
        
    writer = SummaryWriter("logs")
    
    if args['training_mode'] == 'contrastive':
        train_loader_contrastive = set_contrastive_loader(FaceForensic, trfms, CSV_ADD, 
                                                          batch_size=args['batch_size'], shuffle=False)
        
        # define optimizer
        optimizer = optim.SGD(
            model.parameters(),
            lr=args['lr_contrastive'],
            momentum=args['momentum'],
            weight_decay=args['weight_decay'],
        )  
        
        criterion = ConLoss(device=args['device'])
        criterion.to(args['device'])
        train_contrastive(model, train_loader_contrastive, criterion, optimizer, writer, args)

In [34]:
args = {'training_mode': 'contrastive', 'batch_size': 4, 
        'num_workers': 1,'temprature': 0.07,'cosine': True,
        'n_epochs_contrastive': 500, 'lr_contrastive': 1e-1, 
        'n_epochs_cross_entropy': 100, 'lr_cross_entropy': 5e-2,
        'momentum': 0.9, 'lr_decay_rate': 0.1,'lr_decay_epochs': [150, 300, 500],'weight_decay': 1e-4
       }
device = 'cuda' if torch.cuda.is_available() else 'cpu'
args['device'] = device

main(args)



Epoch [1/500]
Saving..
Epoch [2/500]
Epoch [3/500]
Epoch [4/500]
Epoch [5/500]
Epoch [6/500]
Epoch [7/500]
Epoch [8/500]
Epoch [9/500]

KeyboardInterrupt: 