## Imports

In [1]:
import random
import json
import pandas as pd
import gc

import warnings
warnings.filterwarnings("ignore", category=UserWarning) 

In [2]:
import os
import sys
import time
import math
import datetime
from efficientnet_pytorch import EfficientNet

import torchvision 
from torchvision import transforms, datasets
from torch.utils.tensorboard import SummaryWriter
import torch
from torch import Tensor
device = "cuda" if torch.cuda.is_available() else "cpu"

import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import timm

import torch.optim as optim
from sklearn.metrics import confusion_matrix
# from torchvision import transforms, datasets

# from util import AverageMeter
# from util import adjust_learning_rate, warmup_learning_rate
# from util import set_optimizer, save_model
# from networks.resnet_big import SupConResNe
from losses import SupConLoss

  from .autonotebook import tqdm as notebook_tqdm


## Hyperparameters

In [3]:
print_freq = 100
save_freq = 100
batch_size = 64
epochs = 400

# Optimization
learning_rate = 0.05
lr_decay_epochs = "100,200,300"
lr_decay_rate = 0.1
weight_decay = 1e-4
momentum = 0.9

# Dataset
dataset = 'faceforensic'

temp = 0.07

#method
method = 'SupCon'

cosine = True
warm = True
trial = '0'



In [4]:
model_path = './save/SupCon/{}_models'.format(dataset)
tb_path = './save/SupCon/{}_tensorboard'.format(dataset)
lr_decay_epochs = [int(it) for it in lr_decay_epochs.split(',')]


save_time = str(datetime.datetime.now())
model_name = '{}_{}_{}_{}_lr_{}_decay_{}_bsz_{}_temp_{}_trial_{}'.\
        format(save_time, method, dataset, 'resnet50', learning_rate,
               weight_decay, batch_size, temp, trial)

if cosine:
    model_name = '{}_cosine'.format(model_name)


tb_folder = os.path.join(tb_path, model_name)
if not os.path.isdir(tb_folder):
    os.makedirs(tb_folder)

save_folder = os.path.join(model_path, model_name)
if not os.path.isdir(save_folder):
    os.makedirs(save_folder)
    


In [5]:
class TwoCropTransform:
    """Create two crops of the same image"""
    def __init__(self, transform):
        self.transform = transform

    def __call__(self, x):
        return [self.transform(x), self.transform(x)]


class AverageMeter():
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res


def accuracy_evaluate(output, target, topk=(1,)):
    """accuarcy for evaluation 3+1"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)
        correct = []
        tn = 0
        fp = 0
        fn = 0
        tp = 0

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        pred = pred.view(-1)
        # print(pred[5])
        # print(target)

        for i in range(batch_size):
            # print(pred[i].is_nonzero.eval())
            # print(target[i].is_nonzero.eval())

            if pred[i] == 0 and target[i] == 0:
                correct.append(True)
                tn += 1
            elif pred[i] != 0 and target[i] != 0:
                correct.append(True)
                tp += 1
            else:
                if pred[i] == 0 and target[i] != 0:
                    fn += 1
                elif pred[i] != 0 and target[i] == 0:
                    fp += 1
                correct.append(False)

        # print(correct)

        res = []
        for k in topk:
            correct_k = sum(bool(x) for x in correct)
            # print(correct_k)
            res.append(correct_k * 100.0 / batch_size)

        return res

def output_score(output, target):
    print(output)



def adjust_learning_rate(optimizer, epoch):
    lr = learning_rate
    if cosine:
        eta_min = lr * (lr_decay_rate ** 3)
        lr = eta_min + (lr - eta_min) * (
                1 + math.cos(math.pi * epoch / epochs)) / 2
    else:
        steps = np.sum(epoch > np.asarray(lr_decay_epochs))
        if steps > 0:
            lr = lr * (lr_decay_rate ** steps)

    for param_group in optimizer.param_groups:
        param_group['lr'] = lr


def warmup_learning_rate(epoch, batch_id, total_batches, optimizer):
    if warm and epoch <= warm_epochs:
        p = (batch_id + (epoch - 1) * total_batches) / \
            (warm_epochs * total_batches)
        lr = warmup_from + p * (warmup_to - warmup_from)

        for param_group in optimizer.param_groups:
            param_group['lr'] = lr


def set_optimizer(model):
    optimizer = optim.SGD(model.parameters(),
                          lr=learning_rate,
                          momentum=momentum,
                          weight_decay=weight_decay)
#     optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    optimizer = optim.Adamax(model.parameters(),  lr=1e-3, eps=1e-4, weight_decay=1e-4)

    return optimizer


def save_model(model, optimizer, epoch, save_file):
    print('==> Saving...')
    state = {
        'model': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'epoch': epoch,
    }
    torch.save(state, save_file)
    del state


class LinearClassifierFeatureFusion(nn.Module):
    """Linear classifier"""
    def __init__(self, num_classes=2):
        super(LinearClassifierFeatureFusion, self).__init__()
        feat_dim = 3328
        self.fc1 = nn.Linear(feat_dim, 256)
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Linear(256, num_classes)

    def forward(self, features):
        features = features.view(features.shape[0], -1)
        x = self.fc1(features)
        x = self.dropout(x)
        x = self.fc2(x)
        return x

## Splitting Train, Test, Validation

In [6]:
if os.path.exists('train_idx.txt') == False:
    total_idx = list(range(999))
    random.shuffle(total_idx)
    train_idx = total_idx[:600]

    with open('train_idx.txt', 'w') as file:
        json.dump(train_idx, file)

    valid_idx = total_idx[600:800]

    with open('valid_idx.txt', 'w') as file:
        json.dump(valid_idx, file)


    test_idx = total_idx[800:]

    with open('test_idx.txt', 'w') as file:
        json.dump(test_idx, file)

else:
    
    with open('train_idx.txt', 'r') as file:
        train_idx = json.load(file)
        
    with open('valid_idx.txt', 'r') as file:
        valid_idx = json.load(file)
        
    with open('test_idx.txt', 'r') as file:
        test_idx = json.load(file)
        
def idx_to_path(cat, indexes):
 
    root = './Dataset/c23'
    data_frame = {'path':[]}
    path = root + '/{}'.format(cat)
    for idx in indexes:
        folder_path = path + '/{}'.format(idx)
        if os.path.exists(folder_path) == True:
            for file_name in os.listdir(folder_path):
                file_path = folder_path + '/{}'.format(file_name)
                data_frame['path'].append(file_path)

    if cat == 'Original':
        labels = [1]*len(data_frame['path'])
    else:
        labels = [0]*len(data_frame['path'])

    data_frame['labels'] = labels
    data_frame = pd.DataFrame(data_frame)
    
    return data_frame

    


In [7]:
DF_train_data = idx_to_path('Deepfakes', train_idx)
DF_valid_data = idx_to_path('Deepfakes', valid_idx)
DF_test_data = idx_to_path('Deepfakes', test_idx)

F2F_train_data = idx_to_path('Face2Face', train_idx)
F2F_valid_data = idx_to_path('Face2Face', valid_idx)
F2F_test_data = idx_to_path('Face2Face', test_idx)

FS_train_data = idx_to_path('FaceSwap', train_idx)
FS_valid_data = idx_to_path('FaceSwap', valid_idx)
FS_test_data = idx_to_path('FaceSwap', test_idx)

NT_train_data = idx_to_path('NeuralTextures', train_idx)
NT_valid_data = idx_to_path('NeuralTextures', valid_idx)
NT_test_data = idx_to_path('NeuralTextures', test_idx)

OR_train_data = idx_to_path('Original', train_idx)
OR_valid_data = idx_to_path('Original', valid_idx)
OR_test_data = idx_to_path('Original', test_idx)

In [8]:
print(len(DF_train_data))
print(len(F2F_train_data))
print(len(FS_train_data))
print(len(OR_train_data))

17997
17550
18000
18000


In [9]:
print(len(DF_valid_data))
print(len(F2F_valid_data))
print(len(FS_valid_data))
print(len(NT_valid_data))
print(len(OR_valid_data))

6000
5880
6000
6000
6000


In [10]:
print(len(DF_test_data))
print(len(F2F_test_data))
print(len(FS_test_data))
print(len(NT_test_data))
print(len(OR_test_data))

5970
5730
5970
5970
5970


In [11]:
from torch.utils.data import Dataset
from PIL import Image

mean = (0.5, 0.5, 0.5)
std = (0.5, 0.5, 0.5)
# mean = [0.485, 0.456, 0.406]
# std = [0.229, 0.224, 0.225]
normalize = transforms.Normalize(mean=mean, std=std)

    
train_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
#         transforms.RandomApply([
#             transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)
#         ], p=0.8),
#         transforms.RandomGrayscale(p=0.2),
        transforms.ToTensor(),
        normalize,
    ])

class MyDataset(Dataset):
    def __init__(self, data_frame, transform=None):
        
        self.data = data_frame
        self.transform = transform
    
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        
        img_path = self.data.loc[idx, 'path']
        label = self.data.loc[idx, 'labels']
        # RGB
        img = Image.open(img_path).convert('RGB')
        
        if self.transform is not None:
            img = self.transform(img)

        return img, label

In [12]:
DF_train_dataset = MyDataset(DF_train_data, TwoCropTransform(train_transform))
DF_valid_dataset = MyDataset(DF_valid_data, TwoCropTransform(train_transform))
DF_test_dataset = MyDataset(DF_test_data, TwoCropTransform(train_transform))

F2F_train_dataset = MyDataset(F2F_train_data, TwoCropTransform(train_transform))
F2F_valid_dataset = MyDataset(F2F_valid_data, TwoCropTransform(train_transform))
F2F_test_dataset = MyDataset(F2F_test_data, TwoCropTransform(train_transform))

FS_train_dataset = MyDataset(FS_train_data, TwoCropTransform(train_transform))
FS_valid_dataset = MyDataset(FS_valid_data, TwoCropTransform(train_transform))
FS_test_dataset = MyDataset(FS_test_data, TwoCropTransform(train_transform))

NT_train_dataset = MyDataset(NT_train_data, TwoCropTransform(train_transform))
NT_valid_dataset = MyDataset(NT_valid_data, TwoCropTransform(train_transform))
NT_test_dataset = MyDataset(NT_test_data, TwoCropTransform(train_transform))

OR_train_dataset = MyDataset(OR_train_data, TwoCropTransform(train_transform))
OR_valid_dataset = MyDataset(OR_valid_data, TwoCropTransform(train_transform))
OR_test_dataset = MyDataset(OR_test_data, TwoCropTransform(train_transform))

In [13]:
# DF_train_loader = torch.utils.data.DataLoader(DF_train_dataset, batch_size=batch_size, shuffle=True)
# DF_valid_loader = torch.utils.data.DataLoader(DF_valid_dataset, batch_size=batch_size, shuffle=True)
# DF_test_loader = torch.utils.data.DataLoader(DF_test_dataset, batch_size=batch_size, shuffle=True)

# F2F_train_loader = torch.utils.data.DataLoader(F2F_train_dataset, batch_size=batch_size, shuffle=True)
# F2F_valid_loader = torch.utils.data.DataLoader(F2F_valid_dataset, batch_size=batch_size, shuffle=True)
# F2F_test_loader = torch.utils.data.DataLoader(F2F_test_dataset, batch_size=batch_size, shuffle=True)

# FS_train_loader = torch.utils.data.DataLoader(FS_train_dataset, batch_size=batch_size, shuffle=True)
# FS_valid_loader = torch.utils.data.DataLoader(FS_valid_dataset, batch_size=batch_size, shuffle=True)
# FS_test_loader = torch.utils.data.DataLoader(FS_test_dataset, batch_size=batch_size, shuffle=True)

# NT_train_loader = torch.utils.data.DataLoader(NT_train_dataset, batch_size=batch_size, shuffle=True)
# NT_valid_loader = torch.utils.data.DataLoader(NT_valid_dataset, batch_size=batch_size, shuffle=True)
# NT_test_loader = torch.utils.data.DataLoader(NT_test_dataset, batch_size=batch_size, shuffle=True)

# OR_train_loader = torch.utils.data.DataLoader(OR_train_dataset, batch_size=batch_size, shuffle=True)
# OR_valid_loader = torch.utils.data.DataLoader(OR_valid_dataset, batch_size=batch_size, shuffle=True)
# OR_test_loader = torch.utils.data.DataLoader(OR_test_dataset, batch_size=batch_size, shuffle=True)

In [14]:
model_dict = {
    'resnet18': 512,
    'resnet34': 512,
    'resnet50': 2048,
    'resnet101': 2048
}
class SupConResNet(nn.Module):
    """backbone + projection head"""
    def __init__(self, name='resnet50', head='mlp', feat_dim=128):
        super(SupConResNet, self).__init__()
        dim_in = model_dict[name]
        img_model = torchvision.models.resnet50(pretrained=True)
        self.encoder = nn.Sequential(*list(img_model.children())[:-1])
        if head == 'linear':
            self.head = nn.Linear(dim_in, feat_dim)
        elif head == 'mlp':
            self.head = nn.Sequential(
                nn.Linear(dim_in, dim_in),
                nn.ReLU(inplace=True),
                nn.Linear(dim_in, feat_dim)
            )
        else:
            raise NotImplementedError(
                'head not supported: {}'.format(head))

    def forward(self, x):
        feat = self.encoder(x)
        feat = feat.view(feat.shape[0], feat.shape[1])
#         print(feat.shape)
        feat = F.normalize(self.head(feat), dim=1)
        return feat

In [15]:
# test_inp = torch.randn((64,3,224,224))
# test_model = SupConResNet()
# test_out = test_model(test_inp)
# print(test_out.shape)

## SRM Filters

In [16]:
class SRMConv2d_simple(nn.Module):
    
    def __init__(self, inc=3, learnable=False):
        super(SRMConv2d_simple, self).__init__()
        self.truc = nn.Hardtanh(-3, 3)
        kernel = self._build_kernel(inc)  # (3,3,5,5)
        self.kernel = nn.Parameter(data=kernel, requires_grad=learnable)
        # self.hor_kernel = self._build_kernel().transpose(0,1,3,2)

    def forward(self, x):
        '''
        x: imgs (Batch, H, W, 3)
        '''
        out = F.conv2d(x, self.kernel, stride=1, padding=2)
        out = self.truc(out)

        return out

    def _build_kernel(self, inc):
        # filter1: KB
        filter1 = [[0, 0, 0, 0, 0],
                   [0, -1, 2, -1, 0],
                   [0, 2, -4, 2, 0],
                   [0, -1, 2, -1, 0],
                   [0, 0, 0, 0, 0]]
        # filter2：KV
        filter2 = [[-1, 2, -2, 2, -1],
                   [2, -6, 8, -6, 2],
                   [-2, 8, -12, 8, -2],
                   [2, -6, 8, -6, 2],
                   [-1, 2, -2, 2, -1]]
        # filter3：hor 2rd
        filter3 = [[0, 0, 0, 0, 0],
                  [0, 0, 0, 0, 0],
                  [0, 1, -2, 1, 0],
                  [0, 0, 0, 0, 0],
                  [0, 0, 0, 0, 0]]

        filter1 = np.asarray(filter1, dtype=float) / 4.
        filter2 = np.asarray(filter2, dtype=float) / 12.
        filter3 = np.asarray(filter3, dtype=float) / 2.
        # statck the filters
        filters = [[filter1],#, filter1, filter1],
                   [filter2],#, filter2, filter2],
                   [filter3]]#, filter3, filter3]]  # (3,3,5,5)
        filters = np.array(filters)
        filters = np.repeat(filters, inc, axis=1)
        filters = torch.FloatTensor(filters)    # (3,3,5,5)
        return filters


In [17]:
class SupConEffNet(nn.Module):
    """backbone + projection head"""
    def __init__(self, head='mlp', feat_dim=128):
        super(SupConEffNet, self).__init__()
        # _, dim_in = model_dict[name]
        self.srm = SRMConv2d_simple()
        dim_in = 1280
        self.encoder = EfficientNet.from_name('efficientnet-b0', num_classes=2, include_top=False)
        if head == 'linear':
            self.head = nn.Linear(dim_in, feat_dim)
        elif head == 'mlp':
            self.head = nn.Sequential(
                nn.Linear(dim_in, dim_in),
                nn.ReLU(inplace=True),
                nn.Linear(dim_in, feat_dim)
            )
        else:
            raise NotImplementedError(
                'head not supported: {}'.format(head))

    def forward(self, x):
        x = self.srm(x)
        x = self.encoder(x)
#         print(feat.shape)
        x = x.view(x.shape[0], -1)
#         print(feat.shape)
        x = F.normalize(self.head(x), dim=1)
#         print(feat.shape)
        return x


class LinearClassifier(nn.Module):
    """Linear classifier"""
    def __init__(self, num_classes=4):
        super(LinearClassifier, self).__init__()
        feat_dim = 1280
        self.fc1 = nn.Linear(feat_dim, 256)
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Linear(256, num_classes)

    def forward(self, features):
        features = features.view(features.shape[0], -1)
        x = self.fc1(features)
        x = self.dropout(x)
        x = self.fc2(x)
        return x

In [18]:
class SupConEffNet(nn.Module):
    """backbone + projection head"""
    def __init__(self, head='mlp', feat_dim=128):
        super(SupConEffNet, self).__init__()
        # _, dim_in = model_dict[name]
#         self.srm = SRMConv2d_simple()
        dim_in = 1280
        self.encoder = timm.create_model('efficientnet_b0', pretrained=True)
        self.encoder.classifier = nn.Identity()
        if head == 'linear':
            self.head = nn.Linear(dim_in, feat_dim)
        elif head == 'mlp':
            self.head = nn.Sequential(
                nn.Linear(dim_in, dim_in),
                nn.ReLU(inplace=True),
                nn.Linear(dim_in, feat_dim)
            )
        else:
            raise NotImplementedError(
                'head not supported: {}'.format(head))

    def forward(self, x):
#         x = self.srm(x)
        x = self.encoder(x)
#         print(feat.shape)
        x = x.view(x.shape[0], -1)
#         print(feat.shape)
        x = F.normalize(self.head(x), dim=1)
#         print(feat.shape)
        return x

In [19]:
class SupConEffNet(nn.Module):
    """backbone + projection head
    + Pre Stem Blocks
    """
    def __init__(self, head='mlp', feat_dim=128):
        super(SupConEffNet, self).__init__()
        
        self.conv1 = nn.Conv2d(3, 6, 3, stride=1, padding=1, bias=False)
        self.conv2 = nn.Conv2d(6, 12, 3, stride=1, padding=1, bias=False)
        self.conv3 = nn.Conv2d(12, 36, 3, stride=1, padding=1, bias=False)
        self.conv4 = nn.Conv2d(36, 36, 3, stride=1, padding=1, bias=False)
        
        self.mybn1 = nn.BatchNorm2d(6)
        self.mybn2 = nn.BatchNorm2d(12)
        self.mybn3 = nn.BatchNorm2d(36)
        self.mybn4 = nn.BatchNorm2d(36)
        
        dim_in = 1280
        self.encoder = timm.create_model('efficientnet_b0', pretrained=True)
        self.encoder.conv_stem.weight = nn.Parameter(self.encoder.conv_stem.weight.repeat(1, 12, 1, 1))
        self.encoder.classifier = nn.Identity()
        
        if head == 'linear':
            self.head = nn.Linear(dim_in, feat_dim)
        elif head == 'mlp':
            self.head = nn.Sequential(
                nn.Linear(dim_in, dim_in),
                nn.ReLU(inplace=True),
                nn.Linear(dim_in, feat_dim)
            )
        else:
            raise NotImplementedError(
                'head not supported: {}'.format(head))

    def forward(self, x):
        
        x = F.relu6(self.mybn1(self.conv1(x)))
        x = F.relu6(self.mybn2(self.conv2(x)))
        x = F.relu6(self.mybn3(self.conv3(x)))
        x = F.relu6(self.mybn4(self.conv4(x)))
        
        x = self.encoder(x)
#         print(feat.shape)
        x = x.view(x.shape[0], -1)
#         print(feat.shape)
        x = F.normalize(self.head(x), dim=1)
#         print(feat.shape)
        return x

In [20]:
class SupConEffNet_post(nn.Module):
    """backbone + projection head
    + Post Stem Blocks
    """
    def __init__(self, head='mlp', feat_dim=128):
        super(SupConEffNet_post, self).__init__()
        
        
        dim_in = 1280
        self.encoder = timm.create_model('efficientnet_b0', pretrained=True)
        self.encoder.conv_stem.stride = (1,1)
        #self.encoder.conv_stem.weight = nn.Parameter(self.encoder.conv_stem.weight.repeat(1, 12, 1, 1))
        self.encoder.classifier = nn.Identity()
        num_channels = 32
        self.post_stem = nn.ModuleList([timm.models.efficientnet_blocks.InvertedResidual(in_chs=num_channels, out_chs=num_channels, noskip=True),
                    timm.models.efficientnet_blocks.InvertedResidual(in_chs=num_channels, out_chs=num_channels),
                    timm.models.efficientnet_blocks.InvertedResidual(in_chs=num_channels, out_chs=num_channels),
                    timm.models.efficientnet_blocks.InvertedResidual(in_chs=num_channels, out_chs=num_channels, stride=2)])
        
        if head == 'linear':
            self.head = nn.Linear(dim_in, feat_dim)
        elif head == 'mlp':
            self.head = nn.Sequential(
                nn.Linear(dim_in, dim_in),
                nn.ReLU(inplace=True),
                nn.Linear(dim_in, feat_dim)
            )
        else:
            raise NotImplementedError(
                'head not supported: {}'.format(head))

    def forward(self, x):
        
        x = self.encoder.conv_stem(x)
        x = self.encoder.bn1(x)
        print(x.shape)
        for idx, block in enumerate(self.post_stem):
            x = block(x)
        print(x.shape)
        x = self.encoder.blocks(x)
        x = self.encoder.conv_head(x)
        x = self.encoder.bn2(x)
        x = self.encoder.global_pool(x)
#         print(feat.shape)
        x = x.view(x.shape[0], -1)
#         print(x.shape)
        x = F.normalize(self.head(x), dim=1)
#         print(feat.shape)
        return x

In [21]:
class ConvBn(nn.Module):
    """Provides utility to create different types of layers."""

    def __init__(self, in_channels: int, out_channels: int) -> None:
        """Constructor.
        Args:
            in_channels (int): no. of input channels.
            out_channels (int): no. of output channels.
        """
        super().__init__()
        self.conv = nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size=3,
            stride=1,
            padding=1,
            bias=False,
        )
        self.batch_norm = nn.BatchNorm2d(out_channels)

    def forward(self, inp: Tensor) -> Tensor:
        """Returns Conv2d followed by BatchNorm.

        Returns:
            Tensor: Output of Conv2D -> BN.
        """
        return self.batch_norm(self.conv(inp))


class Type1(nn.Module):
    """Creates type 1 layer of SRNet."""

    def __init__(self, in_channels: int, out_channels: int) -> None:
        super().__init__()
        self.convbn = ConvBn(in_channels, out_channels)
        self.relu = nn.ReLU()

    def forward(self, inp: Tensor) -> Tensor:
        """Returns type 1 layer of SRNet.
        Args:
            inp (Tensor): input tensor.
        Returns:
            Tensor: Output of type 1 layer.
        """
        return self.relu(self.convbn(inp))


class Type2(nn.Module):
    """Creates type 2 layer of SRNet."""

    def __init__(self, in_channels: int, out_channels: int) -> None:
        super().__init__()
        self.type1 = Type1(in_channels, out_channels)
        self.convbn = ConvBn(in_channels, out_channels)

    def forward(self, inp: Tensor) -> Tensor:
        """Returns type 2 layer of SRNet.
        Args:
            inp (Tensor): input tensor.
        Returns:
            Tensor: Output of type 2 layer.
        """
        return inp + self.convbn(self.type1(inp))


class Type3(nn.Module):
    """Creates type 3 layer of SRNet."""

    def __init__(self, in_channels: int, out_channels: int) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size=1,
            stride=2,
            padding=0,
            bias=False,
        )
        self.batch_norm = nn.BatchNorm2d(out_channels)
        self.type1 = Type1(in_channels, out_channels)
        self.convbn = ConvBn(out_channels, out_channels)
        self.pool = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)

    def forward(self, inp: Tensor) -> Tensor:
        """Returns type 3 layer of SRNet.
        Args:
            inp (Tensor): input tensor.

        Returns:
            Tensor: Output of type 3 layer.
        """
        out = self.batch_norm(self.conv1(inp))
        out1 = self.pool(self.convbn(self.type1(inp)))
        return out + out1


class Type4(nn.Module):
    """Creates type 4 layer of SRNet."""

    def __init__(self, in_channels: int, out_channels: int) -> None:
        super().__init__()
        self.type1 = Type1(in_channels, out_channels)
        self.convbn = ConvBn(out_channels, out_channels)
        self.gap = nn.AdaptiveAvgPool2d(output_size=1)

    def forward(self, inp: Tensor) -> Tensor:
        """Returns type 4 layer of SRNet.
        Args:
            inp (Tensor): input tensor.
        Returns:
            Tensor: Output of type 4 layer.
        """
        return self.gap(self.convbn(self.type1(inp)))


In [22]:
class Srnet(nn.Module):
    """This is SRNet model class."""

    def __init__(self) -> None:
        """Constructor."""
        super().__init__()
        dim_in = 512
        feat_dim = 128
        self.type1s = nn.Sequential(Type1(3, 64), Type1(64, 16))
        self.type2s = nn.Sequential(
            Type2(16, 16),
            Type2(16, 16),
            Type2(16, 16),
            Type2(16, 16),
            Type2(16, 16),
        )
        self.type3s = nn.Sequential(
            Type3(16, 16),
            Type3(16, 64),
            Type3(64, 128),
            Type3(128, 256),
        )
        self.type4 = Type4(256, 512)
        self.head = nn.Linear(dim_in, feat_dim)
#         self.dense = nn.Linear(512, 2)
#         self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, inp: Tensor) -> Tensor:
        """Returns logits for input images.
        Args:
            inp (Tensor): input image tensor of shape (Batch, 1, 256, 256)
        Returns:
            Tensor: Logits of shape (Batch, 2)
        """
        out = self.type1s(inp)
        out = self.type2s(out)
        out = self.type3s(out)
        out = self.type4(out)
        out = out.view(out.size(0), -1)
        out = F.normalize(self.head(out), dim=1)
#         out = self.dense(out)
        return out

In [23]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from components.attention import ChannelAttention, SpatialAttention, DualCrossModalAttention
from components.srm_conv import SRMConv2d_simple, SRMConv2d_Separate
from networks.xception import TransferModel

In [24]:
class SRMPixelAttention(nn.Module):
    def __init__(self, in_channels):
        super(SRMPixelAttention, self).__init__()
        self.srm = SRMConv2d_simple()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, 32, 3, 2, 0, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 64, 3, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
        )
        
        self.pa = SpatialAttention()

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, a=1)
                if not m.bias is None:
                    nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x_srm = self.srm(x)
        fea = self.conv(x_srm)        
        att_map = self.pa(fea)
        
        return att_map

In [25]:
class FeatureFusionModule(nn.Module):
    def __init__(self, in_chan=2048*2, out_chan=2048, *args, **kwargs):
        super(FeatureFusionModule, self).__init__()
        self.convblk = nn.Sequential(
            nn.Conv2d(in_chan, out_chan, 1, 1, 0, bias=False),
            nn.BatchNorm2d(out_chan),
            nn.ReLU()
        )
        self.ca = ChannelAttention(out_chan, ratio=16)
        self.init_weight()

    def forward(self, x, y):
        fuse_fea = self.convblk(torch.cat((x, y), dim=1))
        fuse_fea = fuse_fea + fuse_fea * self.ca(fuse_fea)
        return fuse_fea

    def init_weight(self):
        for ly in self.children():
            if isinstance(ly, nn.Conv2d):
                nn.init.kaiming_normal_(ly.weight, a=1)
                if not ly.bias is None:
                    nn.init.constant_(ly.bias, 0)

In [26]:
class Two_Stream_Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.xception_rgb = TransferModel(
            'xception', dropout=0.5, inc=3, return_fea=True)
        self.xception_srm = TransferModel(
            'xception', dropout=0.5, inc=3, return_fea=True)

        self.srm_conv0 = SRMConv2d_simple(inc=3)
        self.srm_conv1 = SRMConv2d_Separate(32, 32)
        self.srm_conv2 = SRMConv2d_Separate(64, 64)
        self.relu = nn.ReLU(inplace=True)

        self.att_map = None
        self.srm_sa = SRMPixelAttention(3)
        self.srm_sa_post = nn.Sequential(
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )

        self.dual_cma0 = DualCrossModalAttention(in_dim=728, ret_att=False)
        self.dual_cma1 = DualCrossModalAttention(in_dim=728, ret_att=False)

        self.fusion = FeatureFusionModule()
        dim_in = 2048
        feat_dim = 128
        self.head = nn.Linear(dim_in, feat_dim)

        self.att_dic = {}

    def features(self, x):
        srm = self.srm_conv0(x)

        x = self.xception_rgb.model.fea_part1_0(x)
        y = self.xception_srm.model.fea_part1_0(srm) \
            + self.srm_conv1(x)
        y = self.relu(y)

        x = self.xception_rgb.model.fea_part1_1(x)
        y = self.xception_srm.model.fea_part1_1(y) \
            + self.srm_conv2(x)
        y = self.relu(y)

        # srm guided spatial attention
#         print(srm.shape)
        self.att_map = self.srm_sa(srm)
#         print(self.att_map)
        x = x * self.att_map + x
        x = self.srm_sa_post(x)

        x = self.xception_rgb.model.fea_part2(x)
        y = self.xception_srm.model.fea_part2(y)

        x, y = self.dual_cma0(x, y)


        x = self.xception_rgb.model.fea_part3(x)        
        y = self.xception_srm.model.fea_part3(y)
 

        x, y = self.dual_cma1(x, y)

        x = self.xception_rgb.model.fea_part4(x)
        y = self.xception_srm.model.fea_part4(y)

        x = self.xception_rgb.model.fea_part5(x)
        y = self.xception_srm.model.fea_part5(y)

        fea = self.fusion(x, y)
                

        return fea
    def classifier(self, fea):
        out, fea = self.xception_rgb.classifier(fea)
        return out, fea

    def forward(self, x):
        '''
        x: original rgb
        '''
        _ , fea = self.classifier(self.features(x))
        fea = fea.view(x.size(0), -1)
        fea = F.normalize(self.head(fea), dim=1)

        return fea

In [27]:
import transformers
from transformers import ViTForImageClassification

In [28]:
class TransFormer(nn.Module):
    def __init__(self):
        super(TransFormer, self).__init__()
        # labels = ds['train'].features['labels'].names
        model_name_or_path = 'google/vit-base-patch16-224-in21k'
        self.model = ViTForImageClassification.from_pretrained(
        model_name_or_path,
        num_labels=128)
#         self.model.classifier = nn.Linear(768,128)
    
    def forward(self, x):
        
        y = self.model(x)
        fea = y.logits
        fea = fea.view(x.size(0), -1)
#         print(fea.shape)
        fea = torch.nn.functional.normalize(fea, dim=1)
        
        return fea

In [29]:
test_inp = torch.randn((1,3,224,224))
test_model = TransFormer()
test_out = test_model(test_inp)
print(test_out.shape)

Some weights of the model checkpoint at google/vit-base-patch16-224-in21k were not used when initializing ViTForImageClassification: ['pooler.dense.bias', 'pooler.dense.weight']
- This IS expected if you are initializing ViTForImageClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViTForImageClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


torch.Size([1, 128])


In [30]:
def set_model():
    gc.collect()
    torch.cuda.empty_cache()
#     model = SupConResNet(name='resnet50')
    model = TransFormer()
    criterion = SupConLoss(temperature=temp)

    # enable synchronized Batch Normalization
#     if opt.syncBN:
#         model = apex.parallel.convert_syncbn_model(model)
#         if torch.cuda.device_count() > 1:
#             model.encoder = torch.nn.DataParallel(model.encoder)
    model = model.to(device)
    criterion = criterion.cuda()
#         cudnn.benchmark = True

    return model, criterion

In [31]:
def train(train_loader, model, criterion, optimizer, epoch):
    """one epoch training"""
    

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()

    end = time.time()
    for idx, (images, labels) in enumerate(train_loader):
        data_time.update(time.time() - end)

        images = torch.cat([images[0], images[1]], dim=0)
        images = images.to(device)
        labels = labels.to(device)
        bsz = labels.shape[0]

        # warm-up learning rate
#         warmup_learning_rate(epoch, idx, len(train_loader), optimizer)

        # compute loss
        features = model(images)
        f1, f2 = torch.split(features, [bsz, bsz], dim=0)
        features = torch.cat([f1.unsqueeze(1), f2.unsqueeze(1)], dim=1)
#         print(features.shape, labels.shape)
#         break
        if method == 'SupCon':
            loss = criterion(features, labels)
        elif method == 'SimCLR':
            loss = criterion(features)
        else:
            raise ValueError('contrastive method not supported: {}'.
                             format(method))

        # update metric
        losses.update(loss.item(), bsz)

        # SGD
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        # print info
        if (idx + 1) % print_freq == 0:
            print('Train: [{0}][{1}/{2}]\t'
                  'BT {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'DT {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'loss {loss.val:.3f} ({loss.avg:.3f})'.format(
                   epoch, idx + 1, len(train_loader), batch_time=batch_time,
                   data_time=data_time, loss=losses))
            sys.stdout.flush()

    return losses.avg


In [32]:
train_dataset = torch.utils.data.ConcatDataset([DF_train_dataset, F2F_train_dataset, FS_train_dataset, OR_train_dataset, OR_train_dataset, OR_train_dataset])
# build data loader
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
print(len(train_dataset))

107547


In [33]:
def main():
    
    # build model and criterion
    model, criterion = set_model()
    
    total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Number of Network's Parameters: {total_params:,}")

    # build optimizer
    optimizer = set_optimizer(model)
    scheduler = optim.lr_scheduler.OneCycleLR(optimizer, 
                               max_lr=1e-3, 
                               epochs=epochs,
                               steps_per_epoch=len(train_loader),
                               pct_start=16.0/epochs,
                               div_factor=25,
                               final_div_factor=2)
    model.train()

    # tensorboard
    writer = SummaryWriter(f'./runs')
    print('Start Training')
    # training routine
    for epoch in range(1, epochs):
#         adjust_learning_rate(optimizer, epoch)

        # train for one epoch
        time1 = time.time()
        loss = train(train_loader, model, criterion, optimizer, epoch)
        scheduler.step()
#         return 0
        time2 = time.time()
        print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1))
        print('epoch {}, loss {:.2f}'.format(epoch, loss))
        print('*'*15)
        
        # tensorboard logger
        writer.add_scalar('loss', loss, epoch)
        writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], epoch)

        if epoch % save_freq == 0:
            save_file = os.path.join(
                save_folder, 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch))
            save_model(model, optimizer, epoch, save_file)
    
    writer.flush()
    writer.close()

    # save the last model
    save_file = os.path.join(
        save_folder, 'last.pth')
    save_model(model, optimizer, epochs, save_file)


In [34]:
main()

Some weights of the model checkpoint at google/vit-base-patch16-224-in21k were not used when initializing ViTForImageClassification: ['pooler.dense.bias', 'pooler.dense.weight']
- This IS expected if you are initializing ViTForImageClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViTForImageClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Number of Network's Parameters: 85,897,088
Start Training
Train: [1][100/3361]	BT 0.421 (0.436)	DT 0.049 (0.053)	loss 4.144 (4.240)


KeyboardInterrupt: 