In [1]:
import random
import json
import pandas as pd
import gc

import warnings
warnings.filterwarnings("ignore", category=UserWarning) 

In [2]:
import os
import sys
import time
import math
import datetime


import torchvision 
from torchvision import transforms, datasets
from torch.utils.tensorboard import SummaryWriter
import torch
from torch import Tensor
device = "cuda" if torch.cuda.is_available() else "cpu"

import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from sklearn.metrics import confusion_matrix

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
print_freq = 100
save_freq = 100
batch_size = 64
epochs = 400

# Optimization
learning_rate = 0.05
lr_decay_epochs = "100,200,300"
lr_decay_rate = 0.1
weight_decay = 1e-4
momentum = 0.9

# Dataset
dataset = 'faceforensic'

temp = 0.07

#method
method = 'Adv'

cosine = True
warm = True
trial = '0'



In [7]:
model_path = './save/Adv/{}_models'.format(dataset)
tb_path = './save/Adv/{}_tensorboard'.format(dataset)

save_time = f"{datetime.datetime.now().strftime('%m_%d_%Y_%H_%M_%S')}"
print(save_time)
model_name = '{}_{}_{}_lr_{}_epochs_{}_bsz_{}'.\
        format(save_time, method, dataset, learning_rate,
            epochs , batch_size)

tb_folder = os.path.join(tb_path, model_name)
if not os.path.isdir(tb_folder):
    os.makedirs(tb_folder)

save_folder = os.path.join(model_path, model_name)
if not os.path.isdir(save_folder):
    os.makedirs(save_folder)
    


04_24_2023_18_38_12


In [8]:
class TwoCropTransform:
    """Create two crops of the same image"""
    def __init__(self, transform):
        self.transform = transform

    def __call__(self, x):
        return [self.transform(x), self.transform(x)]


class AverageMeter():
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res


def accuracy_evaluate(output, target, topk=(1,)):
    """accuarcy for evaluation 3+1"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)
        correct = []
        tn = 0
        fp = 0
        fn = 0
        tp = 0

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        pred = pred.view(-1)
        # print(pred[5])
        # print(target)

        for i in range(batch_size):
            # print(pred[i].is_nonzero.eval())
            # print(target[i].is_nonzero.eval())

            if pred[i] == 0 and target[i] == 0:
                correct.append(True)
                tn += 1
            elif pred[i] != 0 and target[i] != 0:
                correct.append(True)
                tp += 1
            else:
                if pred[i] == 0 and target[i] != 0:
                    fn += 1
                elif pred[i] != 0 and target[i] == 0:
                    fp += 1
                correct.append(False)

        # print(correct)

        res = []
        for k in topk:
            correct_k = sum(bool(x) for x in correct)
            # print(correct_k)
            res.append(correct_k * 100.0 / batch_size)

        return res

def output_score(output, target):
    print(output)



def adjust_learning_rate(optimizer, epoch):
    lr = learning_rate
    if cosine:
        eta_min = lr * (lr_decay_rate ** 3)
        lr = eta_min + (lr - eta_min) * (
                1 + math.cos(math.pi * epoch / epochs)) / 2
    else:
        steps = np.sum(epoch > np.asarray(lr_decay_epochs))
        if steps > 0:
            lr = lr * (lr_decay_rate ** steps)

    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

def set_optimizer(model):
    optimizer = optim.SGD(model.parameters(),
                          lr=learning_rate,
                          momentum=momentum,
                          weight_decay=weight_decay)
#     optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    optimizer = optim.Adamax(model.parameters(),  lr=1e-3, eps=1e-4, weight_decay=1e-4)

    return optimizer


def save_model(model, optimizer, epoch, save_file):
    print('==> Saving...')
    state = {
        'model': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'epoch': epoch,
    }
    torch.save(state, save_file)
    del state


class LinearClassifierFeatureFusion(nn.Module):
    """Linear classifier"""
    def __init__(self, num_classes=2):
        super(LinearClassifierFeatureFusion, self).__init__()
        feat_dim = 3328
        self.fc1 = nn.Linear(feat_dim, 256)
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Linear(256, num_classes)

    def forward(self, features):
        features = features.view(features.shape[0], -1)
        x = self.fc1(features)
        x = self.dropout(x)
        x = self.fc2(x)
        return x

In [12]:
# if os.path.exists('artifacts/train_idx.txt') == False:
#     total_idx = list(range(999))
#     random.shuffle(total_idx)
#     train_idx = total_idx[:600]

#     with open('train_idx.txt', 'w') as file:
#         json.dump(train_idx, file)

#     valid_idx = total_idx[600:800]

#     with open('valid_idx.txt', 'w') as file:
#         json.dump(valid_idx, file)


#     test_idx = total_idx[800:]

#     with open('test_idx.txt', 'w') as file:
#         json.dump(test_idx, file)

# else:

with open('../artifacts/train_idx.txt', 'r') as file:
    train_idx = json.load(file)
    
with open('../artifacts/valid_idx.txt', 'r') as file:
    valid_idx = json.load(file)
    
with open('../artifacts/test_idx.txt', 'r') as file:
    test_idx = json.load(file)
        
def idx_to_path(cat, indexes):
 
    root = './dataset/dataset/Dataset/c23'
    data_frame = {'path':[]}
    path = root + '/{}'.format(cat)
    for idx in indexes:
        folder_path = path + '/{}'.format(idx)
        if os.path.exists(folder_path) == True:
            for file_name in os.listdir(folder_path):
                file_path = folder_path + '/{}'.format(file_name)
                data_frame['path'].append(file_path)

    if cat == 'Original':
        labels = [1]*len(data_frame['path'])
    else:
        labels = [0]*len(data_frame['path'])

    data_frame['labels'] = labels
    data_frame = pd.DataFrame(data_frame)
    
    return data_frame


In [13]:
DF_train_data = idx_to_path('Deepfakes', train_idx)
DF_valid_data = idx_to_path('Deepfakes', valid_idx)
DF_test_data = idx_to_path('Deepfakes', test_idx)

F2F_train_data = idx_to_path('Face2Face', train_idx)
F2F_valid_data = idx_to_path('Face2Face', valid_idx)
F2F_test_data = idx_to_path('Face2Face', test_idx)

FS_train_data = idx_to_path('FaceSwap', train_idx)
FS_valid_data = idx_to_path('FaceSwap', valid_idx)
FS_test_data = idx_to_path('FaceSwap', test_idx)

NT_train_data = idx_to_path('NeuralTextures', train_idx)
NT_valid_data = idx_to_path('NeuralTextures', valid_idx)
NT_test_data = idx_to_path('NeuralTextures', test_idx)

OR_train_data = idx_to_path('Original', train_idx)
OR_valid_data = idx_to_path('Original', valid_idx)
OR_test_data = idx_to_path('Original', test_idx)

In [14]:
print(len(DF_train_data))
print(len(F2F_train_data))
print(len(FS_train_data))
print(len(OR_train_data))

17997
17550
18000
18000


In [15]:
print(len(DF_valid_data))
print(len(F2F_valid_data))
print(len(FS_valid_data))
print(len(NT_valid_data))
print(len(OR_valid_data))

6000
5880
6000
6000
6000


In [18]:
print(len(DF_test_data))
print(len(F2F_test_data))
print(len(FS_test_data))
print(len(NT_test_data))
print(len(OR_test_data))

5970
5730
5970
5970
5970


In [17]:
from torch.utils.data import Dataset
from PIL import Image

mean = (0.5, 0.5, 0.5)
std = (0.5, 0.5, 0.5)
normalize = transforms.Normalize(mean=mean, std=std)

    
train_transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.ToTensor(),
        normalize,
    ])

class MyDataset(Dataset):
    def __init__(self, data_frame, cat,transform=None):
        
        self.data = data_frame
        self.cat = cat
        self.transform = transform
    
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        
        img_path = self.data.loc[idx, 'path']
        label = self.data.loc[idx, 'labels']
        # RGB
        img = Image.open(img_path).convert('RGB')
        img = self.transform(img)

        return img, label

In [19]:
DF_train_dataset = MyDataset(DF_train_data, train_transform)
DF_valid_dataset = MyDataset(DF_valid_data, train_transform)
DF_test_dataset = MyDataset(DF_test_data, train_transform)

F2F_train_dataset = MyDataset(F2F_train_data, train_transform)
F2F_valid_dataset = MyDataset(F2F_valid_data, train_transform)
F2F_test_dataset = MyDataset(F2F_test_data, train_transform)

FS_train_dataset = MyDataset(FS_train_data, train_transform)
FS_valid_dataset = MyDataset(FS_valid_data, train_transform)
FS_test_dataset = MyDataset(FS_test_data, train_transform)

NT_train_dataset = MyDataset(NT_train_data, train_transform)
NT_valid_dataset = MyDataset(NT_valid_data, train_transform)
NT_test_dataset = MyDataset(NT_test_data, train_transform)

OR_train_dataset = MyDataset(OR_train_data, train_transform)
OR_valid_dataset = MyDataset(OR_valid_data, train_transform)
OR_test_dataset = MyDataset(OR_test_data, train_transform)

In [20]:
class SRMConv2d_simple(nn.Module):
    
    def __init__(self, inc=3, learnable=False):
        super(SRMConv2d_simple, self).__init__()
        self.truc = nn.Hardtanh(-3, 3)
        kernel = self._build_kernel(inc)  # (3,3,5,5)
        self.kernel = nn.Parameter(data=kernel, requires_grad=learnable)
        # self.hor_kernel = self._build_kernel().transpose(0,1,3,2)

    def forward(self, x):
        '''
        x: imgs (Batch, H, W, 3)
        '''
        out = F.conv2d(x, self.kernel, stride=1, padding=2)
        out = self.truc(out)

        return out

    def _build_kernel(self, inc):
        # filter1: KB
        filter1 = [[0, 0, 0, 0, 0],
                   [0, -1, 2, -1, 0],
                   [0, 2, -4, 2, 0],
                   [0, -1, 2, -1, 0],
                   [0, 0, 0, 0, 0]]
        # filter2：KV
        filter2 = [[-1, 2, -2, 2, -1],
                   [2, -6, 8, -6, 2],
                   [-2, 8, -12, 8, -2],
                   [2, -6, 8, -6, 2],
                   [-1, 2, -2, 2, -1]]
        # filter3：hor 2rd
        filter3 = [[0, 0, 0, 0, 0],
                  [0, 0, 0, 0, 0],
                  [0, 1, -2, 1, 0],
                  [0, 0, 0, 0, 0],
                  [0, 0, 0, 0, 0]]

        filter1 = np.asarray(filter1, dtype=float) / 4.
        filter2 = np.asarray(filter2, dtype=float) / 12.
        filter3 = np.asarray(filter3, dtype=float) / 2.
        # statck the filters
        filters = [[filter1],#, filter1, filter1],
                   [filter2],#, filter2, filter2],
                   [filter3]]#, filter3, filter3]]  # (3,3,5,5)
        filters = np.array(filters)
        filters = np.repeat(filters, inc, axis=1)
        filters = torch.FloatTensor(filters)    # (3,3,5,5)
        return filters

class SRMConv2d_Separate(nn.Module):
    
    def __init__(self, inc, outc, learnable=False):
        super(SRMConv2d_Separate, self).__init__()
        self.inc = inc
        self.truc = nn.Hardtanh(-3, 3)
        kernel = self._build_kernel(inc)  # (3,3,5,5)
        self.kernel = nn.Parameter(data=kernel, requires_grad=learnable)
        # self.hor_kernel = self._build_kernel().transpose(0,1,3,2)
        self.out_conv = nn.Sequential(
            nn.Conv2d(3*inc, outc, 1, 1, 0, 1, 1, bias=False),
            nn.BatchNorm2d(outc),
            nn.ReLU(inplace=True)
        )

        for ly in self.out_conv.children():
            if isinstance(ly, nn.Conv2d):
                nn.init.kaiming_normal_(ly.weight, a=1)

    def forward(self, x):
        '''
        x: imgs (Batch, H, W, 3)
        '''
        out = F.conv2d(x, self.kernel, stride=1, padding=2, groups=self.inc)
        out = self.truc(out)
        out = self.out_conv(out)

        return out

    def _build_kernel(self, inc):
        # filter1: KB
        filter1 = [[0, 0, 0, 0, 0],
                   [0, -1, 2, -1, 0],
                   [0, 2, -4, 2, 0],
                   [0, -1, 2, -1, 0],
                   [0, 0, 0, 0, 0]]
        # filter2：KV
        filter2 = [[-1, 2, -2, 2, -1],
                   [2, -6, 8, -6, 2],
                   [-2, 8, -12, 8, -2],
                   [2, -6, 8, -6, 2],
                   [-1, 2, -2, 2, -1]]
        # # filter3：hor 2rd
        filter3 = [[0, 0, 0, 0, 0],
                  [0, 0, 0, 0, 0],
                  [0, 1, -2, 1, 0],
                  [0, 0, 0, 0, 0],
                  [0, 0, 0, 0, 0]]

        filter1 = np.asarray(filter1, dtype=float) / 4.
        filter2 = np.asarray(filter2, dtype=float) / 12.
        filter3 = np.asarray(filter3, dtype=float) / 2.
        # statck the filters
        filters = [[filter1],#, filter1, filter1],
                   [filter2],#, filter2, filter2],
                   [filter3]]#, filter3, filter3]]  # (3,3,5,5)
        filters = np.array(filters)
        # filters = np.repeat(filters, inc, axis=1)
        filters = np.repeat(filters, inc, axis=0)
        filters = torch.FloatTensor(filters)    # (3,3,5,5)
        # print(filters.size())
        return filters


In [21]:
class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=8):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)

        self.sharedMLP = nn.Sequential(
            nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False),
            nn.ReLU(),
            nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False))
        self.sigmoid = nn.Sigmoid()

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_normal_(m.weight.data, gain=0.02)

    def forward(self, x):
        avgout = self.sharedMLP(self.avg_pool(x))
        maxout = self.sharedMLP(self.max_pool(x))
        return self.sigmoid(avgout + maxout)


class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()
        assert kernel_size in (3, 7), "kernel size must be 3 or 7"
        padding = 3 if kernel_size == 7 else 1

        self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_normal_(m.weight.data, gain=0.02)

    def forward(self, x):
        avgout = torch.mean(x, dim=1, keepdim=True)
        maxout, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avgout, maxout], dim=1)
        x = self.conv(x)
        return self.sigmoid(x)


"""
The following modules are modified based on https://github.com/heykeetae/Self-Attention-GAN
"""


class Self_Attn(nn.Module):
    """ Self attention Layer"""

    def __init__(self, in_dim, out_dim=None, add=False, ratio=8):
        super(Self_Attn, self).__init__()
        self.chanel_in = in_dim
        self.add = add
        if out_dim is None:
            out_dim = in_dim
        self.out_dim = out_dim
        # self.activation = activation

        self.query_conv = nn.Conv2d(
            in_channels=in_dim, out_channels=in_dim//ratio, kernel_size=1)
        self.key_conv = nn.Conv2d(
            in_channels=in_dim, out_channels=in_dim//ratio, kernel_size=1)
        self.value_conv = nn.Conv2d(
            in_channels=in_dim, out_channels=out_dim, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))

        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        """
            inputs :
                x : input feature maps( B X C X W X H)
            returns :
                out : self attention value + input feature 
                attention: B X N X N (N is Width*Height)
        """
        m_batchsize, C, width, height = x.size()
        proj_query = self.query_conv(x).view(
            m_batchsize, -1, width*height).permute(0, 2, 1)  # B X C X(N)
        proj_key = self.key_conv(x).view(
            m_batchsize, -1, width*height)  # B X C x (*W*H)
        energy = torch.bmm(proj_query, proj_key)  # transpose check
        attention = self.softmax(energy)  # BX (N) X (N)
        proj_value = self.value_conv(x).view(
            m_batchsize, -1, width*height)  # B X C X N

        out = torch.bmm(proj_value, attention.permute(0, 2, 1))
        out = out.view(m_batchsize, self.out_dim, width, height)

        if self.add:
            out = self.gamma*out + x
        else:
            out = self.gamma*out
        return out  # , attention


class CrossModalAttention(nn.Module):
    """ CMA attention Layer"""

    def __init__(self, in_dim, activation=None, ratio=8, cross_value=True):
        super(CrossModalAttention, self).__init__()
        self.chanel_in = in_dim
        self.activation = activation
        self.cross_value = cross_value

        self.query_conv = nn.Conv2d(
            in_channels=in_dim, out_channels=in_dim//ratio, kernel_size=1)
        self.key_conv = nn.Conv2d(
            in_channels=in_dim, out_channels=in_dim//ratio, kernel_size=1)
        self.value_conv = nn.Conv2d(
            in_channels=in_dim, out_channels=in_dim, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))

        self.softmax = nn.Softmax(dim=-1)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_normal_(m.weight.data, gain=0.02)

    def forward(self, x, y):
        """
            inputs :
                x : input feature maps( B X C X W X H)
            returns :
                out : self attention value + input feature 
                attention: B X N X N (N is Width*Height)
        """
        B, C, H, W = x.size()

        proj_query = self.query_conv(x).view(
            B, -1, H*W).permute(0, 2, 1)  # B , HW, C
        proj_key = self.key_conv(y).view(
            B, -1, H*W)  # B X C x (*W*H)
        energy = torch.bmm(proj_query, proj_key)  # B, HW, HW
        attention = self.softmax(energy)  # BX (N) X (N)
        if self.cross_value:
            proj_value = self.value_conv(y).view(
                B, -1, H*W)  # B , C , HW
        else:
            proj_value = self.value_conv(x).view(
                B, -1, H*W)  # B , C , HW

        out = torch.bmm(proj_value, attention.permute(0, 2, 1))
        out = out.view(B, C, H, W)

        out = self.gamma*out + x

        if self.activation is not None:
            out = self.activation(out)

        return out  # , attention


class DualCrossModalAttention(nn.Module):
    """ Dual CMA attention Layer"""

    def __init__(self, in_dim, activation=None, size=16, ratio=8, ret_att=False):
        super(DualCrossModalAttention, self).__init__()
        self.chanel_in = in_dim
        self.activation = activation
        self.ret_att = ret_att

        # query conv
        self.key_conv1 = nn.Conv2d(
            in_channels=in_dim, out_channels=in_dim//ratio, kernel_size=1)
        self.key_conv2 = nn.Conv2d(
            in_channels=in_dim, out_channels=in_dim//ratio, kernel_size=1)
        self.key_conv_share = nn.Conv2d(
            in_channels=in_dim//ratio, out_channels=in_dim//ratio, kernel_size=1)

        self.linear1 = nn.Linear(size*size, size*size)
        self.linear2 = nn.Linear(size*size, size*size)

        # separated value conv
        self.value_conv1 = nn.Conv2d(
            in_channels=in_dim, out_channels=in_dim, kernel_size=1)
        self.gamma1 = nn.Parameter(torch.zeros(1))

        self.value_conv2 = nn.Conv2d(
            in_channels=in_dim, out_channels=in_dim, kernel_size=1)
        self.gamma2 = nn.Parameter(torch.zeros(1))

        self.softmax = nn.Softmax(dim=-1)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_normal_(m.weight.data, gain=0.02)
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight.data, gain=0.02)

    def forward(self, x, y):
        """
            inputs :
                x : input feature maps( B X C X W X H)
            returns :
                out : self attention value + input feature 
                attention: B X N X N (N is Width*Height)
        """
        B, C, H, W = x.size()

        def _get_att(a, b):
            proj_key1 = self.key_conv_share(self.key_conv1(a)).view(
                B, -1, H*W).permute(0, 2, 1)  # B, HW, C
            proj_key2 = self.key_conv_share(self.key_conv2(b)).view(
                B, -1, H*W)  # B X C x (*W*H)
            energy = torch.bmm(proj_key1, proj_key2)  # B, HW, HW

            attention1 = self.softmax(self.linear1(energy))
            attention2 = self.softmax(self.linear2(
                energy.permute(0, 2, 1)))  # BX (N) X (N)

            return attention1, attention2

        att_y_on_x, att_x_on_y = _get_att(x, y)
        proj_value_y_on_x = self.value_conv2(y).view(
            B, -1, H*W)  # B, C, HW
        out_y_on_x = torch.bmm(proj_value_y_on_x, att_y_on_x.permute(0, 2, 1))
        out_y_on_x = out_y_on_x.view(B, C, H, W)
        out_x = self.gamma1*out_y_on_x + x

        proj_value_x_on_y = self.value_conv1(x).view(
            B, -1, H*W)  # B , C , HW
        out_x_on_y = torch.bmm(proj_value_x_on_y, att_x_on_y.permute(0, 2, 1))
        out_x_on_y = out_x_on_y.view(B, C, H, W)
        out_y = self.gamma2*out_x_on_y + y

        if self.ret_att:
            return out_x, out_y, att_y_on_x, att_x_on_y

        return out_x, out_y  # , attention


# if __name__ == "__main__":
#     x = torch.rand(10, 768, 16, 16)
#     y = torch.rand(10, 768, 16, 16)
#     dcma = DualCrossModalAttention(768, ret_att=True)
#     out_x, out_y, att_y_on_x, att_x_on_y = dcma(x, y)
#     print(out_y.size())
#     print(att_x_on_y.size())

In [25]:
"""
Code from https://github.com/ondyari/FaceForensics
Author: Andreas Rössler
"""
import os
import argparse


import torch
# import pretrainedmodels
import torch.nn as nn
import torch.nn.functional as F
# from lib.nets.xception import xception
import math
import torchvision

# import math
# import torch
# import torch.nn as nn
# import torch.nn.functional as F
import torch.utils.model_zoo as model_zoo
from torch.nn import init

pretrained_settings = {
    'xception': {
        'imagenet': {
            'url': 'http://data.lip6.fr/cadene/pretrainedmodels/xception-b5690688.pth',
            'input_space': 'RGB',
            'input_size': [3, 299, 299],
            'input_range': [0, 1],
            'mean': [0.5, 0.5, 0.5],
            'std': [0.5, 0.5, 0.5],
            'num_classes': 1000,
            'scale': 0.8975  # The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299
        }
    }
}

PRETAINED_WEIGHT_PATH = '../src/components/networks/xception-b5690688.pth'

class SeparableConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=False):
        super(SeparableConv2d, self).__init__()

        self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size,
                               stride, padding, dilation, groups=in_channels, bias=bias)
        self.pointwise = nn.Conv2d(
            in_channels, out_channels, 1, 1, 0, 1, 1, bias=bias)

    def forward(self, x):
        x = self.conv1(x)
        x = self.pointwise(x)
        return x


class Block(nn.Module):
    def __init__(self, in_filters, out_filters, reps, strides=1, start_with_relu=True, grow_first=True):
        super(Block, self).__init__()

        if out_filters != in_filters or strides != 1:
            self.skip = nn.Conv2d(in_filters, out_filters,
                                  1, stride=strides, bias=False)
            self.skipbn = nn.BatchNorm2d(out_filters)
        else:
            self.skip = None

        self.relu = nn.ReLU(inplace=True)
        rep = []

        filters = in_filters
        if grow_first:
            rep.append(self.relu)
            rep.append(SeparableConv2d(in_filters, out_filters,
                                       3, stride=1, padding=1, bias=False))
            rep.append(nn.BatchNorm2d(out_filters))
            filters = out_filters

        for i in range(reps-1):
            rep.append(self.relu)
            rep.append(SeparableConv2d(filters, filters,
                                       3, stride=1, padding=1, bias=False))
            rep.append(nn.BatchNorm2d(filters))

        if not grow_first:
            rep.append(self.relu)
            rep.append(SeparableConv2d(in_filters, out_filters,
                                       3, stride=1, padding=1, bias=False))
            rep.append(nn.BatchNorm2d(out_filters))

        if not start_with_relu:
            rep = rep[1:]
        else:
            rep[0] = nn.ReLU(inplace=False)

        if strides != 1:
            rep.append(nn.MaxPool2d(3, strides, 1))
        self.rep = nn.Sequential(*rep)

    def forward(self, inp):
        x = self.rep(inp)

        if self.skip is not None:
            skip = self.skip(inp)
            skip = self.skipbn(skip)
        else:
            skip = inp

        x += skip
        return x


def add_gaussian_noise(ins, mean=0, stddev=0.2):
    noise = ins.data.new(ins.size()).normal_(mean, stddev)
    return ins + noise


class Xception(nn.Module):
    """
    Xception optimized for the ImageNet dataset, as specified in
    https://arxiv.org/pdf/1610.02357.pdf
    """

    def __init__(self, num_classes=1000, inc=3):
        """ Constructor
        Args:
            num_classes: number of classes
        """
        super(Xception, self).__init__()
        self.num_classes = num_classes

        # Entry flow
        self.conv1 = nn.Conv2d(inc, 32, 3, 2, 0, bias=False)
        self.bn1 = nn.BatchNorm2d(32)
        self.relu = nn.ReLU(inplace=True)

        self.conv2 = nn.Conv2d(32, 64, 3, bias=False)
        self.bn2 = nn.BatchNorm2d(64)
        # do relu here

        self.block1 = Block(
            64, 128, 2, 2, start_with_relu=False, grow_first=True)
        self.block2 = Block(
            128, 256, 2, 2, start_with_relu=True, grow_first=True)
        self.block3 = Block(
            256, 728, 2, 2, start_with_relu=True, grow_first=True)

        # middle flow
        self.block4 = Block(
            728, 728, 3, 1, start_with_relu=True, grow_first=True)
        self.block5 = Block(
            728, 728, 3, 1, start_with_relu=True, grow_first=True)
        self.block6 = Block(
            728, 728, 3, 1, start_with_relu=True, grow_first=True)
        self.block7 = Block(
            728, 728, 3, 1, start_with_relu=True, grow_first=True)

        self.block8 = Block(
            728, 728, 3, 1, start_with_relu=True, grow_first=True)
        self.block9 = Block(
            728, 728, 3, 1, start_with_relu=True, grow_first=True)
        self.block10 = Block(
            728, 728, 3, 1, start_with_relu=True, grow_first=True)
        self.block11 = Block(
            728, 728, 3, 1, start_with_relu=True, grow_first=True)

        # Exit flow
        self.block12 = Block(
            728, 1024, 2, 2, start_with_relu=True, grow_first=False)

        self.conv3 = SeparableConv2d(1024, 1536, 3, 1, 1)
        self.bn3 = nn.BatchNorm2d(1536)

        # do relu here
        self.conv4 = SeparableConv2d(1536, 2048, 3, 1, 1)
        self.bn4 = nn.BatchNorm2d(2048)

        self.fc = nn.Linear(2048, num_classes)

        # #------- init weights --------
        # for m in self.modules():
        #     if isinstance(m, nn.Conv2d):
        #         n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
        #         m.weight.data.normal_(0, math.sqrt(2. / n))
        #     elif isinstance(m, nn.BatchNorm2d):
        #         m.weight.data.fill_(1)
        #         m.bias.data.zero_()
        # #-----------------------------
    def fea_part1_0(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)

        return x

    def fea_part1_1(self, x):

        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)

        return x

    def fea_part1(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)

        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)

        return x

    def fea_part2(self, x):
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)

        return x

    def fea_part3(self, x):
        x = self.block4(x)
        x = self.block5(x)
        x = self.block6(x)
        x = self.block7(x)

        return x

    def fea_part4(self, x):
        x = self.block8(x)
        x = self.block9(x)
        x = self.block10(x)
        x = self.block11(x)

        return x

    def fea_part5(self, x):
        x = self.block12(x)

        x = self.conv3(x)
        x = self.bn3(x)
        x = self.relu(x)

        x = self.conv4(x)
        x = self.bn4(x)

        return x

    def features(self, input):
        x = self.fea_part1(input)

        x = self.fea_part2(x)
        x = self.fea_part3(x)
        x = self.fea_part4(x)

        x = self.fea_part5(x)
        return x

    def classifier(self, features):
        x = self.relu(features)

        x = F.adaptive_avg_pool2d(x, (1, 1))
        x = x.view(x.size(0), -1)
        out = self.last_linear(x)
        return out, x

    def forward(self, input):
        x = self.features(input)
        out, x = self.classifier(x)
        return out, x


def xception(num_classes=1000, pretrained='imagenet', inc=3):
    model = Xception(num_classes=num_classes, inc=inc)
    if pretrained:
        settings = pretrained_settings['xception'][pretrained]
        assert num_classes == settings['num_classes'], \
            "num_classes should be {}, but is {}".format(
                settings['num_classes'], num_classes)

        model = Xception(num_classes=num_classes)
        model.load_state_dict(model_zoo.load_url(settings['url']))

        model.input_space = settings['input_space']
        model.input_size = settings['input_size']
        model.input_range = settings['input_range']
        model.mean = settings['mean']
        model.std = settings['std']

    # TODO: ugly
    model.last_linear = model.fc
    del model.fc
    return model


class TransferModel(nn.Module):
    """
    Simple transfer learning model that takes an imagenet pretrained model with
    a fc layer as base model and retrains a new fc layer for num_out_classes
    """

    def __init__(self, modelchoice, num_out_classes=2, dropout=0.0,
                 weight_norm=False, return_fea=False, inc=3):
        super(TransferModel, self).__init__()
        self.modelchoice = modelchoice
        self.return_fea = return_fea

        if modelchoice == 'xception':

            def return_pytorch04_xception(pretrained=True):
                # Raises warning "src not broadcastable to dst" but thats fine
                model = xception(pretrained=False)
                if pretrained:
                    # Load model in torch 0.4+
                    model.fc = model.last_linear
                    del model.last_linear
                    state_dict = torch.load(
                        PRETAINED_WEIGHT_PATH)
                    for name, weights in state_dict.items():
                        if 'pointwise' in name:
                            state_dict[name] = weights.unsqueeze(
                                -1).unsqueeze(-1)
                    model.load_state_dict(state_dict)
                    model.last_linear = model.fc
                    del model.fc
                return model

            self.model = return_pytorch04_xception()
            # Replace fc
            num_ftrs = self.model.last_linear.in_features
            if not dropout:
                if weight_norm:
                    print('Using Weight_Norm')
                    self.model.last_linear = nn.utils.weight_norm(
                        nn.Linear(num_ftrs, num_out_classes), name='weight')
                self.model.last_linear = nn.Linear(num_ftrs, num_out_classes)
            else:
                print('Using dropout', dropout)
                if weight_norm:
                    print('Using Weight_Norm')
                    self.model.last_linear = nn.Sequential(
                        nn.Dropout(p=dropout),
                        nn.utils.weight_norm(
                            nn.Linear(num_ftrs, num_out_classes), name='weight')
                    )

                self.model.last_linear = nn.Sequential(
                    nn.Dropout(p=dropout),
                    nn.Identity()
                )

            if inc != 3:
                self.model.conv1 = nn.Conv2d(inc, 32, 3, 2, 0, bias=False)
                nn.init.xavier_normal(self.model.conv1.weight.data, gain=0.02)

        elif modelchoice == 'resnet50' or modelchoice == 'resnet18':
            if modelchoice == 'resnet50':
                self.model = torchvision.models.resnet50(pretrained=True)
            if modelchoice == 'resnet18':
                self.model = torchvision.models.resnet18(pretrained=True)
            # Replace fc
            num_ftrs = self.model.fc.in_features
            if not dropout:
                self.model.fc = nn.Linear(num_ftrs, num_out_classes)
            else:
                self.model.fc = nn.Sequential(
                    nn.Dropout(p=dropout),
                    nn.Linear(num_ftrs, num_out_classes)
                )
        else:
            raise Exception('Choose valid model, e.g. resnet50')

    def set_trainable_up_to(self, boolean=False, layername="Conv2d_4a_3x3"):
        """
        Freezes all layers below a specific layer and sets the following layers
        to true if boolean else only the fully connected final layer
        :param boolean:
        :param layername: depends on lib, for inception e.g. Conv2d_4a_3x3
        :return:
        """
        # Stage-1: freeze all the layers
        if layername is None:
            for i, param in self.model.named_parameters():
                param.requires_grad = True
                return
        else:
            for i, param in self.model.named_parameters():
                param.requires_grad = False
        if boolean:
            # Make all layers following the layername layer trainable
            ct = []
            found = False
            for name, child in self.model.named_children():
                if layername in ct:
                    found = True
                    for params in child.parameters():
                        params.requires_grad = True
                ct.append(name)
            if not found:
                raise NotImplementedError('Layer not found, cant finetune!'.format(
                    layername))
        else:
            if self.modelchoice == 'xception':
                # Make fc trainable
                for param in self.model.last_linear.parameters():
                    param.requires_grad = True

            else:
                # Make fc trainable
                for param in self.model.fc.parameters():
                    param.requires_grad = True

    def forward(self, x):
        out, x = self.model(x)
        if self.return_fea:
            return out, x
        else:
            return out

    def features(self, x):
        x = self.model.features(x)
        return x

    def classifier(self, x):
        out, x = self.model.classifier(x)
        return out, x


def model_selection(modelname, num_out_classes,
                    dropout=None):
    """
    :param modelname:
    :return: model, image size, pretraining<yes/no>, input_list
    """
    if modelname == 'xception':
        return TransferModel(modelchoice='xception',
                             num_out_classes=num_out_classes), 299, \
            True, ['image'], None
    elif modelname == 'resnet18':
        return TransferModel(modelchoice='resnet18', dropout=dropout,
                             num_out_classes=num_out_classes), \
            224, True, ['image'], None
    else:
        raise NotImplementedError(modelname)


# if __name__ == '__main__':
#     model = TransferModel('xception', dropout=0.5)
#     print(model)
#     # model = model.cuda()
#     # from torchsummary import summary
#     # input_s = (3, image_size, image_size)
#     # print(summary(model, input_s))
#     dummy = torch.rand(10, 3, 256, 256)
#     out = model(dummy)
#     print(out.size())
#     x = model.features(dummy)
#     out, x = model.classifier(x)
#     print(out.size())
#     print(x.size())

In [26]:
class SRMPixelAttention(nn.Module):
    def __init__(self, in_channels):
        super(SRMPixelAttention, self).__init__()
        self.srm = SRMConv2d_simple()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, 32, 3, 2, 0, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 64, 3, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
        )
        
        self.pa = SpatialAttention()

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, a=1)
                if not m.bias is None:
                    nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x_srm = self.srm(x)
        fea = self.conv(x_srm)        
        att_map = self.pa(fea)
        
        return att_map


class FeatureFusionModule(nn.Module):
    def __init__(self, in_chan=2048*2, out_chan=2048, *args, **kwargs):
        super(FeatureFusionModule, self).__init__()
        self.convblk = nn.Sequential(
            nn.Conv2d(in_chan, out_chan, 1, 1, 0, bias=False),
            nn.BatchNorm2d(out_chan),
            nn.ReLU()
        )
        self.ca = ChannelAttention(out_chan, ratio=16)
        self.init_weight()

    def forward(self, x, y):
        fuse_fea = self.convblk(torch.cat((x, y), dim=1))
        fuse_fea = fuse_fea + fuse_fea * self.ca(fuse_fea)
        return fuse_fea

    def init_weight(self):
        for ly in self.children():
            if isinstance(ly, nn.Conv2d):
                nn.init.kaiming_normal_(ly.weight, a=1)
                if not ly.bias is None:
                    nn.init.constant_(ly.bias, 0)


class Two_Stream_Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.xception_rgb = TransferModel(
            'xception', dropout=0.5, inc=3, return_fea=True)
        self.xception_srm = TransferModel(
            'xception', dropout=0.5, inc=3, return_fea=True)

        self.srm_conv0 = SRMConv2d_simple(inc=3)
        self.srm_conv1 = SRMConv2d_Separate(32, 32)
        self.srm_conv2 = SRMConv2d_Separate(64, 64)
        self.relu = nn.ReLU(inplace=True)

        self.att_map = None
        self.srm_sa = SRMPixelAttention(3)
        self.srm_sa_post = nn.Sequential(
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )

        self.dual_cma0 = DualCrossModalAttention(in_dim=728, ret_att=False)
        self.dual_cma1 = DualCrossModalAttention(in_dim=728, ret_att=False)

        self.fusion = FeatureFusionModule()

        self.att_dic = {}

    def features(self, x):
        srm = self.srm_conv0(x)

        x = self.xception_rgb.model.fea_part1_0(x)
        y = self.xception_srm.model.fea_part1_0(srm) \
            + self.srm_conv1(x)
        y = self.relu(y)

        x = self.xception_rgb.model.fea_part1_1(x)
        y = self.xception_srm.model.fea_part1_1(y) \
            + self.srm_conv2(x)
        y = self.relu(y)

        # srm guided spatial attention
        self.att_map = self.srm_sa(srm)
        x = x * self.att_map + x
        x = self.srm_sa_post(x)

        x = self.xception_rgb.model.fea_part2(x)
        y = self.xception_srm.model.fea_part2(y)

        x, y = self.dual_cma0(x, y)


        x = self.xception_rgb.model.fea_part3(x)        
        y = self.xception_srm.model.fea_part3(y)
 

        x, y = self.dual_cma1(x, y)

        x = self.xception_rgb.model.fea_part4(x)
        y = self.xception_srm.model.fea_part4(y)

        x = self.xception_rgb.model.fea_part5(x)
        y = self.xception_srm.model.fea_part5(y)

        fea = self.fusion(x, y)
                

        return fea

    def classifier(self, fea):
        out, fea = self.xception_rgb.classifier(fea)
        return out, fea

    def forward(self, x):
        '''
        x: original rgb
        '''
        out, fea = self.classifier(self.features(x))

        return out

In [27]:
backbone = Two_Stream_Net()
dummy = torch.randn((1,3,256,256))
out = backbone(dummy)
print(out.shape)

Using dropout 0.5
Using dropout 0.5
torch.Size([1, 2048])


In [28]:
class Classify(nn.Module):
  def __init__(self):
    super().__init__()
    self.backbone = backbone
    self.network = nn.Sequential(
                    nn.Linear(2048, 512),
                    nn.Linear(512,64),
                    nn.Linear(64,2)
                )
  
  def forward(self, x):
    x = self.backbone(x)
    # print(x.shape)
    out = self.network(x)

    return out

class Domain_Generalize(nn.Module):
  def __init__(self) -> None:
      super().__init__()
      self.backbone = backbone
      self.network = nn.Sequential(
                    nn.Linear(2048, 3),
                )

  def forward(self, x):
    x = self.backbone(x)
    # print(x.shape)
    out = self.network(x)

    return out

In [29]:
net1 = Classify()
out1 = net1(dummy)
print(out1.shape)

net2 = Domain_Generalize()
out2 = net2(dummy)
print(out2.shape)

torch.Size([1, 2])
torch.Size([1, 3])


In [30]:
def set_model():
    gc.collect()
    torch.cuda.empty_cache()
    cls_model = net1.to(device)
    gen_model = net2.to(device)
    criterion = nn.CrossEntropyLoss()
    criterion = criterion.to(device)

    return cls_model, gen_model, criterion

In [31]:
def set_optimizer(cls_model, gen_model):

    optim1 = optim.Adamax(cls_model.parameters(),  lr=1e-3, eps=1e-4, weight_decay=1e-4)
    optim2 = optim.Adamax(gen_model.parameters(),  lr=1e-3, eps=1e-4, weight_decay=1e-4)
    return optim1, optim2

In [33]:
def train(train_loader, cls_model, gen_model, criterion, optim1, optim2, epoch):
    """one epoch training"""
    

    batch_time = AverageMeter()
    data_time = AverageMeter()
    cls_losses = AverageMeter()
    gen_losses = AverageMeter()

    end = time.time()
    for idx, (images, labels) in enumerate(train_loader):
        data_time.update(time.time() - end)

        images = images.to(device)
        labels = labels.to(device)
        bsz = labels.shape[0]

        # compute classification loss
        out1 = cls_model(images)
        loss1 = criterion(out1, labels)
        cls_losses.update(loss1.item(), bsz)

        optim1.zero_grad()
        loss1.backward()
        optim1.step()

        # compute generalization loss
        fake_images = images[labels == 0]
        out2 = gen_model(fake_images)
        fake_labels = torch.ones_like(out2) / 3
        loss2 = criterion(out2, fake_labels)
        gen_losses.update(loss2.item(), bsz)

        optim2.zero_grad()
        loss2.backward()
        optim2.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        # print info
        if (idx + 1) % print_freq == 0:
            print('Train: [{0}][{1}/{2}]\t'
                  'BT {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'DT {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'cls_loss {cls_loss.val:.3f} ({cls_loss.avg:.3f})\t'
                  'gen_loss {gen_loss.val:.3f} ({gen_loss.avg:.3f})\t'.format(
                   epoch, idx + 1, len(train_loader), batch_time=batch_time,
                   data_time=data_time, cls_loss=cls_losses, gen_loss=gen_losses))
            sys.stdout.flush()

    return cls_losses.avg, gen_losses.avg 

In [35]:
def test(test_loader, cls_model, gen_model, criterion, epoch):
    """one epoch training"""
    

    batch_time = AverageMeter()
    data_time = AverageMeter()
    cls_losses = AverageMeter()
    gen_losses = AverageMeter()

    end = time.time()
    for idx, (images, labels) in enumerate(test_loader):
        data_time.update(time.time() - end)

        images = images.to(device)
        labels = labels.to(device)
        bsz = labels.shape[0]

        # compute classification loss
        out1 = cls_model(images)
        loss1 = criterion(out1, labels)
        cls_losses.update(loss1.item(), bsz)

        # compute generalization loss
        fake_images = images[labels == 0]
        out2 = gen_model(fake_images)
        fake_labels = torch.ones_like(out2) / 3
        loss2 = criterion(out2, fake_labels)
        gen_losses.update(loss2.item(), bsz)
        
        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        # print info
        if (idx + 1) % print_freq == 0:
            print('Train: [{0}][{1}/{2}]\t'
                  'BT {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'DT {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'cls_loss {cls_loss.val:.3f} ({cls_loss.avg:.3f})\t'
                  'gen_loss {gen_loss.val:.3f} ({gen_loss.avg:.3f})\t'.format(
                   epoch, idx + 1, len(train_loader), batch_time=batch_time,
                   data_time=data_time, cls_loss=cls_losses, gen_loss=gen_losses))
            sys.stdout.flush()

    return cls_losses.avg, gen_losses.avg 

In [36]:
train_dataset = torch.utils.data.ConcatDataset([DF_train_dataset, F2F_train_dataset, FS_train_dataset, OR_train_dataset, OR_train_dataset, OR_train_dataset])
valid_dataset = torch.utils.data.ConcatDataset([DF_valid_dataset, F2F_valid_dataset, FS_valid_dataset, OR_valid_dataset, OR_valid_dataset, OR_valid_dataset])
# build data loader
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=batch_size, shuffle=True)

print(len(train_dataset))

107547


In [None]:
def main():
    
    # build model and criterion
    cls_model, gen_model, criterion = set_model()
    
    total_params = sum(p.numel() for p in cls_model.parameters() if p.requires_grad)
    print(f"Number of Classififcation Network's Parameters: {total_params:,}")
    total_params = sum(p.numel() for p in gen_model.parameters() if p.requires_grad)
    print(f"Number of Classififcation Network's Parameters: {total_params:,}")

    # build optimizer
    optim1, optim2 = set_optimizer(cls_model, gen_model)
    scheduler1 = optim.lr_scheduler.OneCycleLR(optim1, 
                               max_lr=1e-3, 
                               epochs=epochs,
                               steps_per_epoch=len(train_loader),
                               pct_start=16.0/epochs,
                               div_factor=25,
                               final_div_factor=2)
    scheduler2 = optim.lr_scheduler.OneCycleLR(optim2, 
                               max_lr=1e-3, 
                               epochs=epochs,
                               steps_per_epoch=len(train_loader),
                               pct_start=16.0/epochs,
                               div_factor=25,
                               final_div_factor=2)

    # tensorboard
    writer = SummaryWriter(tb_folder)
    print('Start Training')
    # training routine
    for epoch in range(1, epochs):
#         adjust_learning_rate(optimizer, epoch)

        # train for one epoch
        time1 = time.time()
        cls_model.train()
        gen_model.train()
        cls_loss, gen_loss = train(train_loader, cls_model, gen_model, criterion, optim1, optim2, epoch)
        scheduler1.step()
        scheduler2.step()
#         return 0
        time2 = time.time()
        print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1))
        print('epoch {}, cls loss {:.2f}, gen loss {:.2f}'.format(epoch, cls_loss, gen_loss))
        print('*'*15)
        
        # tensorboard logger
        writer.add_scalar('cls loss', cls_loss, epoch)
        writer.add_scalar('gen loss', gen_loss, epoch)
        writer.add_scalar('cls learning_rate', optim1.param_groups[0]['lr'], epoch)
        writer.add_scalar('gen learning_rate', optim2.param_groups[0]['lr'], epoch)

        if epoch % save_freq == 0:
            save_file = os.path.join(
                save_folder, 'cls_ckpt_epoch_{epoch}.pth'.format(epoch=epoch))
            save_model(cls_model, optim1, epoch, save_file)
            save_file = os.path.join(
                save_folder, 'gen_ckpt_epoch_{epoch}.pth'.format(epoch=epoch))
            save_model(gen_model, optim2, epoch, save_file)
    
    writer.flush()
    writer.close()

    # save the last model
    save_file = os.path.join(
        save_folder, 'cls_last.pth')
    save_model(cls_model, optim1, epochs, save_file)
    save_file = os.path.join(
        save_folder, 'gen_last.pth')
    save_model(gen_model, optim2, epochs, save_file)