In [1]:
import cv2
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms

from dex_age_clf import VGG as age_clf_net

In [3]:
CHPKPT_PATH = '../pretrained_models/dex_age_classifier.pth'
checkpoint = torch.load(CHKPT_PATH, map_location='cpu')['state_dict']
checkpoint = {k.replace('-', '_'): v for k, v in checkpoint.items()}

In [None]:
age_net = age_clf_net()
age_net.load_state_dict(checkpoint)
_ = age_net.eval()

In [None]:
dex_transforms = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Resize((256, 256)),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
    ]
)

In [7]:
@torch.no_grad()
def predict(x):
    x = F.interpolate(x, size=(224, 224), mode='bilinear')
    age_pb = age_net(x)['fc8']
    # get_age
    predict_age_pb = F.softmax(age_pb)
    predict_age = torch.zeros(age_pb.size(0)).type_as(predict_age_pb)
    for i in range(age_pb.size(0)):
        for j in range(age_pb.size(1)):
            predict_age[i] += j * predict_age_pb[i][j]
    return predict_age.item()


In [8]:
df = pd.read_csv('age_results.csv')

In [9]:
for i in range(20,75,5):
    df[f"dex_{i}"] = None

In [11]:
files = df.filename.tolist()
for age_gt in range(20, 75,5):
    print(f"dex age_{age_gt} starting...")
    age_results = []
    for i, file_ in enumerate(files):
        path = os.path.join('inference_results', str(age_gt), file_)
        img = cv2.imread(path)[:,:,:3][:,:,::-1].copy()
        img = dex_transforms(img).unsqueeze(0)
        age = predict(img)
        age_results.append(age)
        print(f"\t{i}/{len(files)} done", end='\r')
    df[f"dex_{age_gt}"] = age_results
    print(f"dex age_{age_gt} done")

dex age_20 starting...


  predict_age_pb = F.softmax(age_pb)


dex age_20 done
dex age_25 starting...
dex age_25 done
dex age_30 starting...
dex age_30 done
dex age_35 starting...
dex age_35 done
dex age_40 starting...
dex age_40 done
dex age_45 starting...
dex age_45 done
dex age_50 starting...
dex age_50 done
dex age_55 starting...
dex age_55 done
dex age_60 starting...
dex age_60 done
dex age_65 starting...
dex age_65 done
dex age_70 starting...
dex age_70 done


In [12]:
df.head()
df.to_csv('age_with_dex.csv', index=False)

In [13]:
df.head()

Unnamed: 0,filename,age_20,age_25,age_30,age_35,age_40,age_45,age_50,age_55,age_60,...,dex_25,dex_30,dex_35,dex_40,dex_45,dex_50,dex_55,dex_60,dex_65,dex_70
0,image_01208.jpg,20,20,23,25,34,43,46,49,51,...,23.807871,29.918303,39.516712,46.524509,54.701405,59.085262,60.142437,60.543205,61.103607,69.948624
1,image_01381.jpg,20,24,24,24,31,31,31,34,35,...,24.669273,26.171074,30.830036,35.25489,39.464714,45.217274,58.49052,60.009022,60.419468,67.503426
2,image_01795.jpg,20,20,23,25,34,36,41,42,44,...,24.773066,35.375366,42.480625,46.251328,52.696869,57.490555,58.913834,59.602692,59.970028,60.771751
3,image_01993.jpg,23,23,24,36,31,35,39,42,46,...,23.568668,26.55361,35.105816,42.876495,46.615517,54.231213,58.883698,59.67828,60.001141,60.563381
4,image_02218.jpg,19,19,20,20,23,27,35,41,46,...,24.603367,25.824903,31.084902,35.246357,40.052582,44.157139,52.303165,59.44154,60.396782,70.090485
