In [2]:
import os
os.environ['CUDA_DEVICE_ORDER']='PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES']='1'
from os.path import abspath, dirname, join, basename, isdir
import json
from addict import Dict
import glob
import yaml
import numpy as np
import torch
import torch.nn as nn
import torch.functional as F
from torch.utils.data import DataLoader, DistributedSampler
from torch.distributed import init_process_group, destroy_process_group
import torch.multiprocessing as mp
import timm
import shutil
import time
from dataset import TotalSegmentatorData
from train import Train

In [3]:
CONFIG_FILE = "config/demo_config.yaml"

In [6]:
cfgs = Dict(yaml.load(open(CONFIG_FILE, "r"), Loader=yaml.Loader))
paths = cfgs.paths
data_cfgs = cfgs.dataset_params
optim_cfgs = cfgs.optimizer_params
schedule_cfgs = cfgs.scheduler_params
model_cfgs = cfgs.model_params
train_cfgs = cfgs.train_params
test_cfgs = cfgs.test_params

## Training

In [7]:
def SaveConfigFile(src, paths):
    results_path = dirname(abspath(paths.model_ckpts_dest))
    model_id = paths.model_ckpts_dest.split(".", 1)[0][-2:]
    filename = basename(src).split(".", 1)[0] + "_" + model_id + ".yaml"
    cp_path = join(results_path, filename)

    if not isdir(results_path):
        os.makedirs(results_path)
    shutil.copy(src, cp_path)

def main(args):
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    torch.manual_seed(42)
    cfgs = Dict(yaml.load(open(abspath(CONFIG_FILE), "r"), Loader=yaml.Loader))
    SaveConfigFile(CONFIG_FILE, cfgs.paths)

    train = Train(cfgs)
    start = time.time()
    train.RunDDP()
    end = time.time()
    print("\n")
    start_s = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(start))
    end_s = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(end))
    time_taken = end - start
    if time_taken > 3600.:
        divisor = 3600.
        suffix = "hr"
    else:
        divisor = 60.
        suffix = "min"
    print(f"Start Time: {start_s}, End Time: {end_s}, Total Time Taken: {(time_taken)/divisor:.3f} {suffix}")

## Testing