# Tutorial 08: Extending `electricmayhem`

All of the pipeline stages are subclasses from `PipelineBase`, which in turn inherits from `torch.nn.Module`- so our best practice should be to default to following the best practices for `torch.nn`. I recommend reading their documentation before you start.

A word on memory management- any attributes added to your object as a parameter (`torch.nn.Parameter`, `torch.nn.ParameterDict`, etc) or as a module will get copied over to the GPU when you call `.cuda()`. If you create an attribute that's just a list of tensors or something it will not. If you're creating a pipeline stage that has a high memory footprint (such as an implanter designed for a large dataset) consider the tradeoff between storing more data on the GPU versus the overhead of copying over a subset each batch.

### Required steps

* Add a `name` attribute that's a string, giving the name of the module (how it will apear in MLFlow for example)
* Add an `__init__()` method that starts with a call to `super().__init__()`. This is a requirement from `torch.nn` that sets up a bunch of the machinery the `Module` object relies on.
* Any keyword arguments you need to re-initialize the step should be captured in a JSON/YAML-serializable dict in `self.params`.
* The main behavior for the stage should be written as a `forward()` method that:
  * Inputs the outputs of the previous stage (generally a batch of image tensors or a dictionary of image batches)
  * Has an `evaluate` kwarg; if `evaluate=True` it runs an evaluation batch (for example possibly using holdout images or a separate model)
  * Has a `control` kwarg; if `control=True`, runs a control batch (same configuration as previous batch but without the patch). Only needs to work with `evaluate=True).
  * Can optionally input a dictionary of paramaters to the `params` kwarg, overruling any randomly-sampled parameters with these values.
  * Can input `**kwargs` dictionary containing arbitrary metadata created by previous stages
  * Returns a 2-tuple containing that stages' output and the input `kwargs` dictionary (possibly with more information added to it)
* There should be a `get_last_sample_as_dict()` method. It should return any stochastic parameters sampled for the last batch as a dictionary containing lists or 1D `numpy` arrays of length `batchsize`. You should be able to pass this dictionary directly back to the `params` argument of `forward`.

### Optional steps

If you want to get fancy:

* Override the `get_description()` method to generate a more useful markdown description for MLFlow.
* Override the `log_vizualizations()` method with any diagnostics that would be useful to log to TensorBoard. This method will get called whenever `pipeline.evaluate()` is run. It should input:
  * **x:** batch of evaluation data (usually a batch of image tensors; sometimes a dict of image batches)
  * **x_control:** the corresponding control batch for **x** (for before-and-after visualizations)
  * **writer:** a TensorBoard `SummaryWriter` object- use this to write diagnostics
  * **step:** integer; current training step (needed both for TensorBoard and logging MLFlow metrics)
  * **logging_to_mlflow:** Boolean; whether MLFlow logging is active.
* Overwrite the `validate()` method to check for anything specific that could go wrong with that step. When the user calles `Pipeline.validate()` it will run the `validate()` method for each step. Use the `logging` library to record check results at the `info` or `warning` level. `validate()` inputs:
  * **x:** batch of evaluation data (usually a batch of image tensors; sometimes a dict of image batches)


```
class MyPipelineStage(PipelineBase):
    name = "MyPipelineStage"

    def __init__(self, foo, bar):
        super().__init__()
        self.params = {"foo":foo, "bar":bar}
        
    def forward(self, x, control=False, evaluate=False, params=None, **kwargs):
        <stuff here>
        y = f(x)
        return y, kwargs
        
    def get_last_sample_as_dict(self):
        return dict(<some stuff>)
        
    def log_vizualizations(self, x, x_control, writer, step):
        """
        """
        writer.add_image("stacked_patch", <some stuff>, global_step=step,
                        logging_to_mlflow)
         
    def get_description(self):
        return "**MyPipelineStage** and some details that would be helpful in mlflow"
        

```