In [9]:
from comet_ml import Experiment
import argparse
import time
from tqdm import tqdm
import os
import torch
import torch.nn as nn
import torch.optim as optim
import yaml
from head.metrics import CosFace
from loss.focal import FocalLoss
from utils.utils import separate_resnet_bn_paras, warm_up_lr, load_checkpoint, \
    schedule_lr, AverageMeter, accuracy
from utils.fairness_utils import evaluate
from utils.data_utils_balanced import prepare_data
from utils.utils_train import Network
import numpy as np
import pandas as pd
import random
import timm
from utils.utils import save_output_from_dict
from utils.utils_train import Network, get_head
from utils.fairness_utils import evaluate, add_column_to_file
from timm.optim import create_optimizer_v2, optimizer_kwargs
from timm.scheduler import create_scheduler
from timm.utils.model_ema import ModelEmaV2
from utils.fairness_utils import *
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#device_ids=range(torch.cuda.device_count())
torch.manual_seed(222)
torch.cuda.manual_seed_all(222)
np.random.seed(222)
random.seed(222)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


default_test_root = '/cmlscratch/sdooley1/data/CelebA/Img/img_align_celeba_splits/test/'
default_train_root = '/cmlscratch/sdooley1/data/CelebA/Img/img_align_celeba_splits/train/'


In [2]:
parser = argparse.ArgumentParser()

parser.add_argument('--default_test_root', default=default_test_root)
parser.add_argument('--default_train_root', default=default_train_root)
parser.add_argument('--demographics_file', default= '/cmlscratch/sdooley1/data/CelebA/CelebA_demographics.txt')
parser.add_argument('--backbone_name', default='mobilenetv3_large_100')
parser.add_argument('--backbone', default='mobilenetv3_large_100')
parser.add_argument('--pretrained', default=False)
parser.add_argument('--project_name', default="from-scratch_no-resampling_adam")
parser.add_argument('--head', default="CosFace")
parser.add_argument('--opt', default="AdamW")
parser.add_argument('--epochs', default=100)
parser.add_argument('--sched', default='cosine')
parser.add_argument('--min_lr', default=0.01)

parser.add_argument('--checkpoints_root', default='/cmlscratch/sdooley1/merge_timm/FR-NAS/Checkpoints/Phase1B/')
parser.add_argument('--head_name', default='CosFace')
parser.add_argument('--train_loss', default='Focal', type=str)

parser.add_argument('--groups_to_modify', default= ['male', 'female'], type=str, nargs='+')
parser.add_argument('--p_identities', default=[1.0, 1.0], type=float, nargs='+')
parser.add_argument('--p_images', default=[1.0, 1.0], type=float, nargs='+')
parser.add_argument('--min_num_images', default=3, type=int)

parser.add_argument('--batch_size', default=250, type=int)
parser.add_argument('--input_size', default=112, type=int)
parser.add_argument('--weight_decay', default=5e-4, type=float)
parser.add_argument('--momentum', default=0.9, type=float)
parser.add_argument('--mean', default=[0.5, 0.5, 0.5], type=int)
parser.add_argument('--std', default=[0.5, 0.5, 0.5], type=int)
parser.add_argument('--stages', default=[35, 65, 95], type=int)
parser.add_argument('--num_workers', default=4, type=int)

parser.add_argument('--lr', default=0.001, type=float)
parser.add_argument('--num_epoch', default=3, type=int)
parser.add_argument('--gpu_id', default=[0], type=int, nargs='+', help='gpu id')
parser.add_argument('--name', default='CelebA', type=str)
parser.add_argument('--dataset', default='CelebA', type=str)
parser.add_argument('--file_name', default='timm_from-scratch.csv', type=str)
parser.add_argument('--seed', default=222, type=int)

args = parser.parse_args('')

p_images = {args.groups_to_modify[i]:args.p_images[i] for i in range(len(args.groups_to_modify))}
p_identities = {args.groups_to_modify[i]:args.p_identities[i] for i in range(len(args.groups_to_modify))}
args.p_images = p_images
args.p_identities = p_identities

print("P identities: {}".format(args.p_identities))
print("P images: {}".format(args.p_images))




P identities: {'male': 1.0, 'female': 1.0}
P images: {'male': 1.0, 'female': 1.0}


In [3]:
dataloaders, num_class, demographic_to_labels_train, demographic_to_labels_test = prepare_data(args)
args.num_class = num_class


PREPARING TRAIN DATASET
Overall # of images for male available is 67562
# images selected for male is 67562
Overall # of images for female available is 76524
# images selected for female is 67562
Number of idx for male is 3529
Number of idx for female is 3529
PREPARING TEST DATASET
Overall # of images for male available is 7644
# images selected for male is 7636
Overall # of images for female available is 8851
# images selected for female is 7636
Number of idx for male is 406
Number of idx for female is 406
Len of train dataloader is 540
Len of test dataloader is 62


In [4]:

''' Model '''
backbone = timm.create_model(args.backbone_name, 
                             num_classes=0,
                             pretrained=args.pretrained).to(device)
config = timm.data.resolve_data_config({}, model=backbone)
model_input_size = config['input_size']

# get model's embedding size
meta = pd.read_csv('/cmlscratch/sdooley1/timm_model_metadata.csv')
embedding_size = int(
    meta[meta.model_name == args.backbone].feature_dim)
args.embedding_size= embedding_size




head = get_head(args)
train_criterion = FocalLoss(elementwise=True)
head,backbone= head.to(device), backbone.to(device)
backbone = nn.DataParallel(backbone)
####################################################################################################################
# ======= argsimizer =======#
model = Network(backbone, head)

optimizer = create_optimizer_v2(model, **optimizer_kwargs(cfg=args))

model_ema = None
model, model_ema, optimizer, epoch, batch, checkpoints_model_root = load_checkpoint(
    args, model, model_ema, optimizer, dataloaders["train"], p_identities,
    p_images)
#model = nn.DataParallel(model)
model = model.to(device)

Checkpoint_Head_CosFace_Backbone_mobilenetv3_large_100_Opt_AdamW_Dataset_CelebA_Epoch_
Found checkpoints for this model: ['Checkpoint_Head_CosFace_Backbone_mobilenetv3_large_100_Opt_AdamW_Dataset_CelebA_Epoch_80.pth', 'Checkpoint_Head_CosFace_Backbone_mobilenetv3_large_100_Opt_AdamW_Dataset_CelebA_Epoch_20.pth', 'Checkpoint_Head_CosFace_Backbone_mobilenetv3_large_100_Opt_AdamW_Dataset_CelebA_Epoch_101.pth', 'Checkpoint_Head_CosFace_Backbone_mobilenetv3_large_100_Opt_AdamW_Dataset_CelebA_Epoch_60.pth', 'Checkpoint_Head_CosFace_Backbone_mobilenetv3_large_100_Opt_AdamW_Dataset_CelebA_Epoch_40.pth', 'Checkpoint_Head_CosFace_Backbone_mobilenetv3_large_100_Opt_AdamW_Dataset_CelebA_Epoch_100.pth']
Loading Checkpoint '/cmlscratch/sdooley1/merge_timm/FR-NAS/Checkpoints/Phase1B/mobilenetv3_large_100_CosFace_AdamW/Checkpoint_Head_CosFace_Backbone_mobilenetv3_large_100_Opt_AdamW_Dataset_CelebA_Epoch_101.pth'


In [5]:
demographic_to_labels = demographic_to_labels_train
loss = {k:torch.tensor(0.0) for k in demographic_to_labels.keys()}
acc = {k:torch.tensor(0.0) for k in demographic_to_labels.keys()}
count = {k:torch.tensor(0.0) for k in demographic_to_labels.keys()}
acc_k = {k:torch.tensor(0.0) for k in demographic_to_labels.keys()}
intra = {k:torch.tensor(0.0) for k in demographic_to_labels.keys()}
inter = {k:torch.tensor(0.0) for k in demographic_to_labels.keys()}
angles_intra, angles_inter, correct = 0, 0, 0

#backbone.eval()
#if multilabel_accuracy:
#    head.eval()
model.eval()
# figure out embedding size
emb_size = embedding_size
dataloader = dataloaders['test']
if emb_size is None:
    inputs, _, _ = next(iter(dataloader))
    x = torch.randn(inputs.shape).to(device)
    emb_size = backbone(x).shape[1]


feature_matrix = torch.empty(0, emb_size)
labels_all = []
indices_all = []
demographic_all = []
predicted_all = []

for inputs, labels, sens_attr, indices in tqdm(iter(dataloader)):
    inputs = inputs.to(device)
    labels = labels.to(device).long()
    labels_all = labels_all + labels.cpu().tolist()
    indices_all = indices_all + indices.cpu().tolist()
    sens_attr = np.array(sens_attr)
    with torch.no_grad():

        if True:
            #need to build feature matrix
            inputs_flipped = torch.flip(inputs, [3])
            try:
                embed = model.module.backbone(inputs) + model.module.backbone(inputs_flipped)
            except AttributeError:
                embed = model.backbone(inputs) + model.backbone(inputs_flipped)
            features_batch = l2_norm(embed)
            feature_matrix = torch.cat((feature_matrix, features_batch.detach().cpu()), dim = 0)

            demographic_all = demographic_all + sens_attr.tolist()


100%|██████████| 62/62 [00:11<00:00,  5.48it/s]


In [6]:
feature_matrix, labels, demographic_to_labels, test_features, test_labels, test_demographic = feature_matrix, torch.tensor(labels_all), demographic_to_labels, feature_matrix, torch.tensor(labels_all), np.array(demographic_all)



In [7]:
dist_matrix =  l2_dist(feature_matrix, feature_matrix)

In [11]:
acc_k = {k:0 for k in demographic_to_labels.keys()}
nearest_neighbors = torch.topk(dist_matrix, dim=1, k = 2, largest = False)[1][:,1]
n_images = dist_matrix.shape[0]
correct = torch.zeros(test_labels.shape)
nearest_id = torch.zeros(test_labels.shape)

t = time.time()
for img in range(n_images):
    nearest_label = labels[nearest_neighbors[img]].item()
    nearest_id[img] = nearest_label
    label_img = test_labels[img].item()
    if label_img == nearest_label:
        correct[img] = 1
print(time.time()-t)
for k in acc_k.keys():
    acc_k[k] = (correct[test_demographic == k]).mean()

# acc_k, 
# correct = torch.tensor(just_one + 1)
# nearest_id = torch.tensor(df[1].apply(lambda x: labels_np[x]))

0.2038564682006836


In [12]:
nearest_id

tensor([ 53.,  53.,  53.,  ..., 576., 662., 406.])

In [None]:
labels_all

In [35]:
def process_row(row):
    base_label = labels_np[row[0]]
    i = 1
    while i < row.shape[0]:
        if labels_np[row[i]] == base_label:
            return i-1
        i+=1
    return -1

labels_np = labels.numpy()

n_nearestneighbors = dist_matrix.shape[0]
desc_dist = torch.topk(dist_matrix, dim=1, k = n_nearestneighbors, largest = False)[1]
entire = torch.tensor([process_row(row) for row in desc_dist])

n_nearestneighbors = 2
inc_dist = torch.topk(dist_matrix, dim=1, k = n_nearestneighbors, largest = False)[1]
just_one = torch.tensor([process_row(row) for row in inc_dist])

nearest_id = desc_dist[:,1].apply_(lambda x: labels_np[x])

In [46]:
sum((entire == 0).long() == just_one + 1)

tensor(15272)

In [47]:
entire

tensor([0, 0, 0,  ..., 0, 0, 0])

In [None]:
entire[entire > 0]

In [None]:
just_one[just_one == -1]

In [None]:
df = pd.DataFrame(desc_dist.numpy())

In [None]:
df.apply(lambda x: labels_np[x])

In [None]:
def process_row(row):
    base_label = labels_np[row[0]]
    i = 1
    while i < row.shape[0]:
        if labels_np[row[i]] == base_label:
            return i-1
        i+=1
    return i-1
df.apply(lambda row : process_row(row), axis = 1)

In [None]:
labels

In [None]:
desc_dist.shape

In [None]:
desc_dist.apply_(lambda x: labels[x])

In [None]:
import time
t = time.time()
torch.topk(dist_matrix, dim=1, k = 2, largest = False)[1]
print(time.time() - t)