Skip to content

Commit

Permalink
Merge pull request #745 from CCInc/api_refactor2
Browse files Browse the repository at this point in the history
 Refactor common ModelFactory methods
  • Loading branch information
CCInc committed Apr 4, 2022
2 parents 61ef9cb + f3ed2be commit 297f0e0
Show file tree
Hide file tree
Showing 8 changed files with 46 additions and 143 deletions.
21 changes: 2 additions & 19 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,9 @@ repos:
- id: check-yaml

- repo: https://github.com/psf/black
rev: 22.1.0
rev: 22.3.0
hooks:
- id: black
language_version: python3.8

- repo: https://github.com/PyCQA/autoflake
rev: v1.4
Expand All @@ -27,25 +26,9 @@ repos:
"--ignore-init-module-imports",
"--imports=torch,torch_geometric,torch_scatter,torch_cluster,numpy,sklearn,scipy,torch_sparse,torch_points_kernels",
]

- repo: https://github.com/kynan/nbstripout
rev: 0.5.0
hooks:
- id: nbstripout
files: ".ipynb"

- repo: local
hooks:
- id: requirements.txt
name: Generate requirements.txt
entry: poetry export
args:
[
"-f",
"requirements.txt",
"-o",
"requirements.txt",
"--without-hashes",
]
pass_filenames: false
language: system
files: "poetry.lock"
3 changes: 0 additions & 3 deletions test/test_api.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
from torch_points3d.datasets.object_detection.scannet import ScannetDataset
from torch_points3d.core.data_transform import GridSampling3D
from torch_points3d.applications.pretrained_api import PretainedRegistry
from test.mockdatasets import MockDatasetGeometric, MockDataset
import os
import sys
import unittest
import torch
from omegaconf import OmegaConf

ROOT = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..")
sys.path.insert(0, ROOT)
Expand Down
24 changes: 2 additions & 22 deletions torch_points3d/applications/kpconv.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os
from omegaconf import DictConfig, OmegaConf
import logging

Expand All @@ -11,11 +10,6 @@
from torch_points3d.core.common_modules.base_modules import MLP
from .utils import extract_output_nc


CUR_FILE = os.path.realpath(__file__)
DIR_PATH = os.path.dirname(os.path.realpath(__file__))
PATH_TO_CONFIG = os.path.join(DIR_PATH, "conf/kpconv")

log = logging.getLogger(__name__)


Expand Down Expand Up @@ -50,24 +44,10 @@ def KPConv(

class KPConvFactory(ModelFactory):
def _build_unet(self):
if self._config:
model_config = self._config
else:
path_to_model = os.path.join(PATH_TO_CONFIG, "unet_{}.yaml".format(self.num_layers))
model_config = OmegaConf.load(path_to_model)
ModelFactory.resolve_model(model_config, self.num_features, self._kwargs)
modules_lib = sys.modules[__name__]
return KPConvUnet(model_config, None, None, modules_lib, **self.kwargs)
return self._build_unet_base(KPConvUnet, "conf/kpconv", __name__)

def _build_encoder(self):
if self._config:
model_config = self._config
else:
path_to_model = os.path.join(PATH_TO_CONFIG, "encoder_{}.yaml".format(self.num_layers))
model_config = OmegaConf.load(path_to_model)
ModelFactory.resolve_model(model_config, self.num_features, self._kwargs)
modules_lib = sys.modules[__name__]
return KPConvEncoder(model_config, None, None, modules_lib, **self.kwargs)
return self._build_encoder_base(KPConvEncoder, "conf/kpconv", __name__)


class BaseKPConv(UnwrappedUnetBasedModel):
Expand Down
27 changes: 2 additions & 25 deletions torch_points3d/applications/minkowski.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import os
import sys
from omegaconf import DictConfig, OmegaConf
import logging
import torch
Expand All @@ -15,10 +13,6 @@
from .utils import extract_output_nc


CUR_FILE = os.path.realpath(__file__)
DIR_PATH = os.path.dirname(os.path.realpath(__file__))
PATH_TO_CONFIG = os.path.join(DIR_PATH, "conf/sparseconv3d")

log = logging.getLogger(__name__)


Expand Down Expand Up @@ -56,27 +50,10 @@ def Minkowski(

class MinkowskiFactory(ModelFactory):
def _build_unet(self):
if self._config:
model_config = self._config
else:
path_to_model = os.path.join(PATH_TO_CONFIG, "unet_{}.yaml".format(self.num_layers))
model_config = OmegaConf.load(path_to_model)
ModelFactory.resolve_model(model_config, self.num_features, self._kwargs)
modules_lib = sys.modules[__name__]
return MinkowskiUnet(model_config, None, None, modules_lib, **self.kwargs)
return self._build_unet_base(MinkowskiUnet, "conf/sparseconv3d", __name__)

def _build_encoder(self):
if self._config:
model_config = self._config
else:
path_to_model = os.path.join(
PATH_TO_CONFIG,
"encoder_{}.yaml".format(self.num_layers),
)
model_config = OmegaConf.load(path_to_model)
ModelFactory.resolve_model(model_config, self.num_features, self._kwargs)
modules_lib = sys.modules[__name__]
return MinkowskiEncoder(model_config, None, None, modules_lib, **self.kwargs)
return self._build_encoder_base(MinkowskiEncoder, "conf/sparseconv3d", __name__)


class BaseMinkowski(UnwrappedUnetBasedModel):
Expand Down
33 changes: 32 additions & 1 deletion torch_points3d/applications/modelfactory.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
from enum import Enum
from omegaconf import DictConfig
import os
import sys
from omegaconf import DictConfig, OmegaConf
import logging

from torch_points3d.utils.model_building_utils.model_definition_resolver import resolve

log = logging.getLogger(__name__)

CUR_FILE = os.path.realpath(__file__)
DIR_PATH = os.path.dirname(os.path.realpath(__file__))


class ModelArchitectures(Enum):
UNET = "unet"
Expand Down Expand Up @@ -62,9 +67,35 @@ def num_features(self):
def _build_unet(self):
raise NotImplementedError

def _build_unet_base(self, unet_class, config_dir, module_name, config_file=None):
PATH_TO_CONFIG = os.path.join(DIR_PATH, config_dir)
if self._config:
model_config = self._config
else:
if config_file is None:
config_file = "unet_{}.yaml".format(self.num_layers)
path_to_model = os.path.join(PATH_TO_CONFIG, config_file)
model_config = OmegaConf.load(path_to_model)
ModelFactory.resolve_model(model_config, self.num_features, self._kwargs)
modules_lib = sys.modules[module_name]
return unet_class(model_config, None, None, modules_lib, **self.kwargs)

def _build_encoder(self):
raise NotImplementedError

def _build_encoder_base(self, encoder_class, config_dir, module_name, config_file=None):
PATH_TO_CONFIG = os.path.join(DIR_PATH, config_dir)
if self._config:
model_config = self._config
else:
if config_file is None:
config_file = "encoder_{}.yaml".format(self.num_layers)
path_to_model = os.path.join(PATH_TO_CONFIG, config_file)
model_config = OmegaConf.load(path_to_model)
ModelFactory.resolve_model(model_config, self.num_features, self._kwargs)
modules_lib = sys.modules[module_name]
return encoder_class(model_config, None, None, modules_lib, **self.kwargs)

def _build_decoder(self):
raise NotImplementedError

Expand Down
31 changes: 4 additions & 27 deletions torch_points3d/applications/pointnet2.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import os
import sys
from omegaconf import DictConfig, OmegaConf
import logging

Expand All @@ -12,10 +10,6 @@
from torch_points3d.core.common_modules.base_modules import Seq
from .utils import extract_output_nc

CUR_FILE = os.path.realpath(__file__)
DIR_PATH = os.path.dirname(os.path.realpath(__file__))
PATH_TO_CONFIG = os.path.join(DIR_PATH, "conf/pointnet2")

log = logging.getLogger(__name__)


Expand Down Expand Up @@ -57,29 +51,12 @@ def PointNet2(

class PointNet2Factory(ModelFactory):
def _build_unet(self):
if self._config:
model_config = self._config
else:
path_to_model = os.path.join(
PATH_TO_CONFIG, "unet_{}_{}.yaml".format(self.num_layers, "ms" if self.kwargs["multiscale"] else "ss")
)
model_config = OmegaConf.load(path_to_model)
ModelFactory.resolve_model(model_config, self.num_features, self._kwargs)
modules_lib = sys.modules[__name__]
return PointNet2Unet(model_config, None, None, modules_lib, **self.kwargs)
config_file = "unet_{}_{}.yaml".format(self.num_layers, "ms" if self.kwargs["multiscale"] else "ss")
return self._build_unet_base(PointNet2Unet, "conf/pointnet2", __name__, config_file)

def _build_encoder(self):
if self._config:
model_config = self._config
else:
path_to_model = os.path.join(
PATH_TO_CONFIG,
"encoder_{}_{}.yaml".format(self.num_layers, "ms" if self.kwargs["multiscale"] else "ss"),
)
model_config = OmegaConf.load(path_to_model)
ModelFactory.resolve_model(model_config, self.num_features, self._kwargs)
modules_lib = sys.modules[__name__]
return PointNet2Encoder(model_config, None, None, modules_lib, **self.kwargs)
config_file = "encoder_{}_{}.yaml".format(self.num_layers, "ms" if self.kwargs["multiscale"] else "ss")
return self._build_encoder_base(PointNet2Encoder, "conf/pointnet2", __name__, config_file)


class BasePointnet2(UnwrappedUnetBasedModel):
Expand Down
24 changes: 2 additions & 22 deletions torch_points3d/applications/rsconv.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import os
import sys
import queue
from omegaconf import DictConfig, OmegaConf
import logging
Expand All @@ -13,10 +11,6 @@
from torch_points3d.core.common_modules.base_modules import Seq
from .utils import extract_output_nc

CUR_FILE = os.path.realpath(__file__)
DIR_PATH = os.path.dirname(os.path.realpath(__file__))
PATH_TO_CONFIG = os.path.join(DIR_PATH, "conf/rsconv")

log = logging.getLogger(__name__)


Expand Down Expand Up @@ -47,24 +41,10 @@ def RSConv(

class RSConvFactory(ModelFactory):
def _build_unet(self):
if self._config:
model_config = self._config
else:
path_to_model = os.path.join(PATH_TO_CONFIG, "unet_{}.yaml".format(self.num_layers))
model_config = OmegaConf.load(path_to_model)
ModelFactory.resolve_model(model_config, self.num_features, self._kwargs)
modules_lib = sys.modules[__name__]
return RSConvUnet(model_config, None, None, modules_lib, **self.kwargs)
return self._build_unet_base(RSConvUnet, "conf/rsconv", __name__)

def _build_encoder(self):
if self._config:
model_config = self._config
else:
path_to_model = os.path.join(PATH_TO_CONFIG, "encoder_{}.yaml".format(self.num_layers))
model_config = OmegaConf.load(path_to_model)
ModelFactory.resolve_model(model_config, self.num_features, self._kwargs)
modules_lib = sys.modules[__name__]
return RSConvEncoder(model_config, None, None, modules_lib, **self.kwargs)
return self._build_encoder_base(RSConvEncoder, "conf/rsconv", __name__)


class RSConvBase(UnwrappedUnetBasedModel):
Expand Down
26 changes: 2 additions & 24 deletions torch_points3d/applications/sparseconv3d.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os
import sys
from omegaconf import DictConfig, OmegaConf
import logging
import torch
Expand All @@ -16,10 +15,6 @@
from .utils import extract_output_nc


CUR_FILE = os.path.realpath(__file__)
DIR_PATH = os.path.dirname(os.path.realpath(__file__))
PATH_TO_CONFIG = os.path.join(DIR_PATH, "conf/sparseconv3d")

log = logging.getLogger(__name__)


Expand Down Expand Up @@ -69,27 +64,10 @@ def SparseConv3d(

class SparseConv3dFactory(ModelFactory):
def _build_unet(self):
if self._config:
model_config = self._config
else:
path_to_model = os.path.join(PATH_TO_CONFIG, "unet_{}.yaml".format(self.num_layers))
model_config = OmegaConf.load(path_to_model)
ModelFactory.resolve_model(model_config, self.num_features, self._kwargs)
modules_lib = sys.modules[__name__]
return SparseConv3dUnet(model_config, None, None, modules_lib, **self.kwargs)
return self._build_unet_base(SparseConv3dUnet, "conf/sparseconv3d", __name__)

def _build_encoder(self):
if self._config:
model_config = self._config
else:
path_to_model = os.path.join(
PATH_TO_CONFIG,
"encoder_{}.yaml".format(self.num_layers),
)
model_config = OmegaConf.load(path_to_model)
ModelFactory.resolve_model(model_config, self.num_features, self._kwargs)
modules_lib = sys.modules[__name__]
return SparseConv3dEncoder(model_config, None, None, modules_lib, **self.kwargs)
return self._build_encoder_base(SparseConv3dEncoder, "conf/sparseconv3d", __name__)


class BaseSparseConv3d(UnwrappedUnetBasedModel):
Expand Down

0 comments on commit 297f0e0

Please sign in to comment.