In [40]:
import torch
import torchvision
import torchvision.transforms as transforms
from torch.autograd import Variable
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from cnn_finetune import make_model

from net import AGNet
from loss import AGLoss

from torch.autograd import Variable
from datagen import ListDataset

import os

In [41]:
batch_size = 32
test_batch_size = 64
epochs = 50
learning_rate = 0.01
momentum = 0.9
dropout = 0.3
start_epoch = 0
best_correct = 0
use_cuda = torch.cuda.is_available()
device = torch.device('cuda' if use_cuda else 'cpu')

In [42]:
transform_train = transforms.Compose([
    transforms.CenterCrop(150),
    transforms.RandomCrop(150, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.485,0.456,0.406), (0.229,0.224,0.225)
    )
])
trainset = ListDataset(root='../data/UTKFace/', transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=8)

transform_test = transforms.Compose([
    transforms.CenterCrop(150),
    transforms.ToTensor(),
    transforms.Normalize((0.485,0.456,0.406), (0.229,0.224,0.225))
])
testset = ListDataset(root='../data/test/', transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=8)

In [43]:
net = AGNet()
net.cuda()
criterion = AGLoss()
# criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=momentum, weight_decay=1e-4)

In [44]:
# def accuracy(age_preds, age_targets, gender_preds, gender_targets):
def accuracy(gender_preds, gender_targets):

#     AGE_TOLERANCE = 3
#     age_prob = F.softmax(age_preds)
#     age_expect = torch.sum(Variable(torch.arange(1, 117)).cuda() * age_prob, 1)
    
#     age_correct = ((age_expect - age_targets.float()).abs() < AGE_TOLERANCE).int().sum().cpu().data[0]
    
    gender_preds = F.sigmoid(gender_preds)
    gender_preds = (gender_preds > 0.5).int()
    gender_correct = (gender_preds == gender_targets.int()).int().cpu().sum().data[0]
    return gender_correct
#     return age_correct, gender_correct

In [45]:
def train(epoch):
    net.train()
    train_loss = 0
    total = 0
    age_correct = 0
    gender_correct = 0
    for batch_idx, (inputs, age_targets, gender_targets) in enumerate(trainloader):
        inputs = Variable(inputs.cuda())
        age_targets = Variable(age_targets.cuda())
        gender_targets = Variable(gender_targets.cuda())
        
        # Zero the parameter gradients
        optimizer.zero_grad()
        
        # 
        age_preds, gender_preds = net(inputs)
        loss = criterion(gender_preds, gender_targets)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.data[0]
        
#         age_correct_i, gender_correct_i = accuracy(age_preds,
#                                                   age_targets, gender_preds,
#                                                   gender_targets)
        gender_correct_i = accuracy(gender_preds, gender_targets)
#         age_correct += age_correct_i
        gender_correct += gender_correct_i
        total += len(inputs)
        print('train_loss: %f | avg_loss: %f | gender_precise: %f (%d/%d) [%d/%d]'
             % (loss.data[0], train_loss/(batch_idx+1),
               100.*gender_correct/total, gender_correct, total,
               batch_idx+1, len(trainloader)))
#         print('train_loss: %.3f | avg_loss: %.3f | age_prec: %.3f (%d/%d) | gender_prec: %.3f (%d/%d)  [%d/%d]'  \
#             % (loss.data[0], train_loss/(batch_idx+1),      \
#                100.*age_correct/total, age_correct, total,  \
#                100.*gender_correct/total, gender_correct, total,    \
#                batch_idx+1, len(trainloader)))

In [46]:
# Test
def test(epoch):
    print('\nTest')
    net.eval()
    test_loss = 0
    total = 0
    age_correct = 0
    gender_correct = 0
    for batch_idx, (inputs, age_targets, gender_targets) in enumerate(testloader):
        inputs = Variable(inputs.cuda())
        age_targets = Variable(age_targets.cuda())
        gender_targets = Variable(gender_targets.cuda())

#         age_preds, gender_preds = net(inputs)
        age_preds, gender_preds = net(inputs)
#         loss = criterion(age_preds, age_targets, gender_preds, gender_targets)
#         loss = criterion(gender_preds, gender_targets)

#         test_loss += loss.data[0]
        gender_correct_i = accuracy(gender_preds, gender_targets)
#         age_correct_i, gender_correct_i = accuracy(
#             age_preds, age_targets, gender_preds, gender_targets)
#         age_correct += age_correct_i
        gender_correct += gender_correct_i
        total += len(inputs)
        print('gender_prec: %f (%d/%d) [%d/%d]' % (100.*gender_correct/total, gender_correct, total, batch_idx+1, len(testloader)))
#         print('test_loss: %f | avg_loss: %f | age_prec: %f (%d/%d) | gender_prec: %f (%d/%d)  [%d/%d]' \
#             % (loss.data[0], test_loss/(batch_idx+1),      \
#                100.*age_correct/total, age_correct, total,  \
#                100.*gender_correct/total, gender_correct, total, \
#                batch_idx+1, len(trainloader)))

    # Save checkpoint
    global best_correct
#     if age_correct + gender_correct > best_correct:
    if gender_correct > best_correct:
        print('Saving..')
        best_correct = gender_correct
        state = {
            'net': net.state_dict(),
            'correct': best_correct,
            'epoch': epoch,
        }
        if not os.path.isdir('../weights/checkpoint'):
            os.mkdir('../weights/checkpoint')
        torch.save(state, '..weights/checkpoint/ckpt.pth')

In [47]:
for epoch in range(start_epoch, start_epoch+10):
    print("Number epoch: {}".format(epoch))
    train(epoch)
    test(epoch)

Number epoch: 0


  print("gender_loss: {}".format(gender_loss.data[0]))


gender_loss: 0.6942108273506165


  if sys.path[0] == '':


train_loss: 0.694211 | avg_loss: 0.694211 | gender_precise: 39.000000 (50/128) [1/186]
gender_loss: 0.6932098269462585
train_loss: 0.693210 | avg_loss: 0.693710 | gender_precise: 44.000000 (113/256) [2/186]
gender_loss: 0.6929138898849487
train_loss: 0.692914 | avg_loss: 0.693445 | gender_precise: 47.000000 (181/384) [3/186]
gender_loss: 0.6931552290916443
train_loss: 0.693155 | avg_loss: 0.693372 | gender_precise: 47.000000 (245/512) [4/186]
gender_loss: 0.6931732296943665
train_loss: 0.693173 | avg_loss: 0.693333 | gender_precise: 48.000000 (308/640) [5/186]
gender_loss: 0.6933927536010742
train_loss: 0.693393 | avg_loss: 0.693343 | gender_precise: 47.000000 (364/768) [6/186]
gender_loss: 0.6932075023651123
train_loss: 0.693208 | avg_loss: 0.693323 | gender_precise: 47.000000 (424/896) [7/186]
gender_loss: 0.693169116973877
train_loss: 0.693169 | avg_loss: 0.693304 | gender_precise: 47.000000 (483/1024) [8/186]
gender_loss: 0.693134605884552
train_loss: 0.693135 | avg_loss: 0.693285 

train_loss: 0.694040 | avg_loss: 0.692051 | gender_precise: 52.000000 (4536/8704) [68/186]
gender_loss: 0.6959288120269775
train_loss: 0.695929 | avg_loss: 0.692107 | gender_precise: 52.000000 (4599/8832) [69/186]
gender_loss: 0.6902096271514893
train_loss: 0.690210 | avg_loss: 0.692080 | gender_precise: 52.000000 (4668/8960) [70/186]
gender_loss: 0.6978049278259277
train_loss: 0.697805 | avg_loss: 0.692160 | gender_precise: 52.000000 (4729/9088) [71/186]
gender_loss: 0.6846057772636414
train_loss: 0.684606 | avg_loss: 0.692055 | gender_precise: 52.000000 (4804/9216) [72/186]
gender_loss: 0.6855738759040833
train_loss: 0.685574 | avg_loss: 0.691966 | gender_precise: 52.000000 (4878/9344) [73/186]
gender_loss: 0.6949324011802673
train_loss: 0.694932 | avg_loss: 0.692007 | gender_precise: 52.000000 (4942/9472) [74/186]
gender_loss: 0.6930127143859863
train_loss: 0.693013 | avg_loss: 0.692020 | gender_precise: 52.000000 (5008/9600) [75/186]
gender_loss: 0.695817232131958
train_loss: 0.695

train_loss: 0.693810 | avg_loss: 0.692454 | gender_precise: 51.000000 (8882/17152) [134/186]
gender_loss: 0.6890815496444702
train_loss: 0.689082 | avg_loss: 0.692429 | gender_precise: 51.000000 (8956/17280) [135/186]
gender_loss: 0.6939033269882202
train_loss: 0.693903 | avg_loss: 0.692439 | gender_precise: 51.000000 (9019/17408) [136/186]
gender_loss: 0.6973665952682495
train_loss: 0.697367 | avg_loss: 0.692475 | gender_precise: 51.000000 (9074/17536) [137/186]
gender_loss: 0.6951184272766113
train_loss: 0.695118 | avg_loss: 0.692495 | gender_precise: 51.000000 (9134/17664) [138/186]
gender_loss: 0.6882700324058533
train_loss: 0.688270 | avg_loss: 0.692464 | gender_precise: 51.000000 (9210/17792) [139/186]
gender_loss: 0.6938259601593018
train_loss: 0.693826 | avg_loss: 0.692474 | gender_precise: 51.000000 (9273/17920) [140/186]
gender_loss: 0.6912797689437866
train_loss: 0.691280 | avg_loss: 0.692465 | gender_precise: 51.000000 (9342/18048) [141/186]
gender_loss: 0.6963562369346619


train_loss: 0.691919 | avg_loss: 0.692142 | gender_precise: 52.000000 (868/1664) [13/186]
gender_loss: 0.6894540190696716
train_loss: 0.689454 | avg_loss: 0.691951 | gender_precise: 52.000000 (938/1792) [14/186]
gender_loss: 0.684411883354187
train_loss: 0.684412 | avg_loss: 0.691448 | gender_precise: 52.000000 (1015/1920) [15/186]
gender_loss: 0.6873265504837036
train_loss: 0.687327 | avg_loss: 0.691190 | gender_precise: 53.000000 (1088/2048) [16/186]
gender_loss: 0.6839935779571533
train_loss: 0.683994 | avg_loss: 0.690767 | gender_precise: 53.000000 (1165/2176) [17/186]
gender_loss: 0.693673849105835
train_loss: 0.693674 | avg_loss: 0.690928 | gender_precise: 53.000000 (1230/2304) [18/186]
gender_loss: 0.6894683837890625
train_loss: 0.689468 | avg_loss: 0.690852 | gender_precise: 53.000000 (1300/2432) [19/186]
gender_loss: 0.6927663683891296
train_loss: 0.692766 | avg_loss: 0.690947 | gender_precise: 53.000000 (1366/2560) [20/186]
gender_loss: 0.693443775177002
train_loss: 0.693444 

train_loss: 0.682809 | avg_loss: 0.691224 | gender_precise: 52.000000 (5422/10240) [80/186]
gender_loss: 0.686247706413269
train_loss: 0.686248 | avg_loss: 0.691163 | gender_precise: 52.000000 (5495/10368) [81/186]
gender_loss: 0.6826953887939453
train_loss: 0.682695 | avg_loss: 0.691059 | gender_precise: 53.000000 (5572/10496) [82/186]
gender_loss: 0.6926846504211426
train_loss: 0.692685 | avg_loss: 0.691079 | gender_precise: 53.000000 (5638/10624) [83/186]
gender_loss: 0.6932688355445862
train_loss: 0.693269 | avg_loss: 0.691105 | gender_precise: 53.000000 (5703/10752) [84/186]
gender_loss: 0.6951802372932434
train_loss: 0.695180 | avg_loss: 0.691153 | gender_precise: 52.000000 (5766/10880) [85/186]
gender_loss: 0.6799356937408447
train_loss: 0.679936 | avg_loss: 0.691022 | gender_precise: 53.000000 (5846/11008) [86/186]
gender_loss: 0.6889631152153015
train_loss: 0.688963 | avg_loss: 0.690999 | gender_precise: 53.000000 (5916/11136) [87/186]
gender_loss: 0.6860582232475281
train_los

train_loss: 0.685968 | avg_loss: 0.691177 | gender_precise: 52.000000 (9864/18688) [146/186]
gender_loss: 0.6893224716186523
train_loss: 0.689322 | avg_loss: 0.691164 | gender_precise: 52.000000 (9933/18816) [147/186]
gender_loss: 0.6908865571022034
train_loss: 0.690887 | avg_loss: 0.691162 | gender_precise: 52.000000 (10000/18944) [148/186]
gender_loss: 0.6979867815971375
train_loss: 0.697987 | avg_loss: 0.691208 | gender_precise: 52.000000 (10055/19072) [149/186]
gender_loss: 0.6931210160255432
train_loss: 0.693121 | avg_loss: 0.691221 | gender_precise: 52.000000 (10118/19200) [150/186]
gender_loss: 0.6950445175170898
train_loss: 0.695045 | avg_loss: 0.691246 | gender_precise: 52.000000 (10177/19328) [151/186]
gender_loss: 0.6912137269973755
train_loss: 0.691214 | avg_loss: 0.691246 | gender_precise: 52.000000 (10242/19456) [152/186]
gender_loss: 0.6963440775871277
train_loss: 0.696344 | avg_loss: 0.691279 | gender_precise: 52.000000 (10300/19584) [153/186]
gender_loss: 0.68722736835

train_loss: 0.689675 | avg_loss: 0.687683 | gender_precise: 60.000000 (2015/3328) [26/186]
gender_loss: 0.6835059523582458
train_loss: 0.683506 | avg_loss: 0.687528 | gender_precise: 60.000000 (2090/3456) [27/186]
gender_loss: 0.6833314895629883
train_loss: 0.683331 | avg_loss: 0.687379 | gender_precise: 60.000000 (2164/3584) [28/186]
gender_loss: 0.6872897744178772
train_loss: 0.687290 | avg_loss: 0.687375 | gender_precise: 60.000000 (2235/3712) [29/186]
gender_loss: 0.6788393259048462
train_loss: 0.678839 | avg_loss: 0.687091 | gender_precise: 60.000000 (2319/3840) [30/186]
gender_loss: 0.6916712522506714
train_loss: 0.691671 | avg_loss: 0.687239 | gender_precise: 60.000000 (2387/3968) [31/186]
gender_loss: 0.6816034317016602
train_loss: 0.681603 | avg_loss: 0.687063 | gender_precise: 60.000000 (2461/4096) [32/186]
gender_loss: 0.6876484155654907
train_loss: 0.687648 | avg_loss: 0.687080 | gender_precise: 60.000000 (2538/4224) [33/186]
gender_loss: 0.6844367384910583
train_loss: 0.68

train_loss: 0.603733 | avg_loss: 0.670236 | gender_precise: 62.000000 (7456/11904) [93/186]
gender_loss: 0.6167487502098083
train_loss: 0.616749 | avg_loss: 0.669667 | gender_precise: 62.000000 (7550/12032) [94/186]
gender_loss: 0.6434347629547119
train_loss: 0.643435 | avg_loss: 0.669391 | gender_precise: 62.000000 (7632/12160) [95/186]
gender_loss: 0.624903678894043
train_loss: 0.624904 | avg_loss: 0.668928 | gender_precise: 62.000000 (7722/12288) [96/186]
gender_loss: 0.5962358713150024
train_loss: 0.596236 | avg_loss: 0.668178 | gender_precise: 62.000000 (7818/12416) [97/186]
gender_loss: 0.6144238114356995
train_loss: 0.614424 | avg_loss: 0.667630 | gender_precise: 63.000000 (7905/12544) [98/186]
gender_loss: 0.6244215965270996
train_loss: 0.624422 | avg_loss: 0.667193 | gender_precise: 63.000000 (7989/12672) [99/186]
gender_loss: 0.5553866028785706
train_loss: 0.555387 | avg_loss: 0.666075 | gender_precise: 63.000000 (8089/12800) [100/186]
gender_loss: 0.6009172797203064
train_lo

train_loss: 0.507548 | avg_loss: 0.616467 | gender_precise: 67.000000 (13754/20352) [159/186]
gender_loss: 0.4338093400001526
train_loss: 0.433809 | avg_loss: 0.615325 | gender_precise: 67.000000 (13861/20480) [160/186]
gender_loss: 0.4566800892353058
train_loss: 0.456680 | avg_loss: 0.614340 | gender_precise: 67.000000 (13959/20608) [161/186]
gender_loss: 0.4553655982017517
train_loss: 0.455366 | avg_loss: 0.613359 | gender_precise: 67.000000 (14056/20736) [162/186]
gender_loss: 0.42951643466949463
train_loss: 0.429516 | avg_loss: 0.612231 | gender_precise: 67.000000 (14157/20864) [163/186]
gender_loss: 0.39813148975372314
train_loss: 0.398131 | avg_loss: 0.610925 | gender_precise: 67.000000 (14264/20992) [164/186]
gender_loss: 0.46134722232818604
train_loss: 0.461347 | avg_loss: 0.610019 | gender_precise: 68.000000 (14364/21120) [165/186]
gender_loss: 0.4579216241836548
train_loss: 0.457922 | avg_loss: 0.609103 | gender_precise: 68.000000 (14465/21248) [166/186]
gender_loss: 0.439752

train_loss: 0.375185 | avg_loss: 0.415610 | gender_precise: 81.000000 (4067/4992) [39/186]
gender_loss: 0.3291642367839813
train_loss: 0.329164 | avg_loss: 0.413449 | gender_precise: 81.000000 (4176/5120) [40/186]
gender_loss: 0.342203825712204
train_loss: 0.342204 | avg_loss: 0.411712 | gender_precise: 81.000000 (4283/5248) [41/186]
gender_loss: 0.2798180878162384
train_loss: 0.279818 | avg_loss: 0.408571 | gender_precise: 81.000000 (4398/5376) [42/186]
gender_loss: 0.34360629320144653
train_loss: 0.343606 | avg_loss: 0.407060 | gender_precise: 81.000000 (4507/5504) [43/186]
gender_loss: 0.38674548268318176
train_loss: 0.386745 | avg_loss: 0.406599 | gender_precise: 81.000000 (4614/5632) [44/186]
gender_loss: 0.2982025742530823
train_loss: 0.298203 | avg_loss: 0.404190 | gender_precise: 82.000000 (4725/5760) [45/186]
gender_loss: 0.37577080726623535
train_loss: 0.375771 | avg_loss: 0.403572 | gender_precise: 82.000000 (4830/5888) [46/186]
gender_loss: 0.3592092990875244
train_loss: 0.

train_loss: 0.335300 | avg_loss: 0.370294 | gender_precise: 83.000000 (11342/13568) [106/186]
gender_loss: 0.38431769609451294
train_loss: 0.384318 | avg_loss: 0.370425 | gender_precise: 83.000000 (11447/13696) [107/186]
gender_loss: 0.2924637496471405
train_loss: 0.292464 | avg_loss: 0.369703 | gender_precise: 83.000000 (11557/13824) [108/186]
gender_loss: 0.34408560395240784
train_loss: 0.344086 | avg_loss: 0.369468 | gender_precise: 83.000000 (11665/13952) [109/186]
gender_loss: 0.2321425825357437
train_loss: 0.232143 | avg_loss: 0.368220 | gender_precise: 83.000000 (11781/14080) [110/186]
gender_loss: 0.28845828771591187
train_loss: 0.288458 | avg_loss: 0.367501 | gender_precise: 83.000000 (11891/14208) [111/186]
gender_loss: 0.3432818353176117
train_loss: 0.343282 | avg_loss: 0.367285 | gender_precise: 83.000000 (11998/14336) [112/186]
gender_loss: 0.2802627682685852
train_loss: 0.280263 | avg_loss: 0.366515 | gender_precise: 83.000000 (12108/14464) [113/186]
gender_loss: 0.319698

train_loss: 0.299733 | avg_loss: 0.353268 | gender_precise: 84.000000 (18452/21888) [171/186]
gender_loss: 0.301461786031723
train_loss: 0.301462 | avg_loss: 0.352967 | gender_precise: 84.000000 (18566/22016) [172/186]
gender_loss: 0.3246181905269623
train_loss: 0.324618 | avg_loss: 0.352803 | gender_precise: 84.000000 (18675/22144) [173/186]
gender_loss: 0.35675206780433655
train_loss: 0.356752 | avg_loss: 0.352825 | gender_precise: 84.000000 (18785/22272) [174/186]
gender_loss: 0.25802260637283325
train_loss: 0.258023 | avg_loss: 0.352284 | gender_precise: 84.000000 (18901/22400) [175/186]
gender_loss: 0.35997718572616577
train_loss: 0.359977 | avg_loss: 0.352327 | gender_precise: 84.000000 (19010/22528) [176/186]
gender_loss: 0.3544620871543884
train_loss: 0.354462 | avg_loss: 0.352339 | gender_precise: 84.000000 (19117/22656) [177/186]
gender_loss: 0.3253265917301178
train_loss: 0.325327 | avg_loss: 0.352188 | gender_precise: 84.000000 (19228/22784) [178/186]
gender_loss: 0.3370715

train_loss: 0.380082 | avg_loss: 0.308142 | gender_precise: 86.000000 (5626/6528) [51/186]
gender_loss: 0.23457016050815582
train_loss: 0.234570 | avg_loss: 0.306727 | gender_precise: 86.000000 (5741/6656) [52/186]
gender_loss: 0.3918714225292206
train_loss: 0.391871 | avg_loss: 0.308334 | gender_precise: 86.000000 (5850/6784) [53/186]
gender_loss: 0.3059954047203064
train_loss: 0.305995 | avg_loss: 0.308290 | gender_precise: 86.000000 (5961/6912) [54/186]
gender_loss: 0.25717592239379883
train_loss: 0.257176 | avg_loss: 0.307361 | gender_precise: 86.000000 (6074/7040) [55/186]
gender_loss: 0.27995750308036804
train_loss: 0.279958 | avg_loss: 0.306872 | gender_precise: 86.000000 (6186/7168) [56/186]
gender_loss: 0.2967441976070404
train_loss: 0.296744 | avg_loss: 0.306694 | gender_precise: 86.000000 (6299/7296) [57/186]
gender_loss: 0.289835661649704
train_loss: 0.289836 | avg_loss: 0.306403 | gender_precise: 86.000000 (6410/7424) [58/186]
gender_loss: 0.3865159749984741
train_loss: 0.

train_loss: 0.235792 | avg_loss: 0.296315 | gender_precise: 86.000000 (13014/14976) [117/186]
gender_loss: 0.26432210206985474
train_loss: 0.264322 | avg_loss: 0.296044 | gender_precise: 86.000000 (13126/15104) [118/186]
gender_loss: 0.23833227157592773
train_loss: 0.238332 | avg_loss: 0.295559 | gender_precise: 86.000000 (13240/15232) [119/186]
gender_loss: 0.32182577252388
train_loss: 0.321826 | avg_loss: 0.295778 | gender_precise: 86.000000 (13348/15360) [120/186]
gender_loss: 0.30718323588371277
train_loss: 0.307183 | avg_loss: 0.295872 | gender_precise: 86.000000 (13457/15488) [121/186]
gender_loss: 0.3066820800304413
train_loss: 0.306682 | avg_loss: 0.295961 | gender_precise: 86.000000 (13570/15616) [122/186]
gender_loss: 0.34138813614845276
train_loss: 0.341388 | avg_loss: 0.296330 | gender_precise: 86.000000 (13679/15744) [123/186]
gender_loss: 0.2611689269542694
train_loss: 0.261169 | avg_loss: 0.296047 | gender_precise: 86.000000 (13790/15872) [124/186]
gender_loss: 0.3410722

train_loss: 0.321784 | avg_loss: 0.291573 | gender_precise: 87.000000 (20314/23296) [182/186]
gender_loss: 0.37186306715011597
train_loss: 0.371863 | avg_loss: 0.292012 | gender_precise: 87.000000 (20420/23424) [183/186]
gender_loss: 0.3048299252986908
train_loss: 0.304830 | avg_loss: 0.292081 | gender_precise: 87.000000 (20528/23552) [184/186]
gender_loss: 0.26146838068962097
train_loss: 0.261468 | avg_loss: 0.291916 | gender_precise: 87.000000 (20638/23680) [185/186]
gender_loss: 0.4959561824798584
train_loss: 0.495956 | avg_loss: 0.293013 | gender_precise: 87.000000 (20659/23708) [186/186]

Test
gender_prec: 83.000000 (107/128) [1/2]
gender_prec: 83.000000 (112/134) [2/2]
Saving..
Number epoch: 5
gender_loss: 0.2194591909646988
train_loss: 0.219459 | avg_loss: 0.219459 | gender_precise: 89.000000 (115/128) [1/186]
gender_loss: 0.25508058071136475
train_loss: 0.255081 | avg_loss: 0.237270 | gender_precise: 90.000000 (231/256) [2/186]
gender_loss: 0.2788529098033905
train_loss: 0.2788

train_loss: 0.301173 | avg_loss: 0.272152 | gender_precise: 88.000000 (6991/7936) [62/186]
gender_loss: 0.2842526435852051
train_loss: 0.284253 | avg_loss: 0.272344 | gender_precise: 88.000000 (7100/8064) [63/186]
gender_loss: 0.23186907172203064
train_loss: 0.231869 | avg_loss: 0.271712 | gender_precise: 88.000000 (7217/8192) [64/186]
gender_loss: 0.22669385373592377
train_loss: 0.226694 | avg_loss: 0.271019 | gender_precise: 88.000000 (7331/8320) [65/186]
gender_loss: 0.29553893208503723
train_loss: 0.295539 | avg_loss: 0.271391 | gender_precise: 88.000000 (7443/8448) [66/186]
gender_loss: 0.2562289834022522
train_loss: 0.256229 | avg_loss: 0.271164 | gender_precise: 88.000000 (7557/8576) [67/186]
gender_loss: 0.26393580436706543
train_loss: 0.263936 | avg_loss: 0.271058 | gender_precise: 88.000000 (7670/8704) [68/186]
gender_loss: 0.3027409017086029
train_loss: 0.302741 | avg_loss: 0.271517 | gender_precise: 88.000000 (7780/8832) [69/186]
gender_loss: 0.2534095346927643
train_loss: 

train_loss: 0.220195 | avg_loss: 0.264327 | gender_precise: 88.000000 (14484/16384) [128/186]
gender_loss: 0.260219544172287
train_loss: 0.260220 | avg_loss: 0.264296 | gender_precise: 88.000000 (14597/16512) [129/186]
gender_loss: 0.18274301290512085
train_loss: 0.182743 | avg_loss: 0.263668 | gender_precise: 88.000000 (14717/16640) [130/186]
gender_loss: 0.24205027520656586
train_loss: 0.242050 | avg_loss: 0.263503 | gender_precise: 88.000000 (14834/16768) [131/186]
gender_loss: 0.22829067707061768
train_loss: 0.228291 | avg_loss: 0.263236 | gender_precise: 88.000000 (14952/16896) [132/186]
gender_loss: 0.1544123739004135
train_loss: 0.154412 | avg_loss: 0.262418 | gender_precise: 88.000000 (15071/17024) [133/186]
gender_loss: 0.32840728759765625
train_loss: 0.328407 | avg_loss: 0.262911 | gender_precise: 88.000000 (15180/17152) [134/186]
gender_loss: 0.36451080441474915
train_loss: 0.364511 | avg_loss: 0.263663 | gender_precise: 88.000000 (15285/17280) [135/186]
gender_loss: 0.37313

train_loss: 0.257554 | avg_loss: 0.258857 | gender_precise: 88.000000 (795/896) [7/186]
gender_loss: 0.39671632647514343
train_loss: 0.396716 | avg_loss: 0.276089 | gender_precise: 88.000000 (906/1024) [8/186]
gender_loss: 0.3173665702342987
train_loss: 0.317367 | avg_loss: 0.280676 | gender_precise: 88.000000 (1017/1152) [9/186]
gender_loss: 0.29883673787117004
train_loss: 0.298837 | avg_loss: 0.282492 | gender_precise: 88.000000 (1131/1280) [10/186]
gender_loss: 0.2512652277946472
train_loss: 0.251265 | avg_loss: 0.279653 | gender_precise: 88.000000 (1246/1408) [11/186]
gender_loss: 0.21312062442302704
train_loss: 0.213121 | avg_loss: 0.274109 | gender_precise: 88.000000 (1361/1536) [12/186]
gender_loss: 0.29095640778541565
train_loss: 0.290956 | avg_loss: 0.275405 | gender_precise: 88.000000 (1473/1664) [13/186]
gender_loss: 0.2684505879878998
train_loss: 0.268451 | avg_loss: 0.274908 | gender_precise: 88.000000 (1585/1792) [14/186]
gender_loss: 0.309044748544693
train_loss: 0.30904

train_loss: 0.255599 | avg_loss: 0.251474 | gender_precise: 89.000000 (8447/9472) [74/186]
gender_loss: 0.2262522280216217
train_loss: 0.226252 | avg_loss: 0.251138 | gender_precise: 89.000000 (8558/9600) [75/186]
gender_loss: 0.2440306395292282
train_loss: 0.244031 | avg_loss: 0.251044 | gender_precise: 89.000000 (8672/9728) [76/186]
gender_loss: 0.3024007976055145
train_loss: 0.302401 | avg_loss: 0.251711 | gender_precise: 89.000000 (8779/9856) [77/186]
gender_loss: 0.3301714062690735
train_loss: 0.330171 | avg_loss: 0.252717 | gender_precise: 89.000000 (8891/9984) [78/186]
gender_loss: 0.29227879643440247
train_loss: 0.292279 | avg_loss: 0.253218 | gender_precise: 89.000000 (9003/10112) [79/186]
gender_loss: 0.2532590627670288
train_loss: 0.253259 | avg_loss: 0.253218 | gender_precise: 88.000000 (9113/10240) [80/186]
gender_loss: 0.23357656598091125
train_loss: 0.233577 | avg_loss: 0.252976 | gender_precise: 89.000000 (9229/10368) [81/186]
gender_loss: 0.23660892248153687
train_loss

train_loss: 0.129109 | avg_loss: 0.249067 | gender_precise: 89.000000 (15958/17920) [140/186]
gender_loss: 0.25953051447868347
train_loss: 0.259531 | avg_loss: 0.249141 | gender_precise: 89.000000 (16072/18048) [141/186]
gender_loss: 0.2018292099237442
train_loss: 0.201829 | avg_loss: 0.248808 | gender_precise: 89.000000 (16188/18176) [142/186]
gender_loss: 0.23034578561782837
train_loss: 0.230346 | avg_loss: 0.248679 | gender_precise: 89.000000 (16303/18304) [143/186]
gender_loss: 0.2305193841457367
train_loss: 0.230519 | avg_loss: 0.248552 | gender_precise: 89.000000 (16414/18432) [144/186]
gender_loss: 0.1971980333328247
train_loss: 0.197198 | avg_loss: 0.248198 | gender_precise: 89.000000 (16534/18560) [145/186]
gender_loss: 0.2137671709060669
train_loss: 0.213767 | avg_loss: 0.247962 | gender_precise: 89.000000 (16651/18688) [146/186]
gender_loss: 0.21601779758930206
train_loss: 0.216018 | avg_loss: 0.247745 | gender_precise: 89.000000 (16766/18816) [147/186]
gender_loss: 0.203596

train_loss: 0.174777 | avg_loss: 0.241195 | gender_precise: 89.000000 (2174/2432) [19/186]
gender_loss: 0.2635890543460846
train_loss: 0.263589 | avg_loss: 0.242314 | gender_precise: 89.000000 (2288/2560) [20/186]
gender_loss: 0.22506996989250183
train_loss: 0.225070 | avg_loss: 0.241493 | gender_precise: 89.000000 (2403/2688) [21/186]
gender_loss: 0.17380118370056152
train_loss: 0.173801 | avg_loss: 0.238416 | gender_precise: 89.000000 (2522/2816) [22/186]
gender_loss: 0.20562775433063507
train_loss: 0.205628 | avg_loss: 0.236991 | gender_precise: 89.000000 (2642/2944) [23/186]
gender_loss: 0.18245644867420197
train_loss: 0.182456 | avg_loss: 0.234718 | gender_precise: 89.000000 (2757/3072) [24/186]
gender_loss: 0.2809452414512634
train_loss: 0.280945 | avg_loss: 0.236567 | gender_precise: 89.000000 (2871/3200) [25/186]
gender_loss: 0.2625579237937927
train_loss: 0.262558 | avg_loss: 0.237567 | gender_precise: 89.000000 (2985/3328) [26/186]
gender_loss: 0.2589667737483978
train_loss: 

train_loss: 0.272170 | avg_loss: 0.236502 | gender_precise: 89.000000 (9869/11008) [86/186]
gender_loss: 0.20990726351737976
train_loss: 0.209907 | avg_loss: 0.236196 | gender_precise: 89.000000 (9984/11136) [87/186]
gender_loss: 0.15229357779026031
train_loss: 0.152294 | avg_loss: 0.235243 | gender_precise: 89.000000 (10106/11264) [88/186]
gender_loss: 0.15265733003616333
train_loss: 0.152657 | avg_loss: 0.234315 | gender_precise: 89.000000 (10223/11392) [89/186]
gender_loss: 0.09379303455352783
train_loss: 0.093793 | avg_loss: 0.232753 | gender_precise: 89.000000 (10346/11520) [90/186]
gender_loss: 0.3141777217388153
train_loss: 0.314178 | avg_loss: 0.233648 | gender_precise: 89.000000 (10456/11648) [91/186]
gender_loss: 0.19466377794742584
train_loss: 0.194664 | avg_loss: 0.233224 | gender_precise: 89.000000 (10572/11776) [92/186]
gender_loss: 0.19589197635650635
train_loss: 0.195892 | avg_loss: 0.232823 | gender_precise: 89.000000 (10690/11904) [93/186]
gender_loss: 0.2476167678833

train_loss: 0.216688 | avg_loss: 0.230501 | gender_precise: 89.000000 (17392/19328) [151/186]
gender_loss: 0.23247294127941132
train_loss: 0.232473 | avg_loss: 0.230514 | gender_precise: 89.000000 (17506/19456) [152/186]
gender_loss: 0.18652494251728058
train_loss: 0.186525 | avg_loss: 0.230226 | gender_precise: 89.000000 (17625/19584) [153/186]
gender_loss: 0.19300997257232666
train_loss: 0.193010 | avg_loss: 0.229985 | gender_precise: 89.000000 (17740/19712) [154/186]
gender_loss: 0.24178621172904968
train_loss: 0.241786 | avg_loss: 0.230061 | gender_precise: 90.000000 (17856/19840) [155/186]
gender_loss: 0.15296149253845215
train_loss: 0.152961 | avg_loss: 0.229566 | gender_precise: 90.000000 (17975/19968) [156/186]
gender_loss: 0.22947129607200623
train_loss: 0.229471 | avg_loss: 0.229566 | gender_precise: 90.000000 (18092/20096) [157/186]
gender_loss: 0.24041748046875
train_loss: 0.240417 | avg_loss: 0.229635 | gender_precise: 90.000000 (18205/20224) [158/186]
gender_loss: 0.24890

train_loss: 0.234142 | avg_loss: 0.225589 | gender_precise: 90.000000 (3484/3840) [30/186]
gender_loss: 0.19985570013523102
train_loss: 0.199856 | avg_loss: 0.224758 | gender_precise: 90.000000 (3600/3968) [31/186]
gender_loss: 0.231576070189476
train_loss: 0.231576 | avg_loss: 0.224972 | gender_precise: 90.000000 (3716/4096) [32/186]
gender_loss: 0.1827651560306549
train_loss: 0.182765 | avg_loss: 0.223693 | gender_precise: 90.000000 (3834/4224) [33/186]
gender_loss: 0.32010897994041443
train_loss: 0.320109 | avg_loss: 0.226528 | gender_precise: 90.000000 (3948/4352) [34/186]
gender_loss: 0.1552044302225113
train_loss: 0.155204 | avg_loss: 0.224491 | gender_precise: 90.000000 (4065/4480) [35/186]
gender_loss: 0.23164741694927216
train_loss: 0.231647 | avg_loss: 0.224689 | gender_precise: 90.000000 (4180/4608) [36/186]
gender_loss: 0.22630147635936737
train_loss: 0.226301 | avg_loss: 0.224733 | gender_precise: 90.000000 (4293/4736) [37/186]
gender_loss: 0.2564074397087097
train_loss: 0

train_loss: 0.166579 | avg_loss: 0.221046 | gender_precise: 90.000000 (11256/12416) [97/186]
gender_loss: 0.2835696339607239
train_loss: 0.283570 | avg_loss: 0.221684 | gender_precise: 90.000000 (11368/12544) [98/186]
gender_loss: 0.19875848293304443
train_loss: 0.198758 | avg_loss: 0.221452 | gender_precise: 90.000000 (11483/12672) [99/186]
gender_loss: 0.19460158050060272
train_loss: 0.194602 | avg_loss: 0.221184 | gender_precise: 90.000000 (11599/12800) [100/186]
gender_loss: 0.15435725450515747
train_loss: 0.154357 | avg_loss: 0.220522 | gender_precise: 90.000000 (11722/12928) [101/186]
gender_loss: 0.2818053364753723
train_loss: 0.281805 | avg_loss: 0.221123 | gender_precise: 90.000000 (11835/13056) [102/186]
gender_loss: 0.1925557404756546
train_loss: 0.192556 | avg_loss: 0.220846 | gender_precise: 90.000000 (11952/13184) [103/186]
gender_loss: 0.25864362716674805
train_loss: 0.258644 | avg_loss: 0.221209 | gender_precise: 90.000000 (12062/13312) [104/186]
gender_loss: 0.20496636

train_loss: 0.186569 | avg_loss: 0.220772 | gender_precise: 90.000000 (18765/20736) [162/186]
gender_loss: 0.1370554119348526
train_loss: 0.137055 | avg_loss: 0.220258 | gender_precise: 90.000000 (18884/20864) [163/186]
gender_loss: 0.3188832402229309
train_loss: 0.318883 | avg_loss: 0.220860 | gender_precise: 90.000000 (18994/20992) [164/186]
gender_loss: 0.21327313780784607
train_loss: 0.213273 | avg_loss: 0.220814 | gender_precise: 90.000000 (19110/21120) [165/186]
gender_loss: 0.23837226629257202
train_loss: 0.238372 | avg_loss: 0.220920 | gender_precise: 90.000000 (19223/21248) [166/186]
gender_loss: 0.30415794253349304
train_loss: 0.304158 | avg_loss: 0.221418 | gender_precise: 90.000000 (19333/21376) [167/186]
gender_loss: 0.2750428020954132
train_loss: 0.275043 | avg_loss: 0.221737 | gender_precise: 90.000000 (19441/21504) [168/186]
gender_loss: 0.19452346861362457
train_loss: 0.194523 | avg_loss: 0.221576 | gender_precise: 90.000000 (19557/21632) [169/186]
gender_loss: 0.21937

train_loss: 0.324362 | avg_loss: 0.201695 | gender_precise: 91.000000 (4907/5376) [42/186]
gender_loss: 0.2267947494983673
train_loss: 0.226795 | avg_loss: 0.202278 | gender_precise: 91.000000 (5024/5504) [43/186]
gender_loss: 0.18780189752578735
train_loss: 0.187802 | avg_loss: 0.201949 | gender_precise: 91.000000 (5143/5632) [44/186]
gender_loss: 0.19326958060264587
train_loss: 0.193270 | avg_loss: 0.201757 | gender_precise: 91.000000 (5258/5760) [45/186]
gender_loss: 0.17167869210243225
train_loss: 0.171679 | avg_loss: 0.201103 | gender_precise: 91.000000 (5376/5888) [46/186]
gender_loss: 0.17775772511959076
train_loss: 0.177758 | avg_loss: 0.200606 | gender_precise: 91.000000 (5494/6016) [47/186]
gender_loss: 0.295192688703537
train_loss: 0.295193 | avg_loss: 0.202577 | gender_precise: 91.000000 (5610/6144) [48/186]
gender_loss: 0.2308320701122284
train_loss: 0.230832 | avg_loss: 0.203153 | gender_precise: 91.000000 (5730/6272) [49/186]
gender_loss: 0.13377968966960907
train_loss: 

train_loss: 0.199778 | avg_loss: 0.210106 | gender_precise: 90.000000 (12574/13824) [108/186]
gender_loss: 0.20604848861694336
train_loss: 0.206048 | avg_loss: 0.210068 | gender_precise: 90.000000 (12692/13952) [109/186]
gender_loss: 0.20235905051231384
train_loss: 0.202359 | avg_loss: 0.209998 | gender_precise: 90.000000 (12812/14080) [110/186]
gender_loss: 0.17774313688278198
train_loss: 0.177743 | avg_loss: 0.209708 | gender_precise: 91.000000 (12933/14208) [111/186]
gender_loss: 0.25385263562202454
train_loss: 0.253853 | avg_loss: 0.210102 | gender_precise: 91.000000 (13047/14336) [112/186]
gender_loss: 0.27970972657203674
train_loss: 0.279710 | avg_loss: 0.210718 | gender_precise: 90.000000 (13156/14464) [113/186]
gender_loss: 0.21587315201759338
train_loss: 0.215873 | avg_loss: 0.210763 | gender_precise: 90.000000 (13269/14592) [114/186]
gender_loss: 0.19915196299552917
train_loss: 0.199152 | avg_loss: 0.210662 | gender_precise: 90.000000 (13385/14720) [115/186]
gender_loss: 0.17

train_loss: 0.193419 | avg_loss: 0.211423 | gender_precise: 90.000000 (20137/22144) [173/186]
gender_loss: 0.24388989806175232
train_loss: 0.243890 | avg_loss: 0.211610 | gender_precise: 90.000000 (20251/22272) [174/186]
gender_loss: 0.2728992998600006
train_loss: 0.272899 | avg_loss: 0.211960 | gender_precise: 90.000000 (20362/22400) [175/186]
gender_loss: 0.21734963357448578
train_loss: 0.217350 | avg_loss: 0.211990 | gender_precise: 90.000000 (20481/22528) [176/186]
gender_loss: 0.20179924368858337
train_loss: 0.201799 | avg_loss: 0.211933 | gender_precise: 90.000000 (20596/22656) [177/186]
gender_loss: 0.1646464616060257
train_loss: 0.164646 | avg_loss: 0.211667 | gender_precise: 90.000000 (20713/22784) [178/186]
gender_loss: 0.1549515575170517
train_loss: 0.154952 | avg_loss: 0.211350 | gender_precise: 90.000000 (20832/22912) [179/186]
gender_loss: 0.2360949069261551
train_loss: 0.236095 | avg_loss: 0.211488 | gender_precise: 90.000000 (20947/23040) [180/186]
gender_loss: 0.269897

## Test with a image

In [51]:
import cv2
from PIL import Image

In [50]:
img = cv2.imread("/home/neosai/Downloads/46402_1951-02-15_2009.jpg")

In [117]:
img = Image.open("/home/neosai/Downloads/download.jpeg")

In [118]:
input_img = transform_test(img)

In [120]:
input_img.shape

torch.Size([3, 150, 150])

In [121]:
input_img.unsqueeze_(0).shape

torch.Size([1, 3, 150, 150])

In [122]:
input_img.shape

torch.Size([1, 3, 150, 150])

In [123]:
a, b = net(input_img.cuda())

In [124]:
b

tensor([-3.8809], device='cuda:0', grad_fn=<ViewBackward>)

In [125]:
gender_preds = F.sigmoid(b)
gender_preds = (gender_preds > 0.5).int()



In [126]:
gender_preds

tensor([0], device='cuda:0', dtype=torch.int32)

In [1]:
import numpy as np

In [22]:
b = np.random.rand(128, 116)

In [23]:
b.shape

(128, 116)

In [8]:
a = np.sum(np.arange(1, 117) * b, 1)

In [9]:
a

array([3620.47253613])

In [27]:
[np.argmax(b[i]) for i in range(0, 116)]

[85,
 102,
 68,
 83,
 52,
 13,
 114,
 105,
 74,
 19,
 42,
 1,
 21,
 95,
 109,
 83,
 48,
 104,
 91,
 39,
 64,
 3,
 83,
 9,
 52,
 26,
 106,
 25,
 103,
 34,
 73,
 27,
 82,
 54,
 81,
 34,
 71,
 112,
 24,
 77,
 21,
 39,
 17,
 83,
 102,
 21,
 106,
 101,
 114,
 101,
 69,
 44,
 77,
 103,
 12,
 47,
 101,
 30,
 84,
 48,
 50,
 100,
 38,
 8,
 105,
 97,
 58,
 13,
 91,
 47,
 93,
 97,
 61,
 16,
 41,
 114,
 105,
 102,
 39,
 42,
 87,
 106,
 70,
 13,
 51,
 106,
 51,
 55,
 104,
 48,
 90,
 64,
 40,
 99,
 5,
 0,
 45,
 90,
 5,
 16,
 1,
 39,
 67,
 56,
 62,
 112,
 35,
 42,
 86,
 11,
 41,
 102,
 19,
 78,
 71,
 74]

In [25]:
b[0].argmax

<function ndarray.argmax>

In [26]:
np.argmax(b[0])

85

In [21]:
b[0][55]

0.7162959670689439

array([[0.48304033, 0.77516863, 0.97833413, 0.3871517 , 0.24357855,
        0.15773135, 0.53445515, 0.92273702, 0.13137669, 0.14347272,
        0.39719973, 0.95451412, 0.04828658, 0.97270312, 0.11151566,
        0.59834198, 0.20499689, 0.76103823, 0.11478081, 0.64958808,
        0.67512689, 0.85111405, 0.96028944, 0.50181462, 0.83416862,
        0.99322512, 0.60526135, 0.11691834, 0.27786215, 0.47205074,
        0.77459759, 0.06659621, 0.94212423, 0.27200115, 0.13643705,
        0.42923546, 0.59603223, 0.10297902, 0.3353779 , 0.73987667,
        0.70128138, 0.98503977, 0.99805381, 0.11922939, 0.72293112,
        0.24500441, 0.57040028, 0.03973017, 0.39401589, 0.05505866,
        0.44677099, 0.14949467, 0.79249604, 0.61423846, 0.8821324 ,
        0.99992684, 0.82861075, 0.99426677, 0.36245055, 0.19003802,
        0.91026761, 0.21190709, 0.21237699, 0.68614634, 0.65465492,
        0.9076976 , 0.13159407, 0.80854219, 0.24745792, 0.92374191,
        0.90822876, 0.06291042, 0.6143676 , 0.22

In [28]:
import torch

In [29]:

a = torch.rand(128, 116)

In [30]:
a

tensor([[0.7913, 0.4329, 0.3026,  ..., 0.6049, 0.1280, 0.6655],
        [0.0876, 0.1363, 0.9883,  ..., 0.1942, 0.0160, 0.9075],
        [0.3080, 0.1526, 0.0819,  ..., 0.0956, 0.0059, 0.6564],
        ...,
        [0.2536, 0.8113, 0.0131,  ..., 0.8667, 0.0907, 0.9448],
        [0.2560, 0.2959, 0.4659,  ..., 0.0078, 0.2909, 0.6288],
        [0.0643, 0.0982, 0.2114,  ..., 0.5803, 0.1505, 0.2233]])

In [31]:
b = [torch.argmax(a[i]) for i in range(0, 116)]

In [32]:
len(b)

116

In [33]:
type(b)

list

In [35]:
c = torch.from_numpy(np.array(b))

In [37]:
c.shape

torch.Size([116])