# Imports

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = "0"  # Limit PyTorch to seeing 1 GPU only.
import yaml

import src.data_loader as data_loader
import src.graph_construction as gc
import src.graph_networks as gn
import src.merge_split_networks as msn
import src.delete_network as delnet
import src.network_config as nc
import src.train as train

# Tabletop Object Dataset

In [None]:
with open('configs/TOD.yaml', 'r') as f:
    TOD_data_loading_config = yaml.load(f)
dl = data_loader.get_TOD_train_dataloader(
    TOD_data_loading_config['TOD_filepath'],
    TOD_data_loading_config,
    batch_size=1,
    num_workers=8,
    shuffle=True
)

# Joint Training of SplitNet + DeleteNet

In [None]:
# Load configs
splitnet_config = nc.get_splitnet_config('configs/splitnet.yaml')
splitnet_train_config = nc.get_splitnet_train_config('configs/splitnet_joint_training.yaml')

deletenet_config = nc.get_deletenet_config('configs/deletenet.yaml')
deletenet_train_config = nc.get_deletenet_train_config('configs/deletenet_joint_training.yaml')

In [None]:
# Load ResNet50+FPN
rn50_fpn = gc.get_resnet50_fpn_model(
    pretrained=True,
    trainable_layer_names=splitnet_train_config['trainable_layer_names'],
)

# Initialize SplitNet model and trainer
sn_wrapper = msn.SplitNetWrapper(splitnet_config)
sn_trainer = train.SplitNetTrainer(sn_wrapper, rn50_fpn, splitnet_train_config)

# Initialize DeleteNet model and trainer
dn_wrapper = delnet.DeleteNetWrapper(deletenet_config)
dn_trainer = train.DeleteNetTrainer(dn_wrapper, rn50_fpn, deletenet_train_config)

In [None]:
# Optionally, specify a loading config to resume training
load_config = {
    'opt_filename' : '',  # path to trainer checkpoint
    'splitnet_wrapper_filename' : '',  # path to SplitNet checkpoint
    'deletenet_wrapper_filename' : '',  # path to DeleteNet checkpoint
    'rn50_fpn_filename' : '',  # path to ResNet50+FPN checkpoint
}

# Load trainer
trainer = train.JointSplitNetDeleteNetTrainer(
    sn_wrapper,
    sn_trainer,
    dn_wrapper,
    dn_trainer,
#     load_config,  # uncomment if resuming training
)

In [None]:
num_epochs = 10
trainer.train(num_epochs, dl)
trainer.save()

# SGS-Net

In [None]:
# Load configs
sgsnet_config = nc.get_sgsnet_config('configs/sgsnet.yaml')
sgsnet_training_config = nc.get_sgsnet_train_config('configs/sgsnet_training.yaml')

In [None]:
# Load models
rn50_fpn = gc.get_resnet50_fpn_model(pretrained=True)
sgsnet_wrapper = gn.SGSNetWrapper(sgsnet_config)
sgsnet_trainer = train.SGSNetTrainer(sgsnet_wrapper, rn50_fpn, sgsnet_training_config)

In [None]:
num_epochs = 3
sgsnet_trainer.train(num_epochs, dl)
sgsnet_trainer.save()