## pre-train the channel-grouping-layer

In [None]:
import sys,os
# Root directory of the project
ROOT_DIR = os.path.abspath("../../")

sys.path.append(ROOT_DIR)

In [None]:
import time
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchvision import models
from networks.channel_grouping import channel_grouping_layer, load_backbone_model

In [None]:
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
print(device)
print(os.getpid())

## load data

In [None]:
from dataset.dataset_ImagePrivacy import IPDataset_FromFileList, full_transform
from torch.utils.data import DataLoader

In [None]:
# 5 partitions in total for cross validation
partition = 1
partition = str(partition)

In [None]:
data_dir = '../../../../data/image_privacy/exp/'
train_images = data_dir + 'partition'+ partition + '/train.csv'

In [None]:
train_data = IPDataset_FromFileList(train_images, full_transform)

In [None]:
train_loader = DataLoader(dataset=train_data, batch_size=16, shuffle=True)

## load the cluster label for pre-training

In [None]:
# load the channel label, 2048 channels with peak cordinate of all the n training images
# formated as 2048 * [tx1, ty1, tx2, ty2, ... txn, tyn]
# the clustering result is a 2048-dimension vector for each part and will be used to supervise the fc layer

In [None]:
part_nums = list(range(2,14,2))
part_num = part_nums[1]

In [None]:
part_num

In [None]:
part_index = np.load(file='./grouping_result/channel_cluster_' + str(part_num) + '.npy')

In [None]:
def index_to_label(part_indexs, part_num):
    cluster_label = []
    for i in range(part_num):
        cluster_label.append([])

    for index in part_indexs:
        for j in range(part_num):
            cluster_label[j].append(0)
        cluster_label[index][-1] = 1
    
    return cluster_label

In [None]:
cluster_label = index_to_label(part_index, part_num)
cluster_label = np.array(cluster_label)
cluster_label = torch.LongTensor(cluster_label).to(device)

## load the models

In [None]:
# backbone model to extract convolutional features.
# the features are flattened and need to be reshaped before next layer.

In [None]:
model_path = '../../models/ResNet4IP.pth'
conv_model = load_backbone_model(model_path).to(device)
conv_model.eval()
cgl = channel_grouping_layer(part_num=part_num, channel_num=2048).to(device)

In [None]:
# pre-trained channel_grouping model

In [None]:
# cgl_path = '../models/channel_grouping_layer.pth'
# cgl.load_state_dict(torch.load(cgl_path))

## experimental setups

In [None]:
epochs = 3
learning_rate = 1e-4

momentum = 0.9
weight_decay = 1e-7

In [None]:
# cluster_result: [8(part_num) * 2048(indicater for each channel)], use MSELoss rather than CrossEntropy
class_weight = torch.tensor([1/8,7/8])

loss_func = nn.CrossEntropyLoss(class_weight).to(device)
# loss_func = nn.CrossEntropyLoss().to(device)


optimizer = torch.optim.Adam(cgl.parameters(), lr=learning_rate, weight_decay=weight_decay)

# optimizer = torch.optim.SGD(cgl.parameters(), lr=learning_rate, momentum=momentum)

scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=0)

In [None]:
def cal_grouping_loss(grouping_result, target, avg = True):
    grouping_loss = torch.zeros(1).to(device)
    
    grouping_result = grouping_result.unsqueeze(-1)
    res_tmp = 1. - grouping_result

    grouping_label = torch.cat((grouping_result, res_tmp), dim = -1)

    for i in range(grouping_label.shape[0]):
        for j in range(target.shape[0]):
            loss = loss_func(grouping_label[i,j,:,:], target[j])
            grouping_loss += loss
            
    if avg:
        sample_num = grouping_label.shape[0] * grouping_label.shape[1]
        grouping_loss = grouping_loss / sample_num
    return grouping_loss

## training

In [None]:
def validate(data_loader):

    print('validating')
    cgl.eval()
    running_loss, count = 0., 0
    print(time.asctime())
    
    with torch.no_grad():

        for i, data in enumerate(data_loader, 0):
            target = cluster_label
            img = data[1].to(device)

            conv_features = conv_model(img).reshape(-1, 2048, 14, 14)
            channel_grouping_res = cgl(conv_features)

            grouping_result, attention_mask = channel_grouping_res[0], channel_grouping_res[1]

            loss = cal_grouping_loss(grouping_result, target)

            # print statistics
            count += data[0].shape[0]
            running_loss += loss.item()
    
#     avg_loss = running_loss / count
    avg_loss = running_loss / (i+1)

    print("avg_loss:" + str(avg_loss))

    return avg_loss

In [None]:
checkpoint_dir = '../../models/ImagePrivacy/'

checkpoint_dir = checkpoint_dir + str(part_num)

if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)

# cgl_path = checkpoint_dir + '/channel_grouping_layer(0)_0.313261658.pth'
# cgl.load_state_dict(torch.load(cgl_path))

In [None]:
epoch_start = 0
for epoch in range(epoch_start, epoch_start + epochs):
    print('training')
    cgl.train()
    running_loss, count, acc = 0., 0, 0.
    print(time.asctime())
    
    print('current learning rate:')
    print(optimizer.param_groups[0]['lr'])
    
    for i, data in enumerate(train_loader, 0):
        target = cluster_label
        img = data[1].to(device)
        
        #use the torchvision model for convinence, but should reshape to deal with the flatten layer
        conv_features = conv_model(img).reshape(-1, 2048, 14, 14)
        channel_grouping_res = cgl(conv_features)
        
        grouping_result, attention_mask = channel_grouping_res[0], channel_grouping_res[1]
        
        optimizer.zero_grad()
        loss = cal_grouping_loss(grouping_result, target)
        loss.backward()
        optimizer.step()
        
        
        # print statistics     
        running_loss += loss.item()
        if i % 50 == 49:    # print every several mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / (i + 1)))

    avg_loss = validate(train_loader)    
    avg_loss = round(avg_loss,9)
    scheduler.step(avg_loss)
    
    model_path = checkpoint_dir + '/channel_grouping_layer({})_{}.pth'.format(epoch, avg_loss)
    torch.save(cgl.state_dict(), model_path)

            
print('Finished Training')   

In [None]:
validate(train_loader)