In [1]:
import numpy as np 
import math
import argparse
import os

import torch
from torch import nn, optim
import torch.utils.tensorboard as tensorboard

from torch.utils.data import DataLoader, Dataset
from torchvision import transforms

from dataset.config import *
from dataset import *
from dataset.config import voc as cfg
from dataset.voc0712 import VOCDetection
from ssd.ssd import build_ssd
from ssd.multiloss import MultiLoss
from transform.augmentation import SSDAugmentation



In [2]:
def train():
    losses, losses_loc, losses_conf = 0, 0, 0

    for i, (imgs, targets) in enumerate(train_loader):
        img = imgs.to(device)
        targets = [{k: y.to(device) for k, y in t.items() } for t in targets]

        outputs = model(img)
        loss_loc, loss_conf = criterion(outputs, targets)
        loss = loss_loc + args.alpha * loss_conf

        model.zero_grad()
        loss.backward()
        optimizer.step()

        # loss, acc
        losses_loc += loss_loc.item()
        losses_conf += loss_conf.item()
        losses += loss.item()

    return losses / len_train, losses_loc / len_train, losses_conf / len_train

In [3]:
def valid():
    losses, losses_loc, losses_conf = 0, 0, 0

    for i, (imgs, targets) in enumerate(test_loader):
        img = imgs.to(device)
        targets = [{k: y.to(device) for k, y in t.items() } for t in targets]

        outputs = model(img)
        loss_loc, loss_conf = criterion(outputs, targets)
        loss = loss_loc + args.alpha * loss_conf

        model.zero_grad()
        loss.backward()
        optimizer.step()

        # loss, acc
        losses_loc += loss_loc.item()
        losses_conf += loss_conf.item()
        losses += loss.item()

    return losses / len_test, losses_loc / len_test, losses_conf / len_test

In [4]:
def init_weight(model):
    for m in model.modules():
        if isinstance(m, nn.Conv2d):
            nn.init.xavier_uniform_(m.weight.data)
            m.bias.data.zero_()

        elif isinstance(m, nn.Linear):
            nn.init.normal_(m.weight, 0, .01)
            nn.init.constant_(m.bias, 0)

In [5]:
def param_decay(epoch):
    if epoch < 100:
        return args.lr
    elif epoch < 150:
        return args.lr * .1
    else:
        return args.lr * .01

In [None]:
if __name__=='__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('-root', type=str, default='/Users/miyasatotakaya/Datasets/pascalvoc/VOCdevkit/')
    parser.add_argument('-gpu', type=bool, default=False, help='use gpu or not')
    parser.add_argument('-lr', type=float, default=.01, help='initial learning rate')
    parser.add_argument('-alpha', type=float, default=1., help='initial learning rate')
    parser.add_argument('-batch', type=int, default=32, help='batch size for dataloader')
    parser.add_argument('-size', type=int, default=32, help='image size for datasets')
    parser.add_argument('-ch', type=int, default=3, help='input channels')
    parser.add_argument('-class_num', type=int, default=10, help='data class')
    parser.add_argument('-epoch', type=int, default=200, help='training epoch')
    parser.add_argument('-resume', type=str, help='load weight')
    parser.add_argument('-save', type=str, default='save/ckpt.pth', help='saved models')
    parser.add_argument('-fmodel', type=str, default='save/fmodel', help='final saved models')
    parser.add_argument('-lr_scheduler')
    args = parser.parse_args()

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    os.makedirs('save', exist_ok=True)
    
    ## Loading_data
    train_data = VOCDetection(args.root, transform=SSDAugmentation(cfg['min_dim'], MEANS), image_sets=[('2012', 'train')])
    test_data = VOCDetection(args.root, transform=SSDAugmentation(cfg['min_dim'], MEANS), image_sets=[('2012', 'val')])
    
    train_loader = DataLoader(train_data, batch_size=4, shuffle=True, num_workers=2, collate_fn=detection_collate)
    test_loader = DataLoader(test_data, batch_size=4, shuffle=False, num_workers=2, collate_fn=detection_collate)
    
    len_train = len(train_data)
    len_test = len(test_data)
    
    ## Build model
    model = build_ssd('train', cfg['min_dim'], cfg['class_num'])
    
    ## load weights or init weights
    if args.resume:
        checkpoints = torch.load(args.resume)
        weights = checkpoints['weights']
        start = checkpoints['epoch']
        max_loss = checkpoints['loss']
        model.load_state_dict(weights)
        
    else:
        model.apply(init_weight)
        
        start = 0
        max_loss = math.inf
    
    model = model.to(device)
    
    ##  setting
    criterion = MultiLoss(num_classes=cfg['class_num'], device=device)
    optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=.9)
    scheduler = optim.lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=param_decay)
    
    ## training model
    for epoch in range(start, args.epoch):
        tl, tl_loc, tl_conf = train()
        vl, vl_loc, vl_conf = valid()
        
        print(epoch + 1, tl, tl_loc, tl_conf, vl, vl_loc, vl_conf)
        
        writer.add_scalars( 'data/loss', 
                           {'train/loss': tl, 'train/loss_loc': tl_loc, 'train/loss_conf': tl_conf,
                            'val/loss': vl, 'val/loss_loc': vl_loc, 'val/loss_conf': vl_conf},
                           epoch + 1 )
        
        torch.save({ 'weights': model.state_dict(), 'loss': vl, 'epoch': epoch}, args.save)
        if max_loss > vl:
            torch.save({ 'weights': model.state_dict(), 'loss': vl, 'epoch': epoch}, args.fmodel + '.pth')
    
    ##  suii
    writer.export_scalars_to_json(args.fmodel + '.json')
    writer.close()

[{'boxes': tensor([[0.2200, 0.1365, 0.8271, 0.9948]]), 'labels': tensor([14])}, {'boxes': tensor([[0.3534, 0.4939, 0.4007, 0.7475],
        [0.3981, 0.4839, 0.4875, 0.7497],
        [0.4325, 0.4861, 0.4867, 0.7653],
        [0.3646, 0.5217, 0.4282, 0.7620],
        [0.2941, 0.4939, 0.3594, 0.7631],
        [0.3052, 0.5239, 0.3663, 0.7709]]), 'labels': tensor([14, 14, 14, 14, 14, 14])}, {'boxes': tensor([[0.6203, 0.3220, 1.0000, 0.8452],
        [0.6842, 0.0588, 1.0000, 0.3406]]), 'labels': tensor([14, 15])}, {'boxes': tensor([[0.0000, 0.6970, 0.3854, 1.0000],
        [0.9101, 0.9129, 0.9465, 1.0000]]), 'labels': tensor([ 5, 14])}]
torch.Size([4, 512, 40, 40])
torch.Size([4, 1024, 20, 20])
torch.Size([4, 512, 10, 10])
torch.Size([4, 256, 5, 5])
torch.Size([4, 256, 3, 3])
torch.Size([4, 256, 1, 1])
torch.Size([4, 6400, 21]) torch.Size([4, 6400, 4])
torch.Size([4, 2400, 21]) torch.Size([4, 2400, 4])
torch.Size([4, 600, 21]) torch.Size([4, 600, 4])
torch.Size([4, 150, 21]) torch.Size([4, 1