In [1]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [2]:
from imports import *

from datasets import OpenImagesDataset
from retinanet import RetinaNet
from loss import FocalLoss
from utils.torch_utils import save_checkpoint, AverageMeter

In [3]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.485,0.456,0.406),(0.229,0.224,0.225))
])

In [4]:
train_dset = OpenImagesDataset(root='./data/train',
                            list_file ='./data/tmp/train_images_bbox.csv',
                            transform=transform, train=True, input_size=600)

In [5]:
train_loader = data.DataLoader(train_dset, batch_size=3, shuffle=True, num_workers=8, collate_fn=train_dset.collate_fn)

In [6]:
net = RetinaNet()
net.load_state_dict(torch.load('./model/net.pth'))
criterion = FocalLoss()
net.cuda()
criterion.cuda()
optimizer = optim.SGD(net.parameters(), lr=1e-4, momentum=0.9, weight_decay=1e-4)

In [7]:
def train_one_epoch(train_loader, model, loss_fn, opt, epoch, interval):
    batch_time = AverageMeter()
    data_time = AverageMeter()

    model.train()
    train_loss = 0
    no_of_batches = int(train_loader.dataset.num_samples/train_loader.batch_size) + 1

    end = time.time()

    for batch_idx, (inputs, loc_targets, cls_targets) in enumerate(train_loader):

        data_time.update(time.time() - end)

        inputs = inputs.cuda()
        loc_targets = loc_targets.cuda()
        cls_targets = cls_targets.cuda()

        opt.zero_grad()
        loc_preds, cls_preds = model(inputs)
        loss = loss_fn(loc_preds, loc_targets, cls_preds, cls_targets)
        loss.backward()
        opt.step()

        batch_time.update(time.time() - end)
        end = time.time()

        train_loss += loss.data[0]
        if(batch_idx%interval == 0):
            print(f'Train -> Batch : [{batch_idx}/{no_of_batches}]| Batch avg time :{batch_time.avg} \
            | Data_avg_time: {data_time.avg} | avg_loss: {train_loss/(batch_idx+1)}')
            
        if(batch_idx%(5000) == 0):
            save_checkpoint({
                'epoch': epoch,
                'state_dict': net.state_dict(),
                'best_val_loss': train_loss/(batch_idx+1),
                'optimizer' : optimizer.state_dict()
            }, is_best=True, fname=f'checkpoint_{epoch}_{batch_idx}.pth.tar')

In [None]:
train_one_epoch(train_loader, net, criterion, optimizer, 0, 100)

  "See the documentation of nn.Upsample for details.".format(mode))


Train -> Batch : [0/558327]| Batch avg time :2.4711053371429443             | Data_avg_time: 1.101675271987915 | avg_loss: 4.198276042938232
Train -> Batch : [100/558327]| Batch avg time :1.2880447264945154             | Data_avg_time: 0.02347696653687128 | avg_loss: 6.6026411056518555
Train -> Batch : [200/558327]| Batch avg time :1.2801134693088816             | Data_avg_time: 0.01420973426667019 | avg_loss: 6.595991611480713
Train -> Batch : [300/558327]| Batch avg time :1.2787376519453486             | Data_avg_time: 0.011103416994164552 | avg_loss: 6.398797512054443
Train -> Batch : [400/558327]| Batch avg time :1.2777136222382732             | Data_avg_time: 0.009557149059457374 | avg_loss: 6.264936447143555
Train -> Batch : [500/558327]| Batch avg time :1.2783410006654476             | Data_avg_time: 0.00865890784653837 | avg_loss: 6.201728820800781
Train -> Batch : [600/558327]| Batch avg time :1.2772861479126079             | Data_avg_time: 0.008036359177651302 | avg_loss: 6.1