# 3.2 Use PVN3D on Pose Estimation Dataset.

## TODO 1:
- [X] Down sample image
- [X] Fix PSPNet padding: output always have the same size as input (120,160 works)
- [ ] Down sample point cloud: Project pcd to original image to get corresponding pixel.


## TODO 2:
- [ ] Make sure PSP Trains
- [ ] Make sure Pointnet Trains
- [ ] Make sure PVN3D Trains
- [ ] Make sure Loss is Right

In [1]:
import warnings
warnings.filterwarnings("ignore")

from utils.utils_data import *
from lib.pvn3d import *
import logging
from datetime import datetime

checkpoint = None

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = torch.device('cpu')
# H, W = 360, 640     # Image height and width
H, W = 120, 160
transforms = {
    # 'rgb'  : Compose([Resize((H, W)), RandomHorizontalFlip(), ColorJitter()]),
    'rgb'  : Compose([Resize((H, W))]),
    'depth': Compose([Resize((H, W))]),
    'label': Compose([Resize((H, W))])
}
p = {
    "device":device,                    
    'bz': 2, 'shuffle': False, 'num_workers':1,   # For loader TODO: Modify this
    'objects': 82, 'keypoints': 16, 'samplepoints': H*W,    # For PVN3D model
    "epochs": 100,  "lr": 1e-5, 'decay': 0,         # 2. For learning
    "scale":0.5,    "amp": False,                 # 3. Grad scaling
    "checkpoint": None
}

# 1. Initialize Network and Dataloader
model = PVN3D(
    num_classes=p['objects'], pcd_input_c=3, num_kps=16, num_points=H*W
)
model.train()
loader_train = get_loader(SegDataset(split_name='val_tiny', transforms=transforms, one_hot_label=False, N_samples=H*W), params=p)


''' ===== Logging ===== '''
with open('./training_data/objects_present.pickle', 'rb') as f:
    OBJECTS = list(pickle.load(f))  # [(i_d1, name_1), ..., (id_n, name_n)]
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
logging.info(f'Using device {device}')
logging.info(f'There are [{len(OBJECTS)}] objects in total')   # 23 objects
logging.info(f'Network:\n'
                f'\t{model.num_kps}     number of keypoints\n'
                f'\t{model.num_points}  number of sample points\n'
                f'\t{model.pcd_input_c} number of pcd augment channels\n'
                f'\t{model.num_classes} output channels (classes)\n')
if checkpoint:
    model.load_state_dict(torch.load(checkpoint))
    logging.info(f'Resuming from checkpoint {checkpoint}')
''' ===== End Logging ===== '''


# 2. Train model, save checkpoint
now = datetime.now().strftime("%d-%H:%M")
epochs = p['epochs']
try:        
    train_pvn3d(model, loader_train, p)
    fname = f'./exp/pvn3d_weight_epochs{epochs}_{now}.pt'
    torch.save(model.state_dict(), fname)
except KeyboardInterrupt:
    fname = f'./exp/pvn3d_weight_INTERRUPTED_{now}_.pt'
    torch.save(model.state_dict(), fname)


## 1. Try PointNet++ from [torch-points3d](https://torch-points3d.readthedocs.io/en/latest/index.html)