# Multimodal model
This notebook lets you instantiate a model and run a forward pass from a multimodal sample. For this, you must have done the following:
- downloaded and preprocessed the S3DIS dataset (if you haven't, instantiating the dataset will launch it for you though). You may edit this code to load any other multimodal dataset you have on your machine.

In [None]:
# Select you GPU
I_GPU = 0

In [None]:
# Uncomment to use autoreload
# %load_ext autoreload
# %autoreload 2

import os
import sys
import numpy as np
import torch
from time import time
from omegaconf import OmegaConf
start = time()
import warnings
warnings.filterwarnings('ignore')

torch.cuda.set_device(I_GPU)
DIR = os.path.dirname(os.getcwd())
ROOT = os.path.join(DIR, "..")
sys.path.insert(0, ROOT)
sys.path.insert(0, DIR)

from torch_points3d.utils.config import hydra_read
from torch_points3d.core.multimodal.data import MMBatch
from torch_points3d.datasets.segmentation.multimodal.s3dis import S3DISFusedDataset
from torch_points3d.models.model_factory import instantiate_model

## Dataset and model configuration

The dataset and model configurations are parsed in the following cell using Hydra. For the **multimodal semantic segmentation** task, dataset configs live in `conf/data/segmentation/multimodal` and model configs live in `conf/models/segmentation/multimodal`. You can create a new model there and run a forward pass on it in this notebook to debug it.

For now, supported multimodal datasets are [S3DIS](http://buildingparser.stanford.edu/dataset.html), [ScanNet](http://www.scan-net.org/) and [KITTI-360](http://www.cvlibs.net/datasets/kitti-360), while supported multimodal architectures are based on [MinkowskiNet](https://arxiv.org/abs/1904.08755)-like backbones. It would be relatively, using Torch-Points3D, to extend the same models to other backbones such as [PointNet](https://arxiv.org/abs/1706.02413), [KP-Conv](https://arxiv.org/abs/1904.08889), etc.

In [None]:
# Set your dataset root directory, where the data was/will be downloaded
DATA_ROOT = '/path/to/your/dataset/root/directory'

config_file = 'segmentation/multimodal/s3disfused-sparse'  # dataset config, S3DIS here 
models_config = 'segmentation/multimodal/sparseconv3d'     # family of models based on sparseconv3d backbone
model_name = 'Res16UNet34-L4-early-ade20k-interpolate'     # name of the specific model we want to use

overrides = [
    'task=segmentation',
    f'data={config_file}',
    f'models={models_config}',
    f'model_name={model_name}',
    f'data.dataroot={DATA_ROOT}',
]

cfg = hydra_read(overrides)
# print(OmegaConf.to_yaml(cfg))

## Dataset creation

The dataset will now be created. If you have not downloaded or preprocessed the dataset before, it will be performed here (but this will take some time though). Otherwise, it will should load normally within a few seconds.

In [None]:
# Dataset instantiation
start = time()
dataset = S3DISFusedDataset(cfg.data)
# print(dataset)
print(f"Time = {time() - start:0.1f} sec.")

## Model creation

The following cell will instantiate teh model, based on the config. In Torch-Points3d, instantiating a model often requires information about the dataset (*e.g.* the number of classes). For this reason, `instantiate_model` requires the `dataset` to be passed as input.

In [None]:
# Model instantiation
print(f"Model: {cfg.model_name}")
model = instantiate_model(cfg, dataset)
model = model.train().cuda()
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Model parameters : {n_params / 10**6:0.1f} M")
# print(model)

## Forward pass on a multimodal batch

We can now create a batch of `batch_size` multimodal samples from `dataset` and run a forward and backward pass on `model`. This can help us debug the model before launching a full training experiemnt.

In [None]:
batch_size = 4

# Create a batch of multimodal samples
print(f"\nBatch creation")
batch = MMBatch.from_mm_data_list([dataset.train_dataset[i] for i in range(batch_size)])
# print(batch)

# Set some model attributes based on the input batch. Moves batch to 
# device
print(f"\nForward pass")
model.set_input(batch, model.device)

# Forward pass. Output will be stored in model attributes
batch = model(batch)

# Loss belongs to the model attributes and is automatically computed 
# when running forward pass
print(f"\nLoss")
model.loss_seg

# Backward pass
print(f"\nBackward pass")
model.backward()

print(f"\nOK")

del batch