from ezflow.models import build_model from ezflow.data import DataloaderCreator from ezflow.engine import get_training_cfg from ezflow.engine import Trainer # Instantiate the dataloader creator DATASET_PATH = "/home/s/Dataset/KITTI_2015/" train_loader_creator = DataloaderCreator( batch_size=16, num_workers=1, pin_memory=True, shuffle=True, ) val_loader_creator = DataloaderCreator( batch_size=16, num_workers=1, pin_memory=True, shuffle=False, ) # Add dataset(s) to the dataloader creator train_loader_creator.add_Kitti( root_dir=("%s" % DATASET_PATH), split="training", crop=True, crop_size=(384, 384), crop_type="random", augment=True, aug_params={ "color_aug_params": { "aug_prob": 0.3, "contrast": 0.5 }, "spatial_aug_params": { "aug_prob": 0.2, "flip": True } } ) val_loader_creator.add_Kitti( root_dir=("%s" % DATASET_PATH), split="validation", crop=False, augment=False ) # Create the dataloaders train_loader = train_loader_creator.get_dataloader() val_loader = val_loader_creator.get_dataloader() # Create the model model = build_model("RAFT", default=True) # Create the trainer from configs/trainers training_cfg = get_training_cfg(cfg_path="base.yaml", custom=False) training_cfg.CKPT_DIR = "./checkpoints" trainer = Trainer( cfg=training_cfg, model=model, train_loader=train_loader, val_loader=val_loader ) # Train trainer.train()