In [None]:
import sys, os, time
import random
from torchvision import models
from torchsummary import summary
import torch
import torch.nn as nn
import numpy as np
from torch.optim.lr_scheduler import ReduceLROnPlateau

In [None]:
from networks.channel_grouping import load_backbone_model, channel_grouping_layer
from networks.dgcn import dgcn_cls
from networks.loss import channel_grouping_loss

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
print(os.getpid())

In [None]:
def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

In [None]:
setup_seed(9)

## load data

In [None]:
from dataset.dataset_ImagePrivacy import IPDataset_FromFileList, full_transform
from torch.utils.data import DataLoader

In [None]:
partition = 1
partition = str(partition)

In [None]:
data_dir = '../../data/image_privacy/'
train_images = data_dir + 'exp/partition'+ partition + '/train.csv'
val_images = data_dir + 'exp/partition'+ partition + '/val.csv'
test_images = data_dir + 'exp/partition'+ partition + '/test.csv'

In [None]:
train_data = IPDataset_FromFileList(train_images, full_transform)
val_data = IPDataset_FromFileList(val_images,full_transform)
test_data = IPDataset_FromFileList(test_images, full_transform)

In [None]:
train_loader = DataLoader(dataset=train_data, batch_size=16, shuffle=True)
val_loader = DataLoader(dataset=val_data, batch_size=32, shuffle=True)
test_loader = DataLoader(dataset=test_data, batch_size=32, shuffle=True)

## class weight

In [None]:
# deal with the unbalanced dataset

private_nums, public_nums = train_data.labels.count(0), train_data.labels.count(1)
sample_class_count  = torch.Tensor([private_nums, public_nums])

class_weight = sample_class_count.float() /train_data.__len__()
class_weight = 1.-class_weight

class_weight = class_weight.to(device)
print(class_weight.shape)

In [None]:
class_weight

## load models

In [None]:
part_nums = list(range(2,14,2))
part_num = part_nums[1]

In [None]:
backbone_model_path = './models/ResNet4IP.pth'
checkpoint_dir = './models/ImagePrivacy/'
checkpoint_dir = checkpoint_dir + str(part_num)
channel_grouping_layer_path = checkpoint_dir + '/channel_grouping_layer.pth'

In [None]:
backbone_model = load_backbone_model(backbone_model_path).to(device) # pretrained feature extractor
channel_grouping = channel_grouping_layer(part_num=part_num, channel_num=2048).to(device)
channel_grouping.load_state_dict(torch.load(channel_grouping_layer_path)) # pretrained channel grouping layer
dgcn = dgcn_cls(part_num=part_num).to(device)

## hyperparameters

In [None]:
learning_rate_backbone = 1e-5
learning_rate_cls = 1e-5
learning_rate_cgl = 1e-3

epochs = 15

momentum = 0.9
weight_decay = 1e-7

In [None]:
cls_loss = nn.CrossEntropyLoss(weight=class_weight)
cgl_loss = channel_grouping_loss()

# cls_optimizer = torch.optim.SGD(dgcn.parameters(), lr=learning_rate, momentum=momentum)
# cgl_optimizer = torch.optim.SGD(channel_grouping.parameters(), lr=learning_rate, momentum=momentum)

backbone_optimizer = torch.optim.Adam(backbone_model.parameters(), lr=learning_rate_backbone,weight_decay=weight_decay)
cls_optimizer = torch.optim.Adam(dgcn.parameters(), lr=learning_rate_cls,weight_decay=weight_decay)
cgl_optimizer = torch.optim.Adam(channel_grouping.parameters(), lr=learning_rate_cgl,weight_decay=weight_decay)

In [None]:
scheduler_backbone = ReduceLROnPlateau(backbone_optimizer, mode='min', factor=0.1, patience=1)
scheduler_cls = ReduceLROnPlateau(cls_optimizer, mode='min', factor=0.1, patience=1)
scheduler_cgl = ReduceLROnPlateau(cgl_optimizer, mode='min', factor=0.1, patience=3)

## training

In [None]:
def validate(data_loader):
    # validating
    print('validating')
    print(time.asctime())
    
    backbone_model.eval()
    dgcn.eval()
    channel_grouping.eval()
    
    correct = 0
    total = 0

    TP,FP,FN,TN = 0,0,0,0

    with torch.no_grad():
        for i, data in enumerate(data_loader, 0):
           
            target = data[0].to(device)
            img = data[1].to(device)

            feature = backbone_model(img).reshape(-1, 2048, 14, 14)
            grouping_result, weighted_feature = channel_grouping(feature)
            cls_res = dgcn(feature, weighted_feature)

            predicted = torch.argmax(cls_res.data,-1)

            total += target.size(0)
            correct += (predicted == target).sum().item()

            TP += ((target == 0) & (predicted == 0)).sum().item()
            FP += ((target == 0) & (predicted == 1)).sum().item()
            FN += ((target == 1) & (predicted == 0)).sum().item()
            TN += ((target == 1) & (predicted == 1)).sum().item()

            del(cls_res)
            del(predicted)
            
    acc = 100. * correct / total
    
    if data_loader == test_loader:

        print('testing accuracy：%.3f%%' % (acc))

    else:
        print('validating accuracy：%.3f%%' % (acc))

    try:

        #private metrics
        p1 = TP / (TP + FP)
        r1 = TP / (TP + FN)
        f1 = (2 * p1 * r1) / (p1 + r1)

        #public metrics
        p2 = TN / (TN + FN)
        r2 = TN / (TN + FP)
        f2 = (2 * p2 * r2) / (p2 + r2)

        print('===========================')

        print('private class metrics:')
        
        print('precision, recall, f1:')
        print('%.3f%%\t%.3f%%\t%.3f' % (p1 * 100, r1 * 100, f1))

        print('===========================')
        
        print('public class metrics:')
        
        print('precision, recall, f1:')
        print('%.3f%%\t%.3f%%\t%.3f' % (p2 * 100, r2 * 100, f2))
        
        print('===========================')


#         print('===========================')
#         print((TP+TN)/(TP+TN+FP+FN))
#         print('===========================')


    except Exception as e:
        print(e)
    
    return acc

In [None]:
# checkpoint dir

checkpoint_dir = checkpoint_dir + '/wo_cgl_finetune'

if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)

In [None]:
cgl_checkpoint = checkpoint_dir + '/CGL_IP(0)_36.116_36.106.pth'
channel_grouping.load_state_dict(torch.load(cgl_checkpoint))

dgcn_checkpoint = checkpoint_dir + '/DGCN_IP(9)_86.502_86.699.pth'
dgcn.load_state_dict(torch.load(dgcn_checkpoint))

backbone_checkpoint = checkpoint_dir + '/ResNet_IP(9)_86.502_86.699.pth'
backbone_model.load_state_dict(torch.load(backbone_checkpoint))

In [None]:
print('current learning rate:')
print('cls:')
print(cls_optimizer.param_groups[0]['lr'])
print('backbone:')
print(backbone_optimizer.param_groups[0]['lr'])

In [None]:
epoch_start = 1

interval = 20

for epoch in range(epoch_start, (epoch_start+epochs)):
    running_loss_cls, count, acc = 0., 0, 0.
    running_loss_dis, running_loss_div = 0., 0.

    print('training')
    print(time.asctime())

    print('current learning rate:')
    print('cls:')
    print(cls_optimizer.param_groups[0]['lr'])
    print('backbone:')
    print(backbone_optimizer.param_groups[0]['lr'])
    
    if epoch%interval != 0:
        dgcn.train()
        backbone_model.train()
        channel_grouping.eval()
    else:
        dgcn.eval()
        backbone_model.eval()
        channel_grouping.train()

    for i, data in enumerate(train_loader, 0):
        target = data[0].to(device)
        img = data[1].to(device)
        
        feature = backbone_model(img).reshape(-1, 2048, 14, 14)
        grouping_result, weighted_feature = channel_grouping(feature)
        cls_res = dgcn(feature, weighted_feature)
        
        if epoch%interval != 0:
            backbone_optimizer.zero_grad()
            cls_optimizer.zero_grad()
            
            loss = cls_loss(cls_res, target)
            running_loss_cls += loss.item()
            
            loss.backward()
            cls_optimizer.step()
            
            if epoch > 5:
                backbone_optimizer.step()
            
            
        else:
            cgl_optimizer.zero_grad()
            
            loss1 = cgl_loss(weighted_feature)    # [dis_loss, div_loss]
            loss2 = cls_loss(cls_res, target)    # classification loss

            running_loss_dis += loss1[0].item()
            running_loss_div += loss1[1].item()
            
            loss = (loss1[0] + loss1[1] + loss2)

            loss.backward()
            cgl_optimizer.step()
            
        # print statistics
        if i % 50 == 49:
            if epoch%interval != 0:
                print('[%d, %5d] loss: %.3f' %
                      (epoch + 1, i + 1, running_loss_cls / (i + 1)))
            else:
                print('[%d, %5d] loss: %.3f' %
                      (epoch + 1, i + 1, loss2))
                print('[%d, %5d] dis/div loss: %.8f, %.8f' %
                      (epoch + 1, i + 1, running_loss_dis / (i + 1), running_loss_div / (i + 1)))
                
                running_loss_cgl = running_loss_dis / (i + 1) + running_loss_div / (i + 1)
                scheduler_cgl.step(running_loss_cgl)


    val_acc = validate(val_loader)
    test_acc = validate(test_loader)

    val_acc = round(val_acc,3)
    test_acc = round(test_acc,3)

    if epoch%interval != 0:
        print('saving cls checkpoints....')

        scheduler_cls.step(val_acc)
        
        if epoch>5:
            scheduler_backbone.step(val_acc)
            backbone_path = checkpoint_dir + '/ResNet_IP({})_{}_{}.pth'.format(epoch, val_acc, test_acc)
            torch.save(backbone_model.state_dict(), backbone_path)

        dgcn_path = checkpoint_dir + '/DGCN_IP({})_{}_{}.pth'.format(epoch, val_acc, test_acc)
        torch.save(dgcn.state_dict(), dgcn_path)

    
    else:
        print('saving cgl checkpoints....')

#         running_loss_cgl = running_loss_dis + running_loss_div
#         scheduler_cgl.step(running_loss_cgl)
        
        cgl_path = checkpoint_dir + '/CGL_IP({})_{}_{}.pth'.format(epoch, val_acc, test_acc)
        torch.save(channel_grouping.state_dict(), cgl_path)    
    
            
print('Finished Training')   

## test

In [None]:
cgl_checkpoint = './models/channel_grouping_layer.pth'

channel_grouping.load_state_dict(torch.load(cgl_checkpoint))

In [None]:
dgcn_checkpoint = './models/DGCN_IP(1)_90.625_87.5.pth'
dgcn.load_state_dict(torch.load(dgcn_checkpoint))

In [None]:
validate(test_loader)