-
Notifications
You must be signed in to change notification settings - Fork 0
/
main_carrada.py
147 lines (124 loc) · 5.96 KB
/
main_carrada.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import argparse
import yaml
import os
import torch
from torch.utils.data import DataLoader, ConcatDataset
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
from pytorch_lightning.loggers import TensorBoardLogger
from utils import update_config_dict, get_transformations, get_models
from datasets import SequenceCarradaDataset, CarradaDataset, Carrada, CarradaDatasetRangeDoppler, CarradaDatasetRangeAngle
def parse_args():
parser = argparse.ArgumentParser(description='MV-RECORD')
parser.add_argument('--config', type=str, help='configuration file path')
parser.add_argument('--deterministic', action='store_true', help='Apply deterministic CUDA ops for reproducibility')
parser.add_argument('--resume_ckpt', type=str, help='Path to the checkpoint to resume the training')
args = parser.parse_args()
return args
args = parse_args()
deterministic = False
seed = 42
pl.seed_everything(seed=seed, workers=True)
config = yaml.load(open(args.config, 'r'), Loader=yaml.FullLoader)
config = update_config_dict(config, args)
model_cfg = config['model_cfg']
train_cfg = config['train_cfg']
dataset_cfg = config['dataset_cfg']
n_frames = model_cfg['win_size']
# Load model
model_instance = get_models(model_cfg)
model_name = model_cfg['name']
# Load datasets
if model_name == 'RECORD-RD':
dataset_loader = CarradaDatasetRangeDoppler
elif model_name == 'RECORD-RA':
dataset_loader = CarradaDatasetRangeAngle
elif model_name in ('MV-RECORD', 'MV-RECORD-S'):
dataset_loader = CarradaDataset
else:
raise ValueError
# Train dataset
train_dataset = Carrada(config).get('Train')
seq_dataloader = DataLoader(SequenceCarradaDataset(train_dataset), batch_size=1,
shuffle=True, num_workers=0)
all_datasets = []
transform_names = config['train_cfg']['transformations'].split(',')
transformations = get_transformations(transform_names=transform_names, sizes=(config['model_cfg']['w_size'], config['model_cfg']['h_size']))
for _, data in enumerate(seq_dataloader):
seq_name, seq = data
path_to_frames = os.path.join(dataset_cfg['carrada'], seq_name[0])
all_datasets.append(dataset_loader(seq,
'dense',
path_to_frames,
process_signal=True, transformations=transformations,
n_frames=n_frames, add_temp=True))
train_dataloader = DataLoader(ConcatDataset(all_datasets), batch_size=train_cfg['batch_size'], shuffle=True,
num_workers=0, pin_memory=True)
# Val dataset
val_dataset = Carrada(config).get('Validation')
seq_dataloader = DataLoader(SequenceCarradaDataset(val_dataset), batch_size=1,
shuffle=False, num_workers=0)
all_datasets = []
for _, data in enumerate(seq_dataloader):
seq_name, seq = data
path_to_frames = os.path.join(dataset_cfg['carrada'], seq_name[0])
all_datasets.append(dataset_loader(seq,
'dense',
path_to_frames,
process_signal=True,
n_frames=n_frames, add_temp=True))
val_dataloader = DataLoader(ConcatDataset(all_datasets), batch_size=train_cfg['batch_size'], shuffle=False,
num_workers=0, pin_memory=True)
# Logger
log_dir = train_cfg['ckpt_dir']
if not os.path.exists(log_dir):
os.makedirs(log_dir)
logger = TensorBoardLogger(save_dir=train_cfg['ckpt_dir'], name=model_name, default_hp_metric=False)
# Add some entries to the configuration dict to get the logs
run_dir = logger.experiment.log_dir
config['train_cfg']['run_dir'] = run_dir
# Load Pytorch Lightning models
if model_name == 'MV-RECORD':
from executors import MVRECORDExecutor as Model
model = Model(config, model=model_instance)
metric_to_monitor = 'val_metrics/global_prec'
elif model_name in ('RECORD-RD', 'RECORD-RA'):
from executors import SVRECORDExecutor as Model
model = Model(config, model=model_instance, view=model_cfg['view'])
metric_to_monitor = 'val_metrics/rd_prec' if model_cfg['view'] == 'range_doppler' else 'val_metrics/ra_prec'
else:
raise ValueError
# Callbacks
checkpoint_callback = ModelCheckpoint(dirpath=run_dir, monitor=metric_to_monitor, mode="max",
save_last=True, save_top_k=3)
lr_tracker = LearningRateMonitor()
early_stop = EarlyStopping(monitor=metric_to_monitor, patience=20, mode='max')
callbacks = [checkpoint_callback, lr_tracker, early_stop]
if torch.cuda.is_available():
print('CUDA available, use GPU')
accelerator = 'gpu'
else:
print('WARNING: CUDA not available, use CPU')
accelerator = 'cpu'
trainer = pl.Trainer(logger=logger, callbacks=callbacks, accelerator=accelerator, devices=1,
max_epochs=train_cfg['n_epoch'],
accumulate_grad_batches=train_cfg['accumulate_grad'])
print('Start training')
trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader, ckpt_path=args.resume_ckpt)
print('Test model')
# Test dataset
test_dataset = Carrada(config).get('Test')
seq_dataloader = DataLoader(SequenceCarradaDataset(test_dataset), batch_size=1,
shuffle=False, num_workers=4)
all_datasets = []
for _, data in enumerate(seq_dataloader):
seq_name, seq = data
path_to_frames = os.path.join(dataset_cfg['carrada'], seq_name[0])
all_datasets.append(dataset_loader(seq,
'dense',
path_to_frames,
process_signal=True,
n_frames=n_frames, add_temp=True))
test_dataloader = DataLoader(ConcatDataset(all_datasets), batch_size=train_cfg['batch_size'], shuffle=False,
num_workers=4, pin_memory=True)
trainer.test(model, dataloaders=test_dataloader, ckpt_path='best')