In [40]:
import torch
import torchvision
import torchvision.transforms as transforms
from torch.autograd import Variable
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from cnn_finetune import make_model

from net import AGNet
from loss import AGLoss

from torch.autograd import Variable
from datagen import ListDataset

import os

In [41]:
batch_size = 32
test_batch_size = 64
epochs = 50
learning_rate = 0.01
momentum = 0.9
dropout = 0.2
start_epoch = 0
best_correct = 0
use_cuda = torch.cuda.is_available()
device = torch.device('cuda' if use_cuda else 'cpu')

In [42]:
transform_train = transforms.Compose([
    transforms.CenterCrop(150),
    transforms.RandomCrop(150, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.485,0.456,0.406), (0.229,0.224,0.225)
    )
])
trainset = ListDataset(root='../data/UTKFace/', transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=8)

transform_test = transforms.Compose([
    transforms.CenterCrop(150),
    transforms.ToTensor(),
    transforms.Normalize((0.485,0.456,0.406), (0.229,0.224,0.225))
])
testset = ListDataset(root='../data/test/', transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=8)

In [43]:
net = AGNet()
net.cuda()
criterion = AGLoss()
# criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=momentum, weight_decay=1e-4)

In [44]:
# def accuracy(age_preds, age_targets, gender_preds, gender_targets):
def accuracy(gender_preds, gender_targets):

#     AGE_TOLERANCE = 3
#     age_prob = F.softmax(age_preds)
#     age_expect = torch.sum(Variable(torch.arange(1, 117)).cuda() * age_prob, 1)
    
#     age_correct = ((age_expect - age_targets.float()).abs() < AGE_TOLERANCE).int().sum().cpu().data[0]
    
    gender_preds = F.sigmoid(gender_preds)
    gender_preds = (gender_preds > 0.5).int()
    gender_correct = (gender_preds == gender_targets.int()).int().cpu().sum().data[0]
    return gender_correct
#     return age_correct, gender_correct

In [45]:
def train(epoch):
    net.train()
    train_loss = 0
    total = 0
    age_correct = 0
    gender_correct = 0
    for batch_idx, (inputs, age_targets, gender_targets) in enumerate(trainloader):
        inputs = Variable(inputs.cuda())
        age_targets = Variable(age_targets.cuda())
        gender_targets = Variable(gender_targets.cuda())
        
        # Zero the parameter gradients
        optimizer.zero_grad()
        
        # 
        age_preds, gender_preds = net(inputs)
        loss = criterion(gender_preds, gender_targets)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.data[0]
        
#         age_correct_i, gender_correct_i = accuracy(age_preds,
#                                                   age_targets, gender_preds,
#                                                   gender_targets)
        gender_correct_i = accuracy(gender_preds, gender_targets)
#         age_correct += age_correct_i
        gender_correct += gender_correct_i
        total += len(inputs)
        print('train_loss: %f | avg_loss: %f | gender_precise: %f (%d/%d) [%d/%d]'
             % (loss.data[0], train_loss/(batch_idx+1),
               100.*gender_correct/total, gender_correct, total,
               batch_idx+1, len(trainloader)))
#         print('train_loss: %.3f | avg_loss: %.3f | age_prec: %.3f (%d/%d) | gender_prec: %.3f (%d/%d)  [%d/%d]'  \
#             % (loss.data[0], train_loss/(batch_idx+1),      \
#                100.*age_correct/total, age_correct, total,  \
#                100.*gender_correct/total, gender_correct, total,    \
#                batch_idx+1, len(trainloader)))

In [46]:
# Test
def test(epoch):
    print('\nTest')
    net.eval()
    test_loss = 0
    total = 0
    age_correct = 0
    gender_correct = 0
    for batch_idx, (inputs, age_targets, gender_targets) in enumerate(testloader):
        inputs = Variable(inputs.cuda())
        age_targets = Variable(age_targets.cuda())
        gender_targets = Variable(gender_targets.cuda())

#         age_preds, gender_preds = net(inputs)
        age_preds, gender_preds = net(inputs)
#         loss = criterion(age_preds, age_targets, gender_preds, gender_targets)
#         loss = criterion(gender_preds, gender_targets)

#         test_loss += loss.data[0]
        gender_correct_i = accuracy(gender_preds, gender_targets)
#         age_correct_i, gender_correct_i = accuracy(
#             age_preds, age_targets, gender_preds, gender_targets)
#         age_correct += age_correct_i
        gender_correct += gender_correct_i
        total += len(inputs)
        print('gender_prec: %f (%d/%d) [%d/%d]' % (100.*gender_correct/total, gender_correct, total, batch_idx+1, len(testloader)))
#         print('test_loss: %f | avg_loss: %f | age_prec: %f (%d/%d) | gender_prec: %f (%d/%d)  [%d/%d]' \
#             % (loss.data[0], test_loss/(batch_idx+1),      \
#                100.*age_correct/total, age_correct, total,  \
#                100.*gender_correct/total, gender_correct, total, \
#                batch_idx+1, len(trainloader)))

    # Save checkpoint
    global best_correct
#     if age_correct + gender_correct > best_correct:
    if gender_correct > best_correct:
        print('Saving..')
        best_correct = gender_correct
        state = {
            'net': net.state_dict(),
            'correct': best_correct,
            'epoch': epoch,
        }
        if not os.path.isdir('../weights/checkpoint'):
            os.mkdir('../weights/checkpoint')
        torch.save(state, '..weights/checkpoint/ckpt.pth')

In [None]:
for epoch in range(start_epoch, start_epoch+10):
    print("Number epoch: {}".format(epoch))
    train(epoch)
    test(epoch)

Number epoch: 0


  print("gender_loss: {}".format(gender_loss.data[0]))


gender_loss: 0.6942108273506165


  if sys.path[0] == '':


train_loss: 0.694211 | avg_loss: 0.694211 | gender_precise: 39.000000 (50/128) [1/186]
gender_loss: 0.6932098269462585
train_loss: 0.693210 | avg_loss: 0.693710 | gender_precise: 44.000000 (113/256) [2/186]
gender_loss: 0.6929138898849487
train_loss: 0.692914 | avg_loss: 0.693445 | gender_precise: 47.000000 (181/384) [3/186]
gender_loss: 0.6931552290916443
train_loss: 0.693155 | avg_loss: 0.693372 | gender_precise: 47.000000 (245/512) [4/186]
gender_loss: 0.6931732296943665
train_loss: 0.693173 | avg_loss: 0.693333 | gender_precise: 48.000000 (308/640) [5/186]
gender_loss: 0.6933927536010742
train_loss: 0.693393 | avg_loss: 0.693343 | gender_precise: 47.000000 (364/768) [6/186]
gender_loss: 0.6932075023651123
train_loss: 0.693208 | avg_loss: 0.693323 | gender_precise: 47.000000 (424/896) [7/186]
gender_loss: 0.693169116973877
train_loss: 0.693169 | avg_loss: 0.693304 | gender_precise: 47.000000 (483/1024) [8/186]
gender_loss: 0.693134605884552
train_loss: 0.693135 | avg_loss: 0.693285 

train_loss: 0.694040 | avg_loss: 0.692051 | gender_precise: 52.000000 (4536/8704) [68/186]
gender_loss: 0.6959288120269775
train_loss: 0.695929 | avg_loss: 0.692107 | gender_precise: 52.000000 (4599/8832) [69/186]
gender_loss: 0.6902096271514893
train_loss: 0.690210 | avg_loss: 0.692080 | gender_precise: 52.000000 (4668/8960) [70/186]
gender_loss: 0.6978049278259277
train_loss: 0.697805 | avg_loss: 0.692160 | gender_precise: 52.000000 (4729/9088) [71/186]
gender_loss: 0.6846057772636414
train_loss: 0.684606 | avg_loss: 0.692055 | gender_precise: 52.000000 (4804/9216) [72/186]
gender_loss: 0.6855738759040833
train_loss: 0.685574 | avg_loss: 0.691966 | gender_precise: 52.000000 (4878/9344) [73/186]
gender_loss: 0.6949324011802673
train_loss: 0.694932 | avg_loss: 0.692007 | gender_precise: 52.000000 (4942/9472) [74/186]
gender_loss: 0.6930127143859863
train_loss: 0.693013 | avg_loss: 0.692020 | gender_precise: 52.000000 (5008/9600) [75/186]
gender_loss: 0.695817232131958
train_loss: 0.695

train_loss: 0.693810 | avg_loss: 0.692454 | gender_precise: 51.000000 (8882/17152) [134/186]
gender_loss: 0.6890815496444702
train_loss: 0.689082 | avg_loss: 0.692429 | gender_precise: 51.000000 (8956/17280) [135/186]
gender_loss: 0.6939033269882202
train_loss: 0.693903 | avg_loss: 0.692439 | gender_precise: 51.000000 (9019/17408) [136/186]
gender_loss: 0.6973665952682495
train_loss: 0.697367 | avg_loss: 0.692475 | gender_precise: 51.000000 (9074/17536) [137/186]
gender_loss: 0.6951184272766113
train_loss: 0.695118 | avg_loss: 0.692495 | gender_precise: 51.000000 (9134/17664) [138/186]
gender_loss: 0.6882700324058533
train_loss: 0.688270 | avg_loss: 0.692464 | gender_precise: 51.000000 (9210/17792) [139/186]
gender_loss: 0.6938259601593018
train_loss: 0.693826 | avg_loss: 0.692474 | gender_precise: 51.000000 (9273/17920) [140/186]
gender_loss: 0.6912797689437866
train_loss: 0.691280 | avg_loss: 0.692465 | gender_precise: 51.000000 (9342/18048) [141/186]
gender_loss: 0.6963562369346619


train_loss: 0.691919 | avg_loss: 0.692142 | gender_precise: 52.000000 (868/1664) [13/186]
gender_loss: 0.6894540190696716
train_loss: 0.689454 | avg_loss: 0.691951 | gender_precise: 52.000000 (938/1792) [14/186]
gender_loss: 0.684411883354187
train_loss: 0.684412 | avg_loss: 0.691448 | gender_precise: 52.000000 (1015/1920) [15/186]
gender_loss: 0.6873265504837036
train_loss: 0.687327 | avg_loss: 0.691190 | gender_precise: 53.000000 (1088/2048) [16/186]
gender_loss: 0.6839935779571533
train_loss: 0.683994 | avg_loss: 0.690767 | gender_precise: 53.000000 (1165/2176) [17/186]
gender_loss: 0.693673849105835
train_loss: 0.693674 | avg_loss: 0.690928 | gender_precise: 53.000000 (1230/2304) [18/186]
gender_loss: 0.6894683837890625
train_loss: 0.689468 | avg_loss: 0.690852 | gender_precise: 53.000000 (1300/2432) [19/186]
gender_loss: 0.6927663683891296
train_loss: 0.692766 | avg_loss: 0.690947 | gender_precise: 53.000000 (1366/2560) [20/186]
gender_loss: 0.693443775177002
train_loss: 0.693444 

train_loss: 0.682809 | avg_loss: 0.691224 | gender_precise: 52.000000 (5422/10240) [80/186]
gender_loss: 0.686247706413269
train_loss: 0.686248 | avg_loss: 0.691163 | gender_precise: 52.000000 (5495/10368) [81/186]
gender_loss: 0.6826953887939453
train_loss: 0.682695 | avg_loss: 0.691059 | gender_precise: 53.000000 (5572/10496) [82/186]
gender_loss: 0.6926846504211426
train_loss: 0.692685 | avg_loss: 0.691079 | gender_precise: 53.000000 (5638/10624) [83/186]
gender_loss: 0.6932688355445862
train_loss: 0.693269 | avg_loss: 0.691105 | gender_precise: 53.000000 (5703/10752) [84/186]
gender_loss: 0.6951802372932434
train_loss: 0.695180 | avg_loss: 0.691153 | gender_precise: 52.000000 (5766/10880) [85/186]
gender_loss: 0.6799356937408447
train_loss: 0.679936 | avg_loss: 0.691022 | gender_precise: 53.000000 (5846/11008) [86/186]
gender_loss: 0.6889631152153015
train_loss: 0.688963 | avg_loss: 0.690999 | gender_precise: 53.000000 (5916/11136) [87/186]
gender_loss: 0.6860582232475281
train_los

train_loss: 0.685968 | avg_loss: 0.691177 | gender_precise: 52.000000 (9864/18688) [146/186]
gender_loss: 0.6893224716186523
train_loss: 0.689322 | avg_loss: 0.691164 | gender_precise: 52.000000 (9933/18816) [147/186]
gender_loss: 0.6908865571022034
train_loss: 0.690887 | avg_loss: 0.691162 | gender_precise: 52.000000 (10000/18944) [148/186]
gender_loss: 0.6979867815971375
train_loss: 0.697987 | avg_loss: 0.691208 | gender_precise: 52.000000 (10055/19072) [149/186]
gender_loss: 0.6931210160255432
train_loss: 0.693121 | avg_loss: 0.691221 | gender_precise: 52.000000 (10118/19200) [150/186]
gender_loss: 0.6950445175170898
train_loss: 0.695045 | avg_loss: 0.691246 | gender_precise: 52.000000 (10177/19328) [151/186]
gender_loss: 0.6912137269973755
train_loss: 0.691214 | avg_loss: 0.691246 | gender_precise: 52.000000 (10242/19456) [152/186]
gender_loss: 0.6963440775871277
train_loss: 0.696344 | avg_loss: 0.691279 | gender_precise: 52.000000 (10300/19584) [153/186]
gender_loss: 0.68722736835

train_loss: 0.689675 | avg_loss: 0.687683 | gender_precise: 60.000000 (2015/3328) [26/186]
gender_loss: 0.6835059523582458
train_loss: 0.683506 | avg_loss: 0.687528 | gender_precise: 60.000000 (2090/3456) [27/186]
gender_loss: 0.6833314895629883
train_loss: 0.683331 | avg_loss: 0.687379 | gender_precise: 60.000000 (2164/3584) [28/186]
gender_loss: 0.6872897744178772
train_loss: 0.687290 | avg_loss: 0.687375 | gender_precise: 60.000000 (2235/3712) [29/186]
gender_loss: 0.6788393259048462
train_loss: 0.678839 | avg_loss: 0.687091 | gender_precise: 60.000000 (2319/3840) [30/186]
gender_loss: 0.6916712522506714
train_loss: 0.691671 | avg_loss: 0.687239 | gender_precise: 60.000000 (2387/3968) [31/186]
gender_loss: 0.6816034317016602
train_loss: 0.681603 | avg_loss: 0.687063 | gender_precise: 60.000000 (2461/4096) [32/186]
gender_loss: 0.6876484155654907
train_loss: 0.687648 | avg_loss: 0.687080 | gender_precise: 60.000000 (2538/4224) [33/186]
gender_loss: 0.6844367384910583
train_loss: 0.68

train_loss: 0.603733 | avg_loss: 0.670236 | gender_precise: 62.000000 (7456/11904) [93/186]
gender_loss: 0.6167487502098083
train_loss: 0.616749 | avg_loss: 0.669667 | gender_precise: 62.000000 (7550/12032) [94/186]
gender_loss: 0.6434347629547119
train_loss: 0.643435 | avg_loss: 0.669391 | gender_precise: 62.000000 (7632/12160) [95/186]
gender_loss: 0.624903678894043
train_loss: 0.624904 | avg_loss: 0.668928 | gender_precise: 62.000000 (7722/12288) [96/186]
gender_loss: 0.5962358713150024
train_loss: 0.596236 | avg_loss: 0.668178 | gender_precise: 62.000000 (7818/12416) [97/186]
gender_loss: 0.6144238114356995
train_loss: 0.614424 | avg_loss: 0.667630 | gender_precise: 63.000000 (7905/12544) [98/186]
gender_loss: 0.6244215965270996
train_loss: 0.624422 | avg_loss: 0.667193 | gender_precise: 63.000000 (7989/12672) [99/186]
gender_loss: 0.5553866028785706
train_loss: 0.555387 | avg_loss: 0.666075 | gender_precise: 63.000000 (8089/12800) [100/186]
gender_loss: 0.6009172797203064
train_lo

train_loss: 0.507548 | avg_loss: 0.616467 | gender_precise: 67.000000 (13754/20352) [159/186]
gender_loss: 0.4338093400001526
train_loss: 0.433809 | avg_loss: 0.615325 | gender_precise: 67.000000 (13861/20480) [160/186]
gender_loss: 0.4566800892353058
train_loss: 0.456680 | avg_loss: 0.614340 | gender_precise: 67.000000 (13959/20608) [161/186]
gender_loss: 0.4553655982017517
train_loss: 0.455366 | avg_loss: 0.613359 | gender_precise: 67.000000 (14056/20736) [162/186]
gender_loss: 0.42951643466949463
train_loss: 0.429516 | avg_loss: 0.612231 | gender_precise: 67.000000 (14157/20864) [163/186]
gender_loss: 0.39813148975372314
train_loss: 0.398131 | avg_loss: 0.610925 | gender_precise: 67.000000 (14264/20992) [164/186]
gender_loss: 0.46134722232818604
train_loss: 0.461347 | avg_loss: 0.610019 | gender_precise: 68.000000 (14364/21120) [165/186]
gender_loss: 0.4579216241836548
train_loss: 0.457922 | avg_loss: 0.609103 | gender_precise: 68.000000 (14465/21248) [166/186]
gender_loss: 0.439752

train_loss: 0.375185 | avg_loss: 0.415610 | gender_precise: 81.000000 (4067/4992) [39/186]
gender_loss: 0.3291642367839813
train_loss: 0.329164 | avg_loss: 0.413449 | gender_precise: 81.000000 (4176/5120) [40/186]
gender_loss: 0.342203825712204
train_loss: 0.342204 | avg_loss: 0.411712 | gender_precise: 81.000000 (4283/5248) [41/186]
gender_loss: 0.2798180878162384
train_loss: 0.279818 | avg_loss: 0.408571 | gender_precise: 81.000000 (4398/5376) [42/186]
gender_loss: 0.34360629320144653
train_loss: 0.343606 | avg_loss: 0.407060 | gender_precise: 81.000000 (4507/5504) [43/186]
gender_loss: 0.38674548268318176
train_loss: 0.386745 | avg_loss: 0.406599 | gender_precise: 81.000000 (4614/5632) [44/186]
gender_loss: 0.2982025742530823
train_loss: 0.298203 | avg_loss: 0.404190 | gender_precise: 82.000000 (4725/5760) [45/186]
gender_loss: 0.37577080726623535
train_loss: 0.375771 | avg_loss: 0.403572 | gender_precise: 82.000000 (4830/5888) [46/186]
gender_loss: 0.3592092990875244
train_loss: 0.

train_loss: 0.335300 | avg_loss: 0.370294 | gender_precise: 83.000000 (11342/13568) [106/186]
gender_loss: 0.38431769609451294
train_loss: 0.384318 | avg_loss: 0.370425 | gender_precise: 83.000000 (11447/13696) [107/186]
gender_loss: 0.2924637496471405
train_loss: 0.292464 | avg_loss: 0.369703 | gender_precise: 83.000000 (11557/13824) [108/186]
gender_loss: 0.34408560395240784
train_loss: 0.344086 | avg_loss: 0.369468 | gender_precise: 83.000000 (11665/13952) [109/186]
gender_loss: 0.2321425825357437
train_loss: 0.232143 | avg_loss: 0.368220 | gender_precise: 83.000000 (11781/14080) [110/186]
gender_loss: 0.28845828771591187
train_loss: 0.288458 | avg_loss: 0.367501 | gender_precise: 83.000000 (11891/14208) [111/186]
gender_loss: 0.3432818353176117
train_loss: 0.343282 | avg_loss: 0.367285 | gender_precise: 83.000000 (11998/14336) [112/186]
gender_loss: 0.2802627682685852
train_loss: 0.280263 | avg_loss: 0.366515 | gender_precise: 83.000000 (12108/14464) [113/186]
gender_loss: 0.319698

train_loss: 0.299733 | avg_loss: 0.353268 | gender_precise: 84.000000 (18452/21888) [171/186]
gender_loss: 0.301461786031723
train_loss: 0.301462 | avg_loss: 0.352967 | gender_precise: 84.000000 (18566/22016) [172/186]
gender_loss: 0.3246181905269623
train_loss: 0.324618 | avg_loss: 0.352803 | gender_precise: 84.000000 (18675/22144) [173/186]
gender_loss: 0.35675206780433655
train_loss: 0.356752 | avg_loss: 0.352825 | gender_precise: 84.000000 (18785/22272) [174/186]
gender_loss: 0.25802260637283325
train_loss: 0.258023 | avg_loss: 0.352284 | gender_precise: 84.000000 (18901/22400) [175/186]
gender_loss: 0.35997718572616577
train_loss: 0.359977 | avg_loss: 0.352327 | gender_precise: 84.000000 (19010/22528) [176/186]
gender_loss: 0.3544620871543884
train_loss: 0.354462 | avg_loss: 0.352339 | gender_precise: 84.000000 (19117/22656) [177/186]
gender_loss: 0.3253265917301178
train_loss: 0.325327 | avg_loss: 0.352188 | gender_precise: 84.000000 (19228/22784) [178/186]
gender_loss: 0.3370715

train_loss: 0.380082 | avg_loss: 0.308142 | gender_precise: 86.000000 (5626/6528) [51/186]
gender_loss: 0.23457016050815582
train_loss: 0.234570 | avg_loss: 0.306727 | gender_precise: 86.000000 (5741/6656) [52/186]
gender_loss: 0.3918714225292206
train_loss: 0.391871 | avg_loss: 0.308334 | gender_precise: 86.000000 (5850/6784) [53/186]
gender_loss: 0.3059954047203064
train_loss: 0.305995 | avg_loss: 0.308290 | gender_precise: 86.000000 (5961/6912) [54/186]
gender_loss: 0.25717592239379883
train_loss: 0.257176 | avg_loss: 0.307361 | gender_precise: 86.000000 (6074/7040) [55/186]
gender_loss: 0.27995750308036804
train_loss: 0.279958 | avg_loss: 0.306872 | gender_precise: 86.000000 (6186/7168) [56/186]
gender_loss: 0.2967441976070404
train_loss: 0.296744 | avg_loss: 0.306694 | gender_precise: 86.000000 (6299/7296) [57/186]
gender_loss: 0.289835661649704
train_loss: 0.289836 | avg_loss: 0.306403 | gender_precise: 86.000000 (6410/7424) [58/186]
gender_loss: 0.3865159749984741
train_loss: 0.

train_loss: 0.235792 | avg_loss: 0.296315 | gender_precise: 86.000000 (13014/14976) [117/186]
gender_loss: 0.26432210206985474
train_loss: 0.264322 | avg_loss: 0.296044 | gender_precise: 86.000000 (13126/15104) [118/186]
gender_loss: 0.23833227157592773
train_loss: 0.238332 | avg_loss: 0.295559 | gender_precise: 86.000000 (13240/15232) [119/186]
gender_loss: 0.32182577252388
train_loss: 0.321826 | avg_loss: 0.295778 | gender_precise: 86.000000 (13348/15360) [120/186]
gender_loss: 0.30718323588371277
train_loss: 0.307183 | avg_loss: 0.295872 | gender_precise: 86.000000 (13457/15488) [121/186]
gender_loss: 0.3066820800304413
train_loss: 0.306682 | avg_loss: 0.295961 | gender_precise: 86.000000 (13570/15616) [122/186]
gender_loss: 0.34138813614845276
train_loss: 0.341388 | avg_loss: 0.296330 | gender_precise: 86.000000 (13679/15744) [123/186]
gender_loss: 0.2611689269542694
train_loss: 0.261169 | avg_loss: 0.296047 | gender_precise: 86.000000 (13790/15872) [124/186]
gender_loss: 0.3410722