Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
231 changes: 184 additions & 47 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,70 +1,207 @@
# Refactor Task: Introduce `BaseCNN` in `cnn.py` and Refactor Existing CNN Classes
# BaseCNN Refactor Submission

[![tests](https://github.com/omaib/coding-test/workflows/test/badge.svg)](https://github.com/omaib/coding-test/actions/workflows/test.yml)
[![codecov](https://codecov.io/gh/omaib/coding-test/branch/main/graph/badge.svg)](https://codecov.io/gh/omaib/coding-test)
[![GitHub license](https://img.shields.io/badge/license-MIT-blue.svg)](https://github.com/omaib/coding-test/blob/main/LICENSE)
[![Python](https://img.shields.io/badge/python-3.10%20%7C%203.11%20%7C%203.12-blue)](https://www.python.org)
This README summarizes the design, implementation, and testing of the `BaseCNN` refactor for the OMAIB coding task. All changes are self-contained and submitted as a `.zip` file per instructions.

## Objective
## Summary
This introduces a `BaseCNN` class to consolidate shared CNN logic and refactors all the 5 existing CNN models (`SmallCNNFeature`, `SignalVAEEncoder`, `ProteinCNN`, `LeNet`, `ImageVAEEncoder`) to inherit from it. The refactoring eliminates ~50 lines of duplicate code while maintaining 100% backward compatibility with existing APIs and behavior.

The current `cnn.py` implements multiple CNN classes that share similar initialisation logic, layers, and utility methods. Your task is to introduce a reusable `BaseCNN` class that captures shared functionality across the CNN variants, and refactor the remaining classes to enhance code maintainability and reduce redundancy.
**Key Changes:**
- Created `BaseCNN` base class with reusable utility methods
- Refactored all 5 CNN classes to directly inherit from `BaseCNN`
- All 6 tests passing with identical behavior preserved
- Zero code quality violations (Black, isort, flake8, mypy compliant)

Each existing CNN class should inherit from BaseCNN while preserving its current behaviour and public API.
## BaseCNN Design

This task assesses your ability to design for **reusability**, **efficiency** and **clarity**, as well as apply good software engineering practices: using pre-commit hooks and writing effective tests.
### Design Decisions For The `BaseCNN` Implementation

## Key Requirements
**1. Modular Helper Methods**
The `BaseCNN` class provides 3 specialised helper methods, each serving a distinct architectural pattern:

- **Reusability**: Move the shared CNN logic into a new `BaseCNN` and remove duplicate code from subclasses.
- **Inheritance**: Each CNN model should inherit from `BaseCNN`, overriding only model-specific parts.
- **Compatibility**: Existing APIs, inputs/outputs, and model behaviour must remain unchanged.
- **Documentation**: Add clear docstrings for `BaseCNN` and all models.
- **Testing**: Ensure all tests pass or update `test_cnn.py` to confirm identical behaviour.
- **`_create_sequential_conv_blocks()`**: creates multiple convolutional layers with explicit channel lists. Used by 4 classes (SmallCNNFeature, SignalVAEEncoder, ProteinCNN, ImageVAEEncoder) that need precise control over channel progression.

## Suggested Steps
- **`_create_doubling_conv_blocks()`**: creates convolutional layers with automatic channel doubling (base → 2×base → 4×base, etc.). Specifically designed for LeNet's architectural pattern where channels double at each layer.

1. Review the CNN classes in `cnn.py` and identify shared logic.
2. Define and implement the `BaseCNN` class.
3. Refactor existing CNN models.
4. Update and run tests to confirm correctness.
- **Common Utilities**: `_apply_activation()`, `_flatten_features()`, `_initialize_weights()` provide frequently-used operations across all models.

## Environment Setup
**2. Flexibility Through Optional Parameters**
All helper methods accept optional parameters to accommodate diverse architectural requirements:
- Support for both 1D (signals) and 2D (images) convolutions via `conv_type` parameter
- Optional batch normalisation via `use_batch_norm` flag
- Configurable kernel sizes, strides, and paddings via list parameters
- Type-safe returns with `Optional[nn.ModuleList]` for conditional components

1. **Create a Python environment** (version 3.10-3.12) using any tool you prefer, such as:
- Conda: `conda create -n omaib python=3.12 && conda activate omaib`
- venv: `python3.12 -m venv .venv && source .venv/bin/activate`
**3. Type Safety for All Methods**
- Extensive type hints provided throughout (using `typing.List`, `Optional`, `Tuple`, and `Union`)
- Mypy-compliant type assertions in forward passes where batch norms are guaranteed to exist
- Clear return type annotations for all.

2. **Install required packages:**
- PyTorch and Torchvision (adjust the `index-url` parameter if you wish to use a GPU build):
### How The Design Addresses Scalability, Flexibility, and Modularity

```bash
pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu
```
**1. Scalability:**
- Helper methods use `nn.ModuleList` for dynamic layer creation, supporting arbitrary network depths
- ProteinCNN demonstrates scalability: accepts `num_filters` list of any length and automatically creates corresponding layers
- LeNet demonstrates configurability: `additional_layers` parameter allows arbitrary depth expansion

- Pre-commit and pytest:
**2. Flexibility:**
- Single helper method (`_create_sequential_conv_blocks`) supports both 1D and 2D convolutions through `conv_type` parameter
- Batch normalization is optional, enabling both standard CNNs and VAE encoders (which typically don't use batch norm)
- Architecture-specific patterns (sequential vs. doubling) are captured in separate helpers, not forced into a single rigid structure

**3. Modularity:**
- Each helper method has a single, well-defined responsibility
- Subclasses only override what's unique to them; common logic stays in the base class
- Clean separation: base class handles layer creation, subclasses define architecture-specific configurations

```bash
pip install pre-commit pytest
```
### Specific Challenges Encountered and Their Resolutions

**Challenge 1: Type Safety with Optional Batch Normalization**
- **Problem**: When using `mypy`, the helper methods return `Optional[nn.ModuleList]` for batch norms (can be `None` when `use_batch_norm=False`), causing mypy type errors when subclasses index into `batch_norms`.
- **Solution**: I added type guard assertions (`assert self.batch_norms is not None`) in forward methods where batch norms are guaranteed to exist, informing mypy that `None` is impossible in that context.

**Challenge 2: Avoiding Over-Abstraction**
- **Problem**: I initially attempted creating `BaseVAEEncoder` as intermediate abstraction for `SignalVAEEncoder` and `ImageVAEEncoder`.
- **Resolution**: I removed intermediate layer to comply with task requirement: "Each existing CNN class should inherit from BaseCNN". VAE encoders now directly inherit from `BaseCNN` while still benefiting from `_create_sequential_conv_blocks()` helper.

**Challenge 3: Diverse Architectural Patterns**
- **Problem**: Looking at the LeNet intial class implementation, it made used of channel doubling (64→128→256), while other models use explicit channel lists (16→32→64).
- **Solution**: As a solution, I created 2 specialised helpers rather than forcing all models into one pattern as stated below:
- `_create_sequential_conv_blocks()`: For explicit channel control
- `_create_doubling_conv_blocks()`: For automatic doubling pattern with pooling layers

**Challenge 4: Code Quality Compliance**
- **Problem**: Long docstring lines (>120 chars) failing flake8 checks.
- **Solution**: Reformatted docstrings with appropriate line breaks while maintaining readability and clarity.

---

## Extensibility

### Supporting Additional Modalities
There are multiple ways in which the above implementation can be extended to support additional modalities. Below are a few recommended implementation examples:

**1. For New Convolutional Architectures:**

To add a new CNN model (for example, 3D volumetric data, graph convolutions, etc), implement as below:

```python
class VolumetricCNN(BaseCNN):
def __init__(self, input_channels: int, ...):
super().__init__()
# Use existing helpers or create volumes-specific layers
# Can extend _create_sequential_conv_blocks to support conv_type="3d"

def forward(self, x: torch.Tensor) -> torch.Tensor:
# Implement forward pass using base class utilities
x = self._apply_activation(...)
x = self._flatten_features(...)
return x
```

## Validate Code Quality and Tests
**2. For New Activation Functions:**

From the root of the repository, run the following commands in your terminal:
To add a new activation function (for example, tanh, softmax, etc), extend `_apply_activation()` in `BaseCNN`:

1. Install pre-commit hooks (only required once):
```python
def _apply_activation(self, x: torch.Tensor, activation_type: str = "relu") -> torch.Tensor:
if activation_type == "relu":
return F.relu(x)
elif activation_type == "sigmoid":
return torch.sigmoid(x)
elif activation_type == "tanh": # New addition
return torch.tanh(x)
elif activation_type == "softmax": # New addition
return F.softmax(x)
# ... add more as needed
```

```bash
pre-commit install
```
**3. For New Helper Methods:**

2. Run pre-commit checks for code style and formatting on all files:
To add specialised helpers to `BaseCNN` for new architectural patterns such as the ResNet-style architecture, implement as below:

```bash
pre-commit run --all-files
```
```python
def _create_residual_blocks(
self,
in_channels: int,
num_blocks: int,
...
) -> nn.ModuleList:
"""Create residual blocks for ResNet-style architectures."""
# Implementation
```

3. Run tests cases to verify functionality:
**4. For Extending to 3D Convolutions:**

Modify `_create_sequential_conv_blocks()` to accept `conv_type="3d"`:

```python
# In BaseCNN._create_sequential_conv_blocks():
if conv_type == "1d":
conv_layers.append(nn.Conv1d(...))
if use_batch_norm and batch_norms is not None:
batch_norms.append(nn.BatchNorm1d(...))
elif conv_type == "2d":
conv_layers.append(nn.Conv2d(...))
if use_batch_norm and batch_norms is not None:
batch_norms.append(nn.BatchNorm2d(...))
elif conv_type == "3d": # New addition
conv_layers.append(nn.Conv3d(...))
if use_batch_norm and batch_norms is not None:
batch_norms.append(nn.BatchNorm3d(...))
```

### Conventions to Follow

**1. Inheritance:**
- Always inherit directly from `BaseCNN` (no intermediate base classes)
- Call `super().__init__()` first in your `__init__` method

**2. Helper Method Usage:**
- Use `_create_sequential_conv_blocks()` for explicit channel control
- Use `_create_doubling_conv_blocks()` for automatic doubling patterns
- Create new helpers in `BaseCNN` if you discover a new reusable pattern

**3. Type Hints:**
- Add comprehensive type hints to all methods
- Use `Optional[...]` for components that may not be initialized (e.g., batch norms)
- Add type guard assertions in forward passes where optionals are guaranteed non-None

**4. Documentation:**
- Provide clear docstrings with Args, Returns, and Example sections
- Document architectural assumptions (e.g., input size requirements)
- Explain any special configuration requirements

**5. Testing:**
- Ensure existing tests pass without modification (backward compatibility)
- Add new tests for new models in `test_cnn.py`
- Test both forward pass outputs and expected tensor shapes

**6. Code Quality:**
- Run pre-commit hooks before committing (`pre-commit run --all-files`)
- Maintain Black formatting (120 char line length)
- Keep flake8 violations at zero
- Ensure mypy type checking passes

## Checklist

- [x] Code runs without errors
- [x] Code is well-documented with extensive docstrings and type hints coverage
- [x] Tests are included and pass for all required 6 tests
- [x] Follows project coding standards (Black, isort, flake8, mypy compliant)

## Additional Notes

### Implementation Highlights

1. **100% Backward Compatibility**: All existing APIs preserved exactly - no changes to constructor signatures, forward pass signatures, or output formats.

2. **Code Reduction**: Eliminated approximately 50 lines of duplicate code through helper methods while improving maintainability.

3. **Helper Method Adoption**: 100% adoption rate - all 5 CNN classes now use at least one helper method from `BaseCNN`.

4. **Type Safety**: Full mypy compliance with proper type annotations and runtime assertions.

### Files Modified

- `cnn.py`: Added `BaseCNN` class, refactored all 5 CNN models

```bash
pytest
```
Loading
Loading