In [1]:
import torch
import time
import tqdm

import fast_rcnn
import datasets
import config
import utils

if __name__ == '__main__':
    
    train_dataset = datasets.get_train_dataset()
    train_data_loader = datasets.create_train_loader(train_dataset)
    valid_dataset = datasets.get_test_dataset()
    valid_data_loader = datasets.create_test_loader(valid_dataset)
    
    model = fast_rcnn.create_model(num_classes=config.NUM_CLASSES)
    model = model.to(config.DEVICE)
    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.SGD(params, lr=0.0005, momentum=0.9, weight_decay=0.0005)
    
    train_loss_hist = utils.Averager()
    val_loss_hist = utils.Averager()
    save_best_model = utils.SaveBestModel()
    
    train_loss_list = []
    val_loss_list = []
    
    MODEL_NAME = 'FasterRCNN_ResNet50'
    
    for epoch in range(config.NUM_EPOCHS):
        print(f"\nEPOCH {epoch+1} of {config.NUM_EPOCHS}")
        
        train_loss_hist.reset()
        val_loss_hist.reset()
    
        start = time.time()
    
        print('Training')
        progress_bar = tqdm.tqdm(train_data_loader, total=len(train_data_loader))
        for i, data in enumerate(progress_bar):
            optimizer.zero_grad()
            images, targets = data

            images = list(image.to(config.DEVICE) for image in images)
            targets = [{k: v.to(config.DEVICE) for k, v in t.items()} for t in targets]

            loss_dict = model(images, targets)
            
            losses = sum(loss for loss in loss_dict.values())
            loss_value = losses.item()
            train_loss_list.append(loss_value)
            train_loss_hist.send(loss_value)
            
            losses.backward()
            optimizer.step()
        
            # update the loss value beside the progress bar for each iteration
            progress_bar.set_description(desc=f"Loss: {loss_value:.4f}")
        
        # validation
        print('Validating')
        progress_bar = tqdm.tqdm(valid_data_loader, total=len(valid_data_loader))
        for i, data in enumerate(progress_bar):
            images, targets = data
    
            images = list(image.to(config.DEVICE) for image in images)
            targets = [{k: v.to(config.DEVICE) for k, v in t.items()} for t in targets]
    
            with torch.no_grad():
                loss_dict = model(images, targets)
    
            losses = sum(loss for loss in loss_dict.values())
            loss_value = losses.item()
            val_loss_list.append(loss_value)
            val_loss_hist.send(loss_value)
    
            progress_bar.set_description(desc=f"Loss: {loss_value:.4f}")
        
        print(f"Epoch #{epoch+1} train loss: {train_loss_hist.value:.3f}")   
        print(f"Epoch #{epoch+1} validation loss: {val_loss_hist.value:.3f}")   
        
        end = time.time()
        print(f"Took {((end - start) / 60):.3f} minutes for epoch {epoch}")
    
        save_best_model(
            val_loss_hist.value, epoch, model, optimizer
        )
        utils.save_model(epoch, model, optimizer)
        utils.save_loss_plot(config.OUT_DIR, train_loss_list, val_loss_list)
            
        # sleep for 2 seconds after each epoch
        time.sleep(2)


EPOCH 1 of 3
Training


Loss: 0.1625: 100%|██████████| 25/25 [02:55<00:00,  7.01s/it]


Validating


Loss: 0.2602: 100%|██████████| 25/25 [01:48<00:00,  4.32s/it]


Epoch #1 train loss: 0.333
Epoch #1 validation loss: 0.223
Took 4.722 minutes for epoch 0

Best validation loss: 0.22314531683921815

Saving best model for epoch: 1

SAVING PLOTS COMPLETE...

EPOCH 2 of 3
Training


Loss: 0.2462: 100%|██████████| 25/25 [02:58<00:00,  7.15s/it]


Validating


Loss: 0.2038: 100%|██████████| 25/25 [01:47<00:00,  4.31s/it]


Epoch #2 train loss: 0.203
Epoch #2 validation loss: 0.180
Took 4.776 minutes for epoch 1

Best validation loss: 0.17974496513605118

Saving best model for epoch: 2

SAVING PLOTS COMPLETE...

EPOCH 3 of 3
Training


Loss: 0.2608: 100%|██████████| 25/25 [02:59<00:00,  7.18s/it]


Validating


Loss: 0.2017: 100%|██████████| 25/25 [01:47<00:00,  4.29s/it]


Epoch #3 train loss: 0.203
Epoch #3 validation loss: 0.185
Took 4.780 minutes for epoch 2
SAVING PLOTS COMPLETE...
