In [3]:
import argparse
import better_exceptions
from pathlib import Path
from collections import OrderedDict
from tqdm import tqdm
import numpy as np
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
from torch.optim.lr_scheduler import StepLR
import torch.utils.data
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
import pretrainedmodels
import pretrainedmodels.utils
from model import get_model
from dataset import FaceDataset
from defaults import _C as cfg
from train import train, AverageMeter, validate

In [7]:
model = get_model()
data_dir = 'appa-real-release'
start_epoch = 0
checkpoint_dir = Path('checkpoint')
checkpoint_dir.mkdir(parents=True, exist_ok=True)
resume_path = None
tensorboard_dir = None
opts = []
multi_gpu = False

if resume_path:
    if Path(resume_path).is_file():
        print("=> loading checkpoint '{}'".format(resume_path))
        checkpoint = torch.load(resume_path, map_location="cpu")
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        print("=> loaded checkpoint '{}' (epoch {})"
              .format(resume_path, checkpoint['epoch']))
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    else:
        print("=> no checkpoint found at '{}'".format(resume_path))

img_size = 56
age_stddev = 1.0
batch_size = 128
learning_rate = 1e-3
step_size = 20
decay_rate = 0.2
num_epochs = 80

In [8]:
train_dataset = FaceDataset(data_dir, "train", img_size=img_size, augment=True, age_stddev=age_stddev)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)

val_dataset = FaceDataset(data_dir, "valid", img_size=img_size, augment=False)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, drop_last=False)


optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
criterion = nn.CrossEntropyLoss().to(device)

scheduler = StepLR(optimizer, step_size=step_size, gamma=decay_rate, last_epoch=start_epoch - 1)
best_val_mae = 10000.0
train_writer = None

if tensorboard_dir is not None:
    opts_prefix = "_".join(opts)
    train_writer = SummaryWriter(log_dir=args.tensorboard + "/" + opts_prefix + "_train")
    val_writer = SummaryWriter(log_dir=args.tensorboard + "/" + opts_prefix + "_val")

In [None]:
for epoch in range(start_epoch, num_epochs):
    # train
    train_loss, train_acc = train(train_loader, model, criterion, optimizer, epoch, device)

    # validate
    val_loss, val_acc, val_mae = validate(val_loader, model, criterion, epoch, device)

    if tensorboard_dir is not None:
        train_writer.add_scalar("loss", train_loss, epoch)
        train_writer.add_scalar("acc", train_acc, epoch)
        val_writer.add_scalar("loss", val_loss, epoch)
        val_writer.add_scalar("acc", val_acc, epoch)
        val_writer.add_scalar("mae", val_mae, epoch)

    # checkpoint
    if val_mae < best_val_mae:
        print(f"=> [epoch {epoch:03d}] best val mae was improved from {best_val_mae:.3f} to {val_mae:.3f}")
        model_state_dict = model.module.state_dict() if multi_gpu else model.state_dict()
        torch.save(
            {
                'epoch': epoch + 1,
                'arch': cfg.MODEL.ARCH,
                'state_dict': model_state_dict,
                'optimizer_state_dict': optimizer.state_dict()
            },
            str(checkpoint_dir.joinpath("epoch{:03d}_{:.5f}_{:.4f}.pth".format(epoch, val_loss, val_mae)))
        )
        best_val_mae = val_mae
    else:
        print(f"=> [epoch {epoch:03d}] best val mae was not improved from {best_val_mae:.3f} ({val_mae:.3f})")

    # adjust learning rate
    scheduler.step()

print("=> training finished")
print(f"additional opts: {opts}")
print(f"best val mae: {best_val_mae:.3f}")

100%|████████| 31/31 [03:17<00:00,  6.36s/it, stage=train, epoch=0, loss=0.0331, acc=0.0358, correct=4, sample_num=128]
100%|████████████| 12/12 [00:26<00:00,  2.17s/it, stage=val, epoch=0, loss=0.036, acc=0.0333, correct=5, sample_num=92]


=> [epoch 000] best val mae was improved from 10000.000 to 11.110


100%|████████| 31/31 [03:28<00:00,  6.71s/it, stage=train, epoch=1, loss=0.0316, acc=0.0381, correct=1, sample_num=128]
100%|████████████| 12/12 [00:24<00:00,  2.06s/it, stage=val, epoch=1, loss=0.0317, acc=0.034, correct=3, sample_num=92]


=> [epoch 001] best val mae was improved from 11.110 to 10.569


100%|████████| 31/31 [03:15<00:00,  6.32s/it, stage=train, epoch=2, loss=0.0307, acc=0.0446, correct=6, sample_num=128]
100%|█████████████| 12/12 [00:24<00:00,  2.07s/it, stage=val, epoch=2, loss=0.0306, acc=0.05, correct=5, sample_num=92]


=> [epoch 002] best val mae was improved from 10.569 to 9.841


100%|██████████| 31/31 [03:16<00:00,  6.35s/it, stage=train, epoch=3, loss=0.03, acc=0.0431, correct=8, sample_num=128]
100%|███████████| 12/12 [00:25<00:00,  2.13s/it, stage=val, epoch=3, loss=0.0304, acc=0.0507, correct=2, sample_num=92]


=> [epoch 003] best val mae was improved from 9.841 to 8.723


100%|████████| 31/31 [03:14<00:00,  6.27s/it, stage=train, epoch=4, loss=0.0295, acc=0.0507, correct=8, sample_num=128]
100%|███████████| 12/12 [00:25<00:00,  2.10s/it, stage=val, epoch=4, loss=0.0293, acc=0.0513, correct=2, sample_num=92]


=> [epoch 004] best val mae was improved from 8.723 to 7.873


100%|████████| 31/31 [03:15<00:00,  6.30s/it, stage=train, epoch=5, loss=0.0293, acc=0.0547, correct=8, sample_num=128]
100%|███████████| 12/12 [00:24<00:00,  2.06s/it, stage=val, epoch=5, loss=0.0295, acc=0.0553, correct=5, sample_num=92]
  0%|                                                                                           | 0/31 [00:00<?, ?it/s]

=> [epoch 005] best val mae was not improved from 7.873 (8.007)


100%|████████| 31/31 [03:13<00:00,  6.23s/it, stage=train, epoch=6, loss=0.0288, acc=0.0585, correct=7, sample_num=128]
100%|████████████| 12/12 [00:24<00:00,  2.06s/it, stage=val, epoch=6, loss=0.0293, acc=0.058, correct=3, sample_num=92]


=> [epoch 006] best val mae was improved from 7.873 to 7.852


100%|████████| 31/31 [03:13<00:00,  6.23s/it, stage=train, epoch=7, loss=0.0285, acc=0.0585, correct=9, sample_num=128]
100%|███████████| 12/12 [00:24<00:00,  2.06s/it, stage=val, epoch=7, loss=0.0289, acc=0.0707, correct=5, sample_num=92]


=> [epoch 007] best val mae was improved from 7.852 to 7.464


100%|████████| 31/31 [03:12<00:00,  6.22s/it, stage=train, epoch=8, loss=0.0282, acc=0.0602, correct=7, sample_num=128]
100%|███████████| 12/12 [00:24<00:00,  2.05s/it, stage=val, epoch=8, loss=0.0287, acc=0.0693, correct=3, sample_num=92]
  0%|                                                                                           | 0/31 [00:00<?, ?it/s]

=> [epoch 008] best val mae was not improved from 7.464 (7.671)


100%|████████| 31/31 [03:12<00:00,  6.22s/it, stage=train, epoch=9, loss=0.0279, acc=0.0628, correct=8, sample_num=128]
100%|███████████| 12/12 [00:24<00:00,  2.06s/it, stage=val, epoch=9, loss=0.0285, acc=0.0713, correct=6, sample_num=92]
  0%|                                                                                           | 0/31 [00:00<?, ?it/s]

=> [epoch 009] best val mae was not improved from 7.464 (7.603)


100%|███████| 31/31 [03:29<00:00,  6.75s/it, stage=train, epoch=10, loss=0.0278, acc=0.0559, correct=6, sample_num=128]
100%|██████████| 12/12 [00:28<00:00,  2.39s/it, stage=val, epoch=10, loss=0.0287, acc=0.0567, correct=2, sample_num=92]
  0%|                                                                                           | 0/31 [00:00<?, ?it/s]

=> [epoch 010] best val mae was not improved from 7.464 (7.871)


 71%|████▉  | 22/31 [02:33<01:02,  6.94s/it, stage=train, epoch=11, loss=0.0277, acc=0.0661, correct=8, sample_num=128]