## Import required packages and functions

In [1]:
import os
import torch
from datetime import datetime
import argparse
from torch.utils.tensorboard import SummaryWriter

from DatasetClass import CarlaUnsupervised
from train import train, run_val
from utils_train import get_dataloaders, setup_logger

## Define Command Line Arguments + Define Setup Function (can be left unchanged)

In [6]:
def parse():
    parser = argparse.ArgumentParser()
    parser.add_argument("--lr", default=5e-4, type=float, help='Learning rate - default: 5e-5')
    parser.add_argument("--batch_size", default=1, type=int, help='Default=2')
    parser.add_argument("--epochs", default=10, type=int, help='Default=50')
    parser.add_argument("--patience", default=6, type=float, help='Default=3')
    parser.add_argument("--lr_scheduler_factor", default=0.5, type=float, help="Learning rate multiplier - default: 3")
    parser.add_argument("--alpha", default=0.25, type=float, help='Focal loss alpha - default: 0.25')
    parser.add_argument("--gamma", default=2.0, type=float, help='Focal loss gamma - default: 2')
    parser.add_argument("--l_M", default=0.005, type=float, help="hyper-param for motion seg loss")
    parser.add_argument("--l_C", default=0.3, type=float, help="hyper-param for consensus loss")
    parser.add_argument("--l_S", default=1.0, type=float, help="hyper-param for regularization")
    parser.add_argument("--load_chkpt", '-chkpt', default='0', type=str, help="Loading entire checkpoint path for inference/continue training")
    parser.add_argument("--dataset_fraction", default=0.001, type=float, help="fraction of dataset to be used")
    return parser

def train_setup(args):
    data_root = os.path.join(args.root, "datasets/CARLA/")
    log_root = os.path.join(args.root, "logs/")
    root_tb = os.path.join(args.root, "runs_temp/")
    args.root_tb = root_tb

    # define string needed for logging
    args.now = datetime.now()
    now_string = args.now.strftime(f"%d-%m-%Y_%H-%M_{args.batch_size}_{args.lr}_{args.epochs}")
    
    # setup logging
    args, logger = setup_logger(args, log_root, now_string)

    # log general info
    logger.info(f"running with lr={args.lr}, batch_size={args.batch_size}, epochs={args.epochs}, patience={args.patience}, lr_scheduler_factor={args.lr_scheduler_factor} alpha={args.alpha}, gamma={args.gamma}")
    logger.info(f"running on '{args.device}'")

    # define dataset and get data loaders
    # test=True test kwarg needed for plotting ground truth in tensorboard
    dataset = CarlaUnsupervised(data_root, test=True)
    train_loader, val_loader, test_loader = get_dataloaders(dataset, args)

    # initialize tensorboard
    args.writer = SummaryWriter(os.path.join(root_tb, now_string))

    return args, logger, train_loader, val_loader

## Train with default args except very small `dataset_fraction`

In [7]:
args = parse().parse_args("")
args.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
args.root = "/storage/remote/atcremers40/motion_seg/"

args, logger, train_loader, val_loader = train_setup(args)
train(args, train_loader, val_loader, None, logger)

[INFO] running with lr=0.0005, batch_size=1, epochs=10, patience=6, lr_scheduler_factor=0.5 alpha=0.25, gamma=2.0
[INFO] running on 'cuda:0'
[INFO] loaded model of type: <class 'ModelClass.UNET'>
train network ...
[INFO] Epoch [1/10] with lr 0.0005, train loss: 2356.51831, val loss: 2533.63501, IoU: 0.0919, ETA: 0.01 hrs
[INFO] Epoch [2/10] with lr 0.0005, train loss: 2137.64221, val loss: 2561.77686, IoU: 0.0919, ETA: 0.01 hrs
[INFO] Epoch [3/10] with lr 0.0005, train loss: 1879.51855, val loss: 2500.43164, IoU: 0.09785, ETA: 0.01 hrs
[INFO] Epoch [4/10] with lr 0.0005, train loss: 1651.33173, val loss: 3617.70996, IoU: 0.0925, ETA: 0.01 hrs
[INFO] Epoch [5/10] with lr 0.0005, train loss: 1457.88318, val loss: 10790.37793, IoU: 0.09295, ETA: 0.0 hrs
[INFO] Epoch [6/10] with lr 0.0005, train loss: 1348.54993, val loss: 11014.74902, IoU: 0.09504, ETA: 0.0 hrs
[INFO] Epoch [7/10] with lr 0.0005, train loss: 1220.72748, val loss: 3682.89233, IoU: 0.10493, ETA: 0.0 hrs
[INFO] Epoch [8/10] 