In [None]:
import sys,os
# Root directory of the project
ROOT_DIR = os.path.abspath("../../")

sys.path.append(ROOT_DIR)

In [None]:
import time
import torch
import torch.nn as nn
from torchvision import models
from torch.optim.lr_scheduler import ReduceLROnPlateau

## Load Data

In [None]:
from dataset.dataset_ImagePrivacy import IPDataset_FromFileList, full_transform
from torch.utils.data import DataLoader

In [None]:
partition = 1
partition = str(partition)

In [None]:
data_dir = '../../../../data/image_privacy/'
train_images = data_dir + 'exp/partition'+ partition + '/train.csv'
val_images = data_dir + 'exp/partition'+ partition + '/val.csv'
test_images = data_dir + 'exp/partition'+ partition + '/test.csv'

In [None]:
train_data = IPDataset_FromFileList(train_images, full_transform)
val_data = IPDataset_FromFileList(val_images,full_transform)
test_data = IPDataset_FromFileList(test_images, full_transform)

In [None]:
train_loader = DataLoader(dataset=train_data, batch_size=32, shuffle=True)
val_loader = DataLoader(dataset=val_data, batch_size=32, shuffle=True)
test_loader = DataLoader(dataset=test_data, batch_size=32, shuffle=True)

## Load Model

In [None]:
from networks.channel_grouping import load_cls_model

In [None]:
cls_model = load_cls_model(class_num=2, pretrained=True)
cls_model = cls_model.to(device)

## Class weight

In [None]:
# deal with the unbalanced dataset

private_nums, public_nums = train_data.labels.count(0), train_data.labels.count(1)
sample_class_count  = torch.Tensor([private_nums, public_nums])

class_weight = sample_class_count.float() /train_data.__len__()
class_weight = 1.-class_weight

class_weight = class_weight.to(device)
print(class_weight)

## Hyperparameters

In [None]:
epochs = 50
learning_rate = 1e-4

momentum = 0.9
weight_decay = 1e-7

In [None]:
loss_func = nn.CrossEntropyLoss(weight=class_weight)
optimizer = torch.optim.SGD(cls_model.parameters(), lr=learning_rate, momentum=momentum)
# optimizer = torch.optim.Adam(cls_model.parameters(), lr=learning_rate,weight_decay=weight_decay)

In [None]:
scheduler_cls = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=1)

## Training

In [None]:
def validate(data_loader):
    # validating
    print('validating')
    cls_model.eval()
    correct = 0
    total = 0

    TP,FP,FN,TN = 0,0,0,0

    with torch.no_grad():
        for i, data in enumerate(data_loader, 0):
            target = data[0].to(device)
            img = data[1].to(device)
            outputs = cls_model(img)

            predicted = torch.argmax(outputs.data,-1)

            total += target.size(0)
            correct += (predicted == target).sum().item()

            TP += ((target == 0) & (predicted == 0)).sum().item()
            FP += ((target == 0) & (predicted == 1)).sum().item()
            FN += ((target == 1) & (predicted == 0)).sum().item()
            TN += ((target == 1) & (predicted == 1)).sum().item()

            del(outputs)
            del(predicted)
            
    acc = 100. * correct / total
    
    if data_loader == test_loader:

        print('testing accuracy：%.3f%%' % (acc))

    else:
        print('validating accuracy：%.3f%%' % (acc))

    try:

        #private metrics
        p1 = TP / (TP + FP)
        r1 = TP / (TP + FN)
        f1 = (2 * p1 * r1) / (p1 + r1)

        #public metrics
        p2 = TN / (TN + FN)
        r2 = TN / (TN + FP)
        f2 = (2 * p2 * r2) / (p2 + r2)

        print('===========================')

        print('private class metrics:')
        
        print('precision, recall, f1:')
        print('%.3f%%\t%.3f%%\t%.3f' % (p1 * 100, r1 * 100, f1))

        print('===========================')
        
        print('public class metrics:')
        
        print('precision, recall, f1:')
        print('%.3f%%\t%.3f%%\t%.3f' % (p2 * 100, r2 * 100, f2))
        
        print('===========================')


#         print('===========================')
#         print((TP+TN)/(TP+TN+FP+FN))
#         print('===========================')


    except Exception as e:
        print(e)
        print('TP, FP, TN, FN: ')
        print(TP, FP, TN, FN)

    
    return acc

In [None]:
epoch_start = 0

for epoch in range(epoch_start, (epoch_start+epochs)):
    print('training')
    cls_model.train()
    running_loss, count, acc = 0., 0, 0.
    print(time.asctime())
    
    print('current learning rate:')
    print(optimizer.param_groups[0]['lr'])
    
    for i, data in enumerate(train_loader, 0):
        target = data[0].to(device)
        img = data[1].to(device)
        outputs = cls_model(img)
        
        optimizer.zero_grad()
        loss = loss_func(outputs, target)
        loss.backward()
        optimizer.step()

        # print statistics     
        running_loss += loss.item()
        if i % 50 == 49:    # print every several mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / (i + 1)))

    val_acc = validate(val_loader)
    test_acc = validate(test_loader)
    
    scheduler_cls.step(val_acc)

    val_acc = round(val_acc,3)
    test_acc = round(test_acc,3)
        
    # save checkpoints
    print('saving checkpoints....')

    model_path = '../models/ResNet4IP({})_{}_{}.pth'.format(epoch, val_acc, test_acc)
    torch.save(cls_model.state_dict(), model_path)

            
print('Finished Training')   

In [None]:
model_path = '../models/ResNet4IP.pth'
cls_model.load_state_dict(torch.load(model_path))

In [None]:
validate(val_loader)