In [None]:
# ! pip install git+https://github.com/catalyst-team/catalyst@dev scikit-learn>=0.20 optuna --upgrade

In [None]:
# ! pip install catalyst==21.02rc1 scikit-learn>=0.20 optuna gym==0.17.3 --upgrade

In [None]:
# import os
# os.environ["CUDA_VISIBLE_DEVICES"] = "0"

#  Catalyst 21.xx demo

## Stage 1: Customization is all u need
- 10 minimal examples with different Catalyst customization usecases

In [None]:
import os
import random
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_moons, make_blobs
%matplotlib inline

In [None]:
from typing import *

import torch
from torch import nn
from torch import optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, TensorDataset

from catalyst import dl, metrics, utils

In [None]:
# make up a dataset
def make_dataset(seed=42, n_samples=int(1e3)):
    np.random.seed(seed)
    random.seed(seed)
    X, y = make_moons(n_samples=n_samples, noise=0.1)

    y = y*2 - 1 # make y be -1 or 1
    return X, y

def visualize_dataset(X, y):
    plt.figure(figsize=(5,5))
    plt.scatter(X[:,0], X[:,1], c=y, s=20, cmap='jet')

# let's create train data
X_train, y_train = make_dataset()
visualize_dataset(X_train, y_train)

In [None]:
# valid data
X_valid, y_valid = make_dataset(seed=137)
visualize_dataset(X_valid, y_valid)

In [None]:
# and another train one (why not?)
X_train2, y_train2 = make_dataset(seed=1337)
visualize_dataset(X_train2, y_train2)

In [None]:
# initialize a model 
# 2-layer neural network
model = nn.Sequential(
    nn.Linear(2, 16), nn.ReLU(), 
    nn.Linear(16, 16), nn.ReLU(), 
    nn.Linear(16, 1)
)
print(model)
# print("number of parameters", len(model.parameters()))

In [None]:
def visualize_decision_boundary(X, y, model):
    h = 0.25
    x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
    y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
    xx, yy = np.meshgrid(np.arange(x_min, x_max, h),
                         np.arange(y_min, y_max, h))
    Xmesh = np.c_[xx.ravel(), yy.ravel()]
    
    inputs = torch.tensor([list(xrow) for xrow in Xmesh]).float()
    scores = model(inputs)
    
    Z = np.array([s.data > 0 for s in scores])
    Z = Z.reshape(xx.shape)

    fig = plt.figure()
    plt.contourf(xx, yy, Z, cmap=plt.cm.Spectral, alpha=0.8)
    plt.scatter(X[:, 0], X[:, 1], c=y, s=40, cmap=plt.cm.Spectral)
    plt.xlim(xx.min(), xx.max())
    plt.ylim(yy.min(), yy.max())
    plt.show()
    return fig

In [None]:
_ = visualize_decision_boundary(X_valid, y_valid, model)

In [None]:
t1 = TensorDataset(torch.tensor(X_train).float(), torch.tensor(y_train > 0).float())
t2 = TensorDataset(torch.tensor(X_train2).float(), torch.tensor(y_train2 > 0).float())
v1 = TensorDataset(torch.tensor(X_valid).float(), torch.tensor(y_valid > 0).float())

loaders = {
    "train_1": DataLoader(t1, batch_size=32, num_workers=1), 
    "train_2": DataLoader(t2, batch_size=32, num_workers=1), 
    "valid": DataLoader(v1, batch_size=32, num_workers=1), 
}

---

### Act 1 - ``CustomRunner – batch handling by you own``

In [None]:
class CustomRunner(dl.IRunner):
    def get_engine(self) -> dl.IEngine:
        return dl.DeviceEngine("cpu")
    
    @property
    def stages(self) -> Iterable[str]:
        return ["train"]
    
    def get_stage_len(self, stage: str) -> int:
        return 5
    
    def get_loaders(self, stage: str) -> "OrderedDict[str, DataLoader]":
        return loaders
    
    def get_model(self, stage: str):
        return nn.Sequential(
            nn.Linear(2, 16), nn.ReLU(), 
            nn.Linear(16, 16), nn.ReLU(), 
            nn.Linear(16, 1)
        )

    def get_criterion(self, stage: str):
        return None

    def get_optimizer(self, stage: str, model):
        return torch.optim.Adam(model.parameters(), lr=0.02)

    def get_scheduler(self, stage: str, optimizer):
        return None
    
    def handle_batch(self, batch):
        x, y = batch
        y_hat = self.model(x)

        loss = F.binary_cross_entropy_with_logits(y_hat.view(-1), y)
        self.batch_metrics = {"loss": loss}
        if self.loader_batch_step % 10 == 0:
            print(
                f"{self.loader_key} ({self.loader_batch_step}/{self.loader_batch_len}:" 
                f"loss {loss.item()}"
            )

        if self.is_train_loader:
            loss.backward()
            self.optimizer.step()
            self.optimizer.zero_grad()

runner = CustomRunner().run()
model = runner.model

In [None]:
_ = visualize_decision_boundary(X_valid, y_valid, model)

---

### Act 2 - ``SupervisedRunner – Runner with Callbacks``

In [None]:
class CustomSupervisedRunner(dl.IRunner):
    def get_engine(self) -> dl.IEngine:
        return dl.DeviceEngine("cpu")
    
    def get_loggers(self):
        return {
            "console": dl.ConsoleLogger(),
    #         "csv": dl.LogdirLogger(logdir="./logdir02"),
            "tensorboard": dl.TensorboardLogger(logdir="./logdir02/tb"),
        }
    
    @property
    def stages(self) -> Iterable[str]:
        return ["train"]
    
    def get_stage_len(self, stage: str) -> int:
        return 5
    
    def get_loaders(self, stage: str) -> "OrderedDict[str, DataLoader]":
        return loaders
    
    def get_model(self, stage: str):
        return nn.Sequential(
            nn.Linear(2, 16), nn.ReLU(), 
            nn.Linear(16, 16), nn.ReLU(), 
            nn.Linear(16, 1)
        )

    def get_criterion(self, stage: str):
        return nn.BCEWithLogitsLoss()

    def get_optimizer(self, stage: str, model):
        return torch.optim.Adam(model.parameters(), lr=0.02)

    def get_scheduler(self, stage: str, optimizer):
        return torch.optim.lr_scheduler.MultiStepLR(optimizer, [2, 4])
    
    def get_callbacks(self, stage: str):
        return {
            # Let's use AUC metric as an example – it's loader-based, so we shouldn't compute it on each batch
            "auc": dl.LoaderMetricCallback(
                metric=metrics.AUCMetric(),
                input_key="scores", target_key="targets", 
            ), 
            # To wrap the criterion step logic, you could use CriterionCallback:
            "criterion": dl.CriterionCallback(
                metric_key="loss", 
                input_key="logits", 
                target_key="targets"
            ), 
            # To wrap the optimizer step logic, you could use OptimizerCallback:
            "optimizer": dl.OptimizerCallback(metric_key="loss"), 
            # The same case with the scheduler:
            "scheduler": dl.SchedulerCallback(
                loader_key="valid", metric_key="loss"
            ),
            # We could also use lrfinder for lr scheduling:
#             "lr-finder": dl.LRFinder(
#                 final_lr=1.0,
#                 scale="log",
#                 num_steps=None,
#                 optimizer_key=None,
#             ),
            # You can select any number of metrics to checkpoint on:
            "checkpoint1": dl.CheckpointCallback(
                logdir="./logdir02/auc",
                loader_key="valid", metric_key="auc", 
                minimize=False, save_n_best=3
            ),
            "checkpoint2": dl.CheckpointCallback(
                logdir="./logdir02/loss",
                loader_key="valid", metric_key="loss", 
                minimize=True, save_n_best=1
            ),
            # Or turn on/off tqdm verbose during loader run:
            "verbose": dl.TqdmCallback(),
        }
    
    def handle_batch(self, batch):
        x, y = batch
        y_hat = self.model(x)
        
        self.batch = {
            "features": x,
            "targets": y,
            "logits": y_hat.view(-1),
            "scores": torch.sigmoid(y_hat.view(-1)),
        }

runner = CustomSupervisedRunner().run()
model = runner.model

In [None]:
_ = visualize_decision_boundary(X_valid, y_valid, model)

---

### Act 3 - ``CustomMetric``

In [None]:
class CustomAccuracyMetric(metrics.ICallbackBatchMetric, metrics.AdditiveValueMetric):
    def update(self, scores: torch.Tensor, targets: torch.Tensor) -> float:
        value = ((scores > 0.5) == targets).float().mean().item()
        value = super().update(value, len(targets))
        return value
    
    def update_key_value(self, scores: torch.Tensor, targets: torch.Tensor) -> Dict[str, float]:
        value = self.update(scores, targets)
        return {"accuracy": value}

    def compute_key_value(self) -> Dict[str, float]:
        mean, std = super().compute()
        return {"accuracy": mean, "accuracy/std": std}

    
class CustomSupervisedRunner(dl.IRunner):
    def get_engine(self) -> dl.IEngine:
        return dl.DeviceEngine("cpu")
    
    def get_loggers(self):
        return {
            "console": dl.ConsoleLogger(),
            "tensorboard": dl.TensorboardLogger(logdir="./logdir03/tb"),
        }
    
    @property
    def stages(self) -> Iterable[str]:
        return ["train"]
    
    def get_stage_len(self, stage: str) -> int:
        return 5
    
    def get_loaders(self, stage: str) -> "OrderedDict[str, DataLoader]":
        return loaders
    
    def get_model(self, stage: str):
        return nn.Sequential(
            nn.Linear(2, 16), nn.ReLU(), 
            nn.Linear(16, 16), nn.ReLU(), 
            nn.Linear(16, 1)
        )

    def get_criterion(self, stage: str):
        return nn.BCEWithLogitsLoss()

    def get_optimizer(self, stage: str, model):
        return torch.optim.Adam(model.parameters(), lr=0.02)

    def get_scheduler(self, stage: str, optimizer):
        return torch.optim.lr_scheduler.MultiStepLR(optimizer, [2, 4])
    
    def get_callbacks(self, stage: str):
        return {
            "accuracy": dl.BatchMetricCallback(
                metric=CustomAccuracyMetric(), log_on_batch=True,
                input_key="scores", target_key="targets", 
            ),
            "auc": dl.LoaderMetricCallback(
                metric=metrics.AUCMetric(),
                input_key="scores", target_key="targets", 
            ), 
            "criterion": dl.CriterionCallback(
                metric_key="loss", 
                input_key="logits", 
                target_key="targets"
            ), 
            "optimizer": dl.OptimizerCallback(metric_key="loss"), 
            "scheduler": dl.SchedulerCallback(
                loader_key="valid", metric_key="loss"
            ),
            "checkpoint1": dl.CheckpointCallback(
                logdir="./logdir03/accuracy",
                loader_key="valid", metric_key="accuracy", 
                minimize=False, save_n_best=3
            ),
            "checkpoint2": dl.CheckpointCallback(
                logdir="./logdir03/loss",
                loader_key="valid", metric_key="loss", 
                minimize=True, save_n_best=1
            ),
    #         "verbose": dl.TqdmCallback(),
        }
    
    def handle_batch(self, batch):
        x, y = batch
        y_hat = self.model(x)
        
        self.batch = {
            "features": x,
            "targets": y,
            "logits": y_hat.view(-1),
            "scores": torch.sigmoid(y_hat.view(-1)),
        }

runner = CustomSupervisedRunner().run()
model = runner.model

In [None]:
_ = visualize_decision_boundary(X_valid, y_valid, model)

---

### Act 4 - ``CustomCallback``

In [None]:
# Let's plot the decision doundary after each epoch:
class VisualizationCallback(dl.Callback):
    def __init__(self):
        super().__init__(order=dl.CallbackOrder.External)

    def on_epoch_end(self, runner):
        img = visualize_decision_boundary(X_valid, y_valid, runner.model)


class CustomSupervisedRunner(dl.IRunner):
    def get_engine(self) -> dl.IEngine:
        return dl.DeviceEngine("cpu")
    
    def get_loggers(self):
        return {
            "console": dl.ConsoleLogger(),
            "tensorboard": dl.TensorboardLogger(logdir="./logdir04/tb"),
        }
    
    @property
    def stages(self) -> Iterable[str]:
        return ["train"]
    
    def get_stage_len(self, stage: str) -> int:
        return 5
    
    def get_loaders(self, stage: str) -> "OrderedDict[str, DataLoader]":
        return loaders
    
    def get_model(self, stage: str):
        return nn.Sequential(
            nn.Linear(2, 16), nn.ReLU(), 
            nn.Linear(16, 16), nn.ReLU(), 
            nn.Linear(16, 1)
        )

    def get_criterion(self, stage: str):
        return nn.BCEWithLogitsLoss()

    def get_optimizer(self, stage: str, model):
        return torch.optim.Adam(model.parameters(), lr=0.02)

    def get_scheduler(self, stage: str, optimizer):
        return torch.optim.lr_scheduler.MultiStepLR(optimizer, [2, 4])
    
    def get_callbacks(self, stage: str):
        return {
            "criterion": dl.CriterionCallback(
                metric_key="loss", 
                input_key="logits", 
                target_key="targets"
            ), 
            "optimizer": dl.OptimizerCallback(metric_key="loss"), 
            "scheduler": dl.SchedulerCallback(
                loader_key="valid", metric_key="loss"
            ),
            "checkpoint": dl.CheckpointCallback(
                logdir="./logdir04/loss",
                loader_key="valid", metric_key="loss", 
                minimize=True, save_n_best=1
            ),
            # And include it into callbacks:        
            "visualization": VisualizationCallback()
        }
    
    def handle_batch(self, batch):
        x, y = batch
        y_hat = self.model(x)
        
        self.batch = {
            "features": x,
            "targets": y,
            "logits": y_hat.view(-1),
            "scores": torch.sigmoid(y_hat.view(-1)),
        }

runner = CustomSupervisedRunner().run()
model = runner.model

---

### Act 5 - ``CustomLogger``

In [None]:
import io
import cv2
import numpy as np
import matplotlib.pyplot as plt

def get_img_from_fig(fig, dpi=180):
    buf = io.BytesIO()
    fig.savefig(buf, format="png", dpi=dpi)
    buf.seek(0)
    
    img_arr = np.frombuffer(buf.getvalue(), dtype=np.uint8)
    buf.close()
    img = cv2.imdecode(img_arr, 1)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    return img

In [None]:
# We need to add only a few lines to log the image to all runner's loggers
class VisualizationCallback(dl.Callback):
    def __init__(self):
        super().__init__(order=dl.CallbackOrder.External)

    def on_epoch_end(self, runner):
        image = visualize_decision_boundary(X_valid, y_valid, runner.model)
        image = get_img_from_fig(image)
        # runner will propagate it to all loggers
        runner.log_image(tag="decision_boundary", image=image, scope="epoch")


# Let's also add our own Logger to store image on the disk
class VisualizationLogger(dl.ILogger):
    def __init__(self, logdir: str):
        self.logdir = logdir
        os.makedirs(self.logdir, exist_ok=True)
        
    def log_image(
        self,
        tag: str,
        image: np.ndarray,
        scope: str = None,
        # experiment info
        experiment_key: str = None,
        global_epoch_step: int = 0,
        global_batch_step: int = 0,
        global_sample_step: int = 0,
        # stage info
        stage_key: str = None,
        stage_epoch_len: int = 0,
        stage_epoch_step: int = 0,
        stage_batch_step: int = 0,
        stage_sample_step: int = 0,
        # loader info
        loader_key: str = None,
        loader_batch_len: int = 0,
        loader_sample_len: int = 0,
        loader_batch_step: int = 0,
        loader_sample_step: int = 0,
    ) -> None:
        if scope == "epoch":
            plt.imsave(
                os.path.join(self.logdir, f"{tag}_{stage_key}_{stage_epoch_step}.png"),
                image,
            )


class CustomSupervisedRunner(dl.IRunner):
    def get_engine(self) -> dl.IEngine:
        return dl.DeviceEngine("cpu")
    
    def get_loggers(self):
        return {
            "console": dl.ConsoleLogger(),
            "visualization": VisualizationLogger(logdir="./logdir05/visualization"),
            "tensorboard": dl.TensorboardLogger(logdir="./logdir05/tb"),
        }
    
    @property
    def stages(self) -> Iterable[str]:
        return ["train"]
    
    def get_stage_len(self, stage: str) -> int:
        return 5
    
    def get_loaders(self, stage: str) -> "OrderedDict[str, DataLoader]":
        return loaders
    
    def get_model(self, stage: str):
        return nn.Sequential(
            nn.Linear(2, 16), nn.ReLU(), 
            nn.Linear(16, 16), nn.ReLU(), 
            nn.Linear(16, 1)
        )

    def get_criterion(self, stage: str):
        return nn.BCEWithLogitsLoss()

    def get_optimizer(self, stage: str, model):
        return torch.optim.Adam(model.parameters(), lr=0.02)

    def get_scheduler(self, stage: str, optimizer):
        return torch.optim.lr_scheduler.MultiStepLR(optimizer, [2, 4])
    
    def get_callbacks(self, stage: str):
        return {
            "criterion": dl.CriterionCallback(
                metric_key="loss", 
                input_key="logits", 
                target_key="targets"
            ), 
            "optimizer": dl.OptimizerCallback(metric_key="loss"), 
            "scheduler": dl.SchedulerCallback(
                loader_key="valid", metric_key="loss"
            ),
            "checkpoint": dl.CheckpointCallback(
                logdir="./logdir05/loss",
                loader_key="valid", metric_key="loss", 
                minimize=True, save_n_best=1
            ),
            "visualization": VisualizationCallback()
        }
    
    def handle_batch(self, batch):
        x, y = batch
        y_hat = self.model(x)
        
        self.batch = {
            "features": x,
            "targets": y,
            "logits": y_hat.view(-1),
            "scores": torch.sigmoid(y_hat.view(-1)),
        }

runner = CustomSupervisedRunner().run()
model = runner.model

In [None]:
! ls ./logdir05
! ls ./logdir05/loss
! ls ./logdir05/tb
! ls ./logdir05/visualization

### Act 6 - ``Multistage Run``

In [None]:
import io
import cv2
import numpy as np
import matplotlib.pyplot as plt

def get_img_from_fig(fig, dpi=180):
    buf = io.BytesIO()
    fig.savefig(buf, format="png", dpi=dpi)
    buf.seek(0)
    
    img_arr = np.frombuffer(buf.getvalue(), dtype=np.uint8)
    buf.close()
    img = cv2.imdecode(img_arr, 1)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    return img

In [None]:
loaders = {
    "stage_1": {
        "train_1": DataLoader(t1, batch_size=32, num_workers=1), 
        "valid": DataLoader(v1, batch_size=32, num_workers=1), 
    },
    "stage_2": {
        "train_2": DataLoader(t2, batch_size=32, num_workers=1), 
        "valid": DataLoader(v1, batch_size=32, num_workers=1), 
    },
}

    
class VisualizationCallback(dl.Callback):
    def __init__(self):
        super().__init__(order=dl.CallbackOrder.External)

    def on_epoch_end(self, runner):
        image = visualize_decision_boundary(X_valid, y_valid, runner.model)
        image = get_img_from_fig(image)
        # runner will propagate it to all loggers
        runner.log_image(tag="decision_boundary", image=image, scope="epoch")


class VisualizationLogger(dl.ILogger):
    def __init__(self, logdir: str):
        self.logdir = logdir
        os.makedirs(self.logdir, exist_ok=True)
        
    def log_image(
        self,
        tag: str,
        image: np.ndarray,
        scope: str = None,
        # experiment info
        experiment_key: str = None,
        global_epoch_step: int = 0,
        global_batch_step: int = 0,
        global_sample_step: int = 0,
        # stage info
        stage_key: str = None,
        stage_epoch_len: int = 0,
        stage_epoch_step: int = 0,
        stage_batch_step: int = 0,
        stage_sample_step: int = 0,
        # loader info
        loader_key: str = None,
        loader_batch_len: int = 0,
        loader_sample_len: int = 0,
        loader_batch_step: int = 0,
        loader_sample_step: int = 0,
    ) -> None:
        if scope == "epoch":
            plt.imsave(
                os.path.join(self.logdir, f"{tag}_{stage_key}_{stage_epoch_step}.png"),
                image,
            )

class CustomSupervisedRunner(dl.IRunner):
    def get_engine(self) -> dl.IEngine:
        return dl.DeviceEngine("cpu")
    
    def get_loggers(self):
        return {
            "console": dl.ConsoleLogger(),
            "visualization": VisualizationLogger(logdir="./logdir06/visualization"),
            "tensorboard": dl.TensorboardLogger(logdir="./logdir06/tb"),
        }
    
    @property
    def stages(self) -> Iterable[str]:
        return loaders.keys()
    
    def get_stage_len(self, stage: str) -> int:
        return 5
    
    def get_loaders(self, stage: str) -> "OrderedDict[str, DataLoader]":
        return loaders[stage]
    
    def get_model(self, stage: str):
        return nn.Sequential(
            nn.Linear(2, 16), nn.ReLU(), 
            nn.Linear(16, 16), nn.ReLU(), 
            nn.Linear(16, 1)
        )

    def get_criterion(self, stage: str):
        return nn.BCEWithLogitsLoss()

    def get_optimizer(self, stage: str, model):
        return torch.optim.Adam(model.parameters(), lr=0.02)

    def get_scheduler(self, stage: str, optimizer):
        return torch.optim.lr_scheduler.MultiStepLR(optimizer, [2, 4])
    
    def get_callbacks(self, stage: str):
        return {
            "auc": dl.LoaderMetricCallback(
                metric=metrics.AUCMetric(),
                input_key="scores", target_key="targets", 
            ), 
            "criterion": dl.CriterionCallback(
                metric_key="loss", 
                input_key="logits", 
                target_key="targets"
            ), 
            "optimizer": dl.OptimizerCallback(metric_key="loss"), 
            "scheduler": dl.SchedulerCallback(
                loader_key="valid", metric_key="loss"
            ),
            "checkpoint1": dl.CheckpointCallback(
                logdir="./logdir06/auc",
                loader_key="valid", metric_key="auc", 
                minimize=False, save_n_best=3
            ),
            "checkpoint2": dl.CheckpointCallback(
                logdir="./logdir06/loss",
                loader_key="valid", metric_key="loss", 
                minimize=True, save_n_best=1
            ),
            "visualization": VisualizationCallback(),
    #         "verbose": TqdmCallback(),

        }
    
    def handle_batch(self, batch):
        x, y = batch
        y_hat = self.model(x)
        
        self.batch = {
            "features": x,
            "targets": y,
            "logits": y_hat.view(-1),
            "scores": torch.sigmoid(y_hat.view(-1)),
        }

runner = CustomSupervisedRunner().run()
model = runner.model

In [None]:
! ls ./logdir06

In [None]:
_ = visualize_decision_boundary(X_valid, y_valid, model)

---

### Act 7 - ``CustomRunner``

In [None]:
class CustomSupervisedRunner(dl.IRunner):
    def get_engine(self) -> dl.IEngine:
        return dl.DeviceEngine("cpu")
    
    def get_loggers(self):
        return {
            "console": dl.ConsoleLogger(),
            "tensorboard": dl.TensorboardLogger(logdir="./logdir07/tb"),
        }
    
    @property
    def seed(self) -> int:
        return 73

    @property
    def name(self) -> str:
        return "experiment73"
    
    @property
    def stages(self) -> Iterable[str]:
        return ["stage_1", "stage_2"]
    
    def get_stage_len(self, stage: str) -> int:
        return 5
    
    def get_loaders(self, stage: str) -> "OrderedDict[str, DataLoader]":
        if stage == "stage_1":
            return {
                "train_1": DataLoader(t1, batch_size=32, num_workers=1), 
                "valid": DataLoader(v1, batch_size=32, num_workers=1), 
            }
        elif stage == "stage_2":
            return {
                "train_2": DataLoader(t2, batch_size=32, num_workers=1), 
                "valid": DataLoader(v1, batch_size=32, num_workers=1), 
            }
        else:
            raise NotImplemented()
    
    def get_model(self, stage: str):
        if self.model is not None:
            return self.model
        return nn.Sequential(
            nn.Linear(2, 16), nn.ReLU(), 
            nn.Linear(16, 16), nn.ReLU(), 
            nn.Linear(16, 1)
        )

    def get_criterion(self, stage: str):
        return nn.BCEWithLogitsLoss()

    def get_optimizer(self, stage: str, model):
        if stage == "stage_1":
            return torch.optim.Adam(model.parameters(), lr=0.02)
        elif stage == "stage_2":
            return torch.optim.SGD(model.parameters(), lr=0.01)
        else:
            raise NotImplemented()

    def get_scheduler(self, stage: str, optimizer):
        if stage == "stage_1":
            return torch.optim.lr_scheduler.MultiStepLR(optimizer, [3, 8])
        elif stage == "stage_2":
            return torch.optim.lr_scheduler.MultiStepLR(optimizer, [3, 6])
        else:
            raise NotImplemented()
    
    def get_callbacks(self, stage: str):
        if stage == "stage_1":
            return {
                "criterion": dl.CriterionCallback(
                    metric_key="loss", 
                    input_key="logits", 
                    target_key="targets"
                ), 
                "optimizer": dl.OptimizerCallback(metric_key="loss"), 
                "scheduler": dl.SchedulerCallback(
                    loader_key="valid", metric_key="loss"
                ),
                "checkpoint": dl.CheckpointCallback(
                    logdir="./logdir07/loss",
                    loader_key="valid", metric_key="loss", 
                    minimize=True, save_n_best=3
                ),
            }
        elif stage == "stage_2":
            return {
                "auc": dl.LoaderMetricCallback(
                    metric=metrics.AUCMetric(),
                    input_key="scores", target_key="targets", 
                ), 
                "criterion": dl.CriterionCallback(
                    metric_key="loss", 
                    input_key="logits", 
                    target_key="targets"
                ), 
                "optimizer": dl.OptimizerCallback(metric_key="loss"), 
                "scheduler": dl.SchedulerCallback(
                    loader_key="valid", metric_key="loss"
                ),
                "checkpoint_auc": dl.CheckpointCallback(
                    logdir="./logdir07/auc",
                    loader_key="valid", metric_key="auc", 
                    minimize=False, save_n_best=3
                ),
            }
        else:
            raise NotImplemented()
    
    def handle_batch(self, batch):
        x, y = batch
        y_hat = self.model(x)
        
        self.batch = {
            "features": x,
            "targets": y,
            "logits": y_hat.view(-1),
            "scores": torch.sigmoid(y_hat.view(-1)),
        }

runner = CustomSupervisedRunner().run()
model = runner.model

---

### Act 8 - integration with hyperparameter search

In [None]:
from datetime import datetime
import optuna    

def objective(trial):
    num_epochs = 6
    num_hidden1 = int(trial.suggest_loguniform("num_hidden1", 2, 16))
    num_hidden2 = int(trial.suggest_loguniform("num_hidden2", 2, 16))
    logdir = f"./logdir08/{datetime.now().strftime('%Y%m%d-%H%M%S')}"
    
    loaders = {
        "train_1": DataLoader(t1, batch_size=32, num_workers=1), 
        "train_2": DataLoader(t2, batch_size=32, num_workers=1), 
        "valid": DataLoader(v1, batch_size=32, num_workers=1), 
    }

    class CustomRunner(dl.IRunner):
        def get_trial(self):
            return trial

        def get_engine(self) -> dl.IEngine:
            return dl.DeviceEngine("cpu")

        def get_loggers(self):
            return {
                "console": dl.ConsoleLogger(),
                "tensorboard": dl.TensorboardLogger(logdir=f"{logdir}/tb"),
            }

        @property
        def seed(self) -> int:
            return 73

        @property
        def name(self) -> str:
            return "experiment73"

        @property
        def stages(self) -> Iterable[str]:
            return ["stage_1", "stage_2"]

        def get_stage_len(self, stage: str) -> int:
            return num_epochs

        def get_loaders(self, stage: str) -> "OrderedDict[str, DataLoader]":
            return loaders

        def get_model(self, stage: str):
            return nn.Sequential(
                nn.Linear(2, num_hidden1), nn.ReLU(), 
                nn.Linear(num_hidden1, num_hidden2), nn.ReLU(), 
                nn.Linear(num_hidden2, 1)
            )

        def get_criterion(self, stage: str):
            return nn.BCEWithLogitsLoss()

        def get_optimizer(self, stage: str, model):
            return torch.optim.Adam(model.parameters(), lr=0.02)

        def get_scheduler(self, stage: str, optimizer):
            return torch.optim.lr_scheduler.MultiStepLR(optimizer, [3, 6])

        def get_callbacks(self, stage: str):
            return {
                "auc": dl.LoaderMetricCallback(
                    metric=metrics.AUCMetric(),
                    input_key="scores", target_key="targets", 
                ), 
                "criterion": dl.CriterionCallback(
                    metric_key="loss", 
                    input_key="logits", 
                    target_key="targets"
                ), 
                "optimizer": dl.OptimizerCallback(metric_key="loss"), 
                "scheduler": dl.SchedulerCallback(
                    loader_key="valid", metric_key="loss"
                ),
                "checkpoint": dl.CheckpointCallback(
                    logdir=f"{logdir}/auc",
                    loader_key="valid", metric_key="auc", 
                    minimize=False, save_n_best=3
                ),
                "optuna": dl.OptunaPruningCallback(loader_key="valid", metric_key="auc", minimize=False)
            }

        def handle_batch(self, batch):
            x, y = batch
            y_hat = self.model(x)

            self.batch = {
                "features": x,
                "targets": y,
                "logits": y_hat.view(-1),
                "scores": torch.sigmoid(y_hat.view(-1)),
            }

    runner = CustomRunner()
    runner.run()
    score = runner.callbacks["optuna"].best_score
    
    return score

study = optuna.create_study(
    direction="maximize",
#     direction="minimize",
    pruner=optuna.pruners.MedianPruner(
        n_startup_trials=0, n_warmup_steps=0, interval_steps=1
    ),
)
study.optimize(objective, n_trials=5, timeout=300)
print(study.best_value, study.best_params)

---

### Act 9 - Confusion Matrix logging - IMetric+ICallback+ILogger

In [None]:
import torch
from torch.utils.data import DataLoader, TensorDataset
from catalyst import dl, metrics, utils

# sample data
num_samples, num_features, num_classes = int(1e4), int(1e1), 6
num_epochs = 6

class CustomSupervisedRunner(dl.IRunner):
    def get_engine(self) -> dl.IEngine:
        return dl.DeviceEngine("cpu")
    
    def get_loggers(self):
        return {
            "console": dl.ConsoleLogger(),
            "csv": dl.CSVLogger(logdir="./logdir09"),
            "tensorboard": dl.TensorboardLogger(logdir="./logdir09/tb"),
        }
    
    @property
    def stages(self) -> Iterable[str]:
        return ["train"]
    
    def get_stage_len(self, stage: str) -> int:
        return num_epochs
    
    def get_loaders(self, stage: str) -> "OrderedDict[str, DataLoader]":
        # sample data
        num_samples, num_features, num_classes = int(1e4), int(1e1), 6
        X = torch.rand(num_samples, num_features)
        y = (torch.rand(num_samples, ) * num_classes).to(torch.int64)

        # pytorch loaders
        dataset = TensorDataset(X, y)
        loader = DataLoader(dataset, batch_size=32, num_workers=1)
        loaders = {"train": loader, "valid": loader}
        return loaders
    
    def get_model(self, stage: str):
        return torch.nn.Linear(num_features, num_classes)

    def get_criterion(self, stage: str):
        return torch.nn.CrossEntropyLoss()

    def get_optimizer(self, stage: str, model):
        return torch.optim.Adam(model.parameters())

    def get_scheduler(self, stage: str, optimizer):
        return torch.optim.lr_scheduler.MultiStepLR(optimizer, [2])
    
    def get_callbacks(self, stage: str):
        return {
            "accuracy": dl.BatchMetricCallback(
                metric=metrics.AccuracyMetric(num_classes=num_classes),
                input_key="probs", target_key="targets", 
            ),
            "auc": dl.LoaderMetricCallback(
                metric=metrics.AUCMetric(),
                input_key="scores", target_key="targets", 
            ), 
            "criterion": dl.CriterionCallback(
                metric_key="loss", 
                input_key="logits", 
                target_key="targets",
            ), 
            "optimizer": dl.OptimizerCallback(metric_key="loss"), 
            "scheduler": dl.SchedulerCallback(
                loader_key="valid", metric_key="loss"
            ),
            "checkpoint1": dl.CheckpointCallback(
                logdir="./logdir09/loss",
                loader_key="valid", metric_key="loss", 
                minimize=False, save_n_best=3
            ),
            "checkpoint2": dl.CheckpointCallback(
                logdir="./logdir09/auc",
                loader_key="valid", metric_key="auc", 
                minimize=True, save_n_best=1
            ),
            "checkpoint3": dl.CheckpointCallback(
                logdir="./logdir9/accuracy",
                loader_key="valid", metric_key="accuracy", 
                minimize=True, save_n_best=1
            ),
            "verbose": dl.TqdmCallback(),
            "confusion_matrix": dl.ConfusionMatrixCallback(
                input_key="probs", 
                target_key="targets",
                prefix="confusion_matrix",
                num_classes=num_classes,
            )
        }
    
    def handle_batch(self, batch):
        x, y = batch
        y_hat = self.model(x)
        
        self.batch = {
            "features": x,
            "targets": y,
            "logits": y_hat,
            "scores": torch.sigmoid(y_hat),
            "probs": torch.softmax(y_hat, dim=1),
        }

runner = CustomSupervisedRunner().run()
model = runner.model

---

In [None]:
### Act 10 - @TODO

---

## Stage 2: PythonAPI is all u need
- 10 minimal examples with different Catalyst user-friendly PythonAPI usecases

In [None]:
# let's start minimal examples section
from catalyst import dl, metrics, utils

### Act 11 - ML - linear regression

In [None]:
import torch
from torch.utils.data import DataLoader, TensorDataset
from catalyst import dl


# data
num_samples, num_features = int(1e4), int(1e1)
X, y = torch.rand(num_samples, num_features), torch.rand(num_samples)
dataset = TensorDataset(X, y)
loader = DataLoader(dataset, batch_size=32, num_workers=1)
loaders = {"train": loader, "valid": loader}

# model, criterion, optimizer, scheduler
model = torch.nn.Linear(num_features, 1)
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters())
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [3, 6])

# model training
runner = dl.SupervisedRunner(
    input_key="features", output_key="logits", target_key="targets"
)
runner.train(
    model=model,
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler,
    loaders=loaders,
    logdir="./logdir11",
    num_epochs=8,
    verbose=True,
    valid_loader="valid",
    valid_metric="loss",
    minimize_valid_metric=True,
)

---

### Act 12 - ML - multiclass classification

In [None]:
import torch
from torch.utils.data import DataLoader, TensorDataset
from catalyst import dl, metrics, utils

# sample data
num_samples, num_features, num_classes = int(1e4), int(1e1), 4
X = torch.rand(num_samples, num_features)
y = (torch.rand(num_samples, ) * num_classes).to(torch.int64)

# pytorch loaders
dataset = TensorDataset(X, y)
loader = DataLoader(dataset, batch_size=32, num_workers=1)
loaders = {"train": loader, "valid": loader}

# model, criterion, optimizer, scheduler
model = torch.nn.Linear(num_features, num_classes)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [2])

# model training
runner = dl.SupervisedRunner(
    input_key="features", output_key="logits", target_key="targets"
)
runner.train(
    model=model,
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler,
    loaders=loaders,
    logdir="./logdir12",
    num_epochs=6,
    verbose=True,
    valid_loader="valid",
    valid_metric="loss",
    minimize_valid_metric=True,
    callbacks=[dl.AccuracyCallback(input_key="logits", target_key="targets", num_classes=num_classes)]
#     callbacks={
#         "classification": dl.BatchMetricCallback(
#             metric=metrics.MulticlassPrecisionRecallF1SupportMetric(num_classes=num_classes),
#             input_key="logits", target_key="targets", 
#         ),
#     },
)

----

### Act 13 - ML - multilabel classification

In [None]:
import torch
from torch.utils.data import DataLoader, TensorDataset
from catalyst import dl

# sample data
num_samples, num_features, num_classes = int(1e4), int(1e1), 4
X = torch.rand(num_samples, num_features)
y = (torch.rand(num_samples, num_classes) > 0.5).to(torch.float32)

# pytorch loaders
dataset = TensorDataset(X, y)
loader = DataLoader(dataset, batch_size=32, num_workers=1)
loaders = {"train": loader, "valid": loader}

# model, criterion, optimizer, scheduler
model = torch.nn.Linear(num_features, num_classes)
criterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters())
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [2])

# model training
runner = dl.SupervisedRunner(
    input_key="features", output_key="logits", target_key="targets"
)
runner.train(
    model=model,
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler,
    loaders=loaders,
    logdir="./logdir",
    num_epochs=3,
    valid_loader="valid",
    valid_metric="loss",
    minimize_valid_metric=True,
    callbacks={
        "classification": dl.BatchMetricCallback(
            metric=metrics.MultilabelPrecisionRecallF1SupportMetric(num_classes=num_classes),
            input_key="logits", target_key="targets", 
        ),
    },
)

---

In [None]:
### Act 14 - CV - MNIST classification

In [None]:
### Act 15 - CV - classification with AutoEncoder

In [None]:
### Act 16 - CV - classification with Variational AutoEncoder

### Act 17 - CV - segmentation with classification auxiliary task

In [None]:
import os
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from catalyst import dl, metrics
from catalyst.data.transforms import ToTensor
from catalyst.contrib.datasets import MNIST

class ClassifyUnet(nn.Module):

    def __init__(self, in_channels, in_hw, out_features):
        super().__init__()
        self.encoder = nn.Sequential(nn.Conv2d(in_channels, in_channels, 3, 1, 1), nn.Tanh())
        self.decoder = nn.Conv2d(in_channels, in_channels, 3, 1, 1)
        self.clf = nn.Linear(in_channels * in_hw * in_hw, out_features)

    def forward(self, x):
        z = self.encoder(x)
        z_ = z.view(z.size(0), -1)
        y_hat = self.clf(z_)
        x_ = self.decoder(z)
        return y_hat, x_

model = ClassifyUnet(1, 28, 10)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.02)

loaders = {
    "train": DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=ToTensor()), batch_size=32),
    "valid": DataLoader(MNIST(os.getcwd(), train=False, download=True, transform=ToTensor()), batch_size=32),
}

class CustomRunner(dl.Runner):

    def handle_batch(self, batch):
        x, y = batch
        x_noise = (x + torch.rand_like(x)).clamp_(0, 1)
        y_hat, x_ = self.model(x_noise)
        
        self.batch = {
#             "image": x,
            "clf_targets": y,
            "seg_targets": x,
            "clf_logits": y_hat,
            "seg_logits": x_,
        }


runner = CustomRunner()
runner.train(
    loaders=loaders, 
    model=model, 
    criterion=criterion,
    optimizer=optimizer, 
    logdir="./logdir14",
    num_epochs=6,
    verbose=True,
    valid_loader="valid",
    valid_metric="loss",
    minimize_valid_metric=True,
    callbacks={
        "classification": dl.BatchMetricCallback(
            metric=metrics.MulticlassPrecisionRecallF1SupportMetric(num_classes=10),
            input_key="clf_logits", target_key="clf_targets", 
        ),
        "segmentation": dl.BatchMetricCallback(
            metric=metrics.IOUMetric(),
            input_key="seg_logits", target_key="seg_targets", 
        ),
        "criterion": dl.CriterionCallback(
            metric_key="loss", 
            input_key="clf_logits", 
            target_key="clf_targets",
        ), 
        "optimizer": dl.OptimizerCallback(metric_key="loss"), 
    },
)

---

### Act 18 - CV - MNIST with Metric Learning

In [None]:
# from torch.optim import Adam
# from torch.utils.data import DataLoader

# from catalyst import data, dl, utils
# from catalyst.contrib import datasets, models, nn
# import catalyst.contrib.data.cv.transforms.torch as t


# # 1. train and valid datasets
# dataset_root = "."
# transforms = t.Compose([t.ToTensor(), t.Normalize((0.1307,), (0.3081,))])

# dataset_train = datasets.MnistMLDataset(root=dataset_root, download=True, transform=transforms)
# sampler = data.BalanceBatchSampler(labels=dataset_train.get_labels(), p=5, k=10)
# train_loader = DataLoader(dataset=dataset_train, sampler=sampler, batch_size=sampler.batch_size)

# dataset_val = datasets.MnistQGDataset(root=dataset_root, transform=transforms, gallery_fraq=0.2)
# val_loader = DataLoader(dataset=dataset_val, batch_size=1024)

# # 2. model and optimizer
# model = models.SimpleConv(features_dim=16)
# optimizer = Adam(model.parameters(), lr=0.001)

# # 3. criterion with triplets sampling
# sampler_inbatch = data.HardTripletsSampler(norm_required=False)
# criterion = nn.TripletMarginLossWithSampler(margin=0.5, sampler_inbatch=sampler_inbatch)

# # 4. training with catalyst Runner
# callbacks = [
#     dl.ControlFlowCallback(
#         dl.CriterionCallback(input_key="logits", target_key="targets", metric_key="loss"), 
#         loaders="train"
#     ),
#     dl.ControlFlowCallback(dl.CMCScoreCallback(topk_args=[1]), loaders="valid"),
#     dl.PeriodicLoaderCallback(valid=100),
# ]

# runner = dl.SupervisedRunner(
#     input_key="features", output_key="logits", target_key="targets"
# )
# runner.train(
#     model=model,
#     criterion=criterion,
#     optimizer=optimizer,
#     callbacks=callbacks,
#     loaders={"train": train_loader, "valid": val_loader},
#     minimize_metric=False,
#     verbose=True,
#     valid_loader="valid",
#     num_epochs=200,
#     main_metric="cmc01",
# )   

---

### Act 19 - GAN - MNIST, flatten version

In [None]:
import os
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from catalyst import dl
from catalyst.data.transforms import ToTensor
from catalyst.contrib.datasets import MNIST
from catalyst.contrib.nn.modules import Flatten, GlobalMaxPool2d, Lambda

latent_dim = 128
generator = nn.Sequential(
    # We want to generate 128 coefficients to reshape into a 7x7x128 map
    nn.Linear(128, 128 * 7 * 7),
    nn.LeakyReLU(0.2, inplace=True),
    Lambda(lambda x: x.view(x.size(0), 128, 7, 7)),
    nn.ConvTranspose2d(128, 128, (4, 4), stride=(2, 2), padding=1),
    nn.LeakyReLU(0.2, inplace=True),
    nn.ConvTranspose2d(128, 128, (4, 4), stride=(2, 2), padding=1),
    nn.LeakyReLU(0.2, inplace=True),
    nn.Conv2d(128, 1, (7, 7), padding=3),
    nn.Sigmoid(),
)
discriminator = nn.Sequential(
    nn.Conv2d(1, 64, (3, 3), stride=(2, 2), padding=1),
    nn.LeakyReLU(0.2, inplace=True),
    nn.Conv2d(64, 128, (3, 3), stride=(2, 2), padding=1),
    nn.LeakyReLU(0.2, inplace=True),
    GlobalMaxPool2d(),
    Flatten(),
    nn.Linear(128, 1)
)

model = {"generator": generator, "discriminator": discriminator}
optimizer = {
    "generator": torch.optim.Adam(generator.parameters(), lr=0.0003, betas=(0.5, 0.999)),
    "discriminator": torch.optim.Adam(discriminator.parameters(), lr=0.0003, betas=(0.5, 0.999)),
}
loaders = {
    "train": DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=ToTensor()), batch_size=32),
}

class CustomRunner(dl.Runner):

    def handle_batch(self, batch):
        real_images, _ = batch
        batch_metrics = {}
        
        # Sample random points in the latent space
        batch_size = real_images.shape[0]
        random_latent_vectors = torch.randn(batch_size, latent_dim).to(self.device)
        
        # Decode them to fake images
        generated_images = self.model["generator"](random_latent_vectors).detach()
        # Combine them with real images
        combined_images = torch.cat([generated_images, real_images])
        
        # Assemble labels discriminating real from fake images
        labels = torch.cat([
            torch.ones((batch_size, 1)), torch.zeros((batch_size, 1))
        ]).to(self.device)
        # Add random noise to the labels - important trick!
        labels += 0.05 * torch.rand(labels.shape).to(self.device)
        
        # Train the discriminator
        predictions = self.model["discriminator"](combined_images)
        batch_metrics["loss_discriminator"] = \
          F.binary_cross_entropy_with_logits(predictions, labels)
        
        # Sample random points in the latent space
        random_latent_vectors = torch.randn(batch_size, latent_dim).to(self.device)
        # Assemble labels that say "all real images"
        misleading_labels = torch.zeros((batch_size, 1)).to(self.device)
        
        # Train the generator
        generated_images = self.model["generator"](random_latent_vectors)
        predictions = self.model["discriminator"](generated_images)
        batch_metrics["loss_generator"] = \
          F.binary_cross_entropy_with_logits(predictions, misleading_labels)
        
        self.batch_metrics.update(**batch_metrics)

runner = CustomRunner()
runner.train(
    model=model, 
    optimizer=optimizer,
    loaders=loaders,
    callbacks=[
        dl.OptimizerCallback(
            model_key="generator",
            optimizer_key="generator", 
            metric_key="loss_generator"
        ),
        dl.OptimizerCallback(
            model_key="discriminator", 
            optimizer_key="discriminator", 
            metric_key="loss_discriminator"
        ),
    ],
#     valid_loader="train",
#     valid_metric="loss_generator",
#     minimize_valid_metric=True,
    num_epochs=1,
    verbose=True,
#     logdir="./logdir19",
)

---

### Act 20 - AutoML - hyperparameters optimization with Optuna

In [None]:
import os
import optuna
import torch
from torch import nn
from torch.utils.data import DataLoader
from catalyst import dl
from catalyst.data.transforms import ToTensor
from catalyst.contrib.datasets import MNIST
from catalyst.contrib.nn import Flatten
    

def objective(trial):
    lr = trial.suggest_loguniform("lr", 1e-3, 1e-1)
    num_hidden = int(trial.suggest_loguniform("num_hidden", 32, 128))

    loaders = {
        "train": DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=ToTensor()), batch_size=32),
        "valid": DataLoader(MNIST(os.getcwd(), train=False, download=True, transform=ToTensor()), batch_size=32),
    }
    model = nn.Sequential(
        Flatten(), nn.Linear(784, num_hidden), nn.ReLU(), nn.Linear(num_hidden, 10)
    )
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()

    runner = dl.SupervisedRunner(
        input_key="features", output_key="logits", target_key="targets"
    )
    runner.train(
        model=model,
        loaders=loaders,
        criterion=criterion,
        optimizer=optimizer,
        callbacks={
            "optuna": dl.OptunaPruningCallback(loader_key="valid", metric_key="accuracy01", minimize=False, trial=trial),
            "accuracy": dl.AccuracyCallback(input_key="logits", target_key="targets", num_classes=10),
        },
        num_epochs=3,
#         valid_loader="valid",
#         valid_metric="accuracy01",
#         minimize_valid_metric=False,
    )
    score = runner.callbacks["optuna"].best_score
    return score

study = optuna.create_study(
    direction="maximize",
    pruner=optuna.pruners.MedianPruner(
        n_startup_trials=1, n_warmup_steps=0, interval_steps=1
    ),
)
study.optimize(objective, n_trials=3, timeout=300)
print(study.best_value, study.best_params)

----

## Stage 3: Offpolicy Reinforcement Learning

In [None]:
from collections import deque, namedtuple
import typing as tp
import random

import numpy as np
import gym

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.data.dataset import IterableDataset

from catalyst import dl, metrics, utils

In [None]:
Transition = namedtuple(
    'Transition', 
    field_names=[
        'state', 
        'action', 
        'reward',
        'done', 
        'next_state'
    ]
)


class ReplayBuffer:
    def __init__(self, capacity: int):
        self.buffer = deque(maxlen=capacity)
    
    def append(self, transition: Transition):
        self.buffer.append(transition)
    
    def sample(self, size: int) -> tp.Sequence[np.array]:
        indices = np.random.choice(
            len(self.buffer),
            size,
            replace=size > len(self.buffer)
        )
        states, actions, rewards, dones, next_states = \
            zip(*[self.buffer[idx] for idx in indices])
        states, actions, rewards, dones, next_states = (
            np.array(states, dtype=np.float32),
            np.array(actions, dtype=np.int64),
            np.array(rewards, dtype=np.float32),
            np.array(dones, dtype=np.bool),
            np.array(next_states, dtype=np.float32)
        )
        return states, actions, rewards, dones, next_states

    def __len__(self) -> int:
        return len(self.buffer)


# as far as RL does not have some predefined dataset, 
# we need to specify epoch lenght by ourselfs
class ReplayDataset(IterableDataset):

    def __init__(self, buffer: ReplayBuffer, epoch_size: int = int(1e3)):
        self.buffer = buffer
        self.epoch_size = epoch_size

    def __iter__(self) -> tp.Iterator[tp.Sequence[np.array]]:
        states, actions, rewards, dones, next_states = \
            self.buffer.sample(self.epoch_size)
        for i in range(len(dones)):
            yield states[i], actions[i], rewards[i], dones[i], next_states[i]

    def __len__(self) -> int:
        return self.epoch_size
    
    
def soft_update(target: nn.Module, source: nn.Module, tau: float):
    """Updates the target data with smoothing by ``tau``"""
    for target_param, param in zip(target.parameters(), source.parameters()):
        target_param.data.copy_(
            target_param.data * (1.0 - tau) + param.data * tau
        )

### Act 21 - DQN

In [None]:
def get_action(
    env,
    network: nn.Module,
    state: np.array,
    epsilon: float = -1
) ->  int:
    if np.random.random() < epsilon:
        action = env.action_space.sample()
    else:
        state = torch.tensor(state[None], dtype=torch.float32)
        q_values = network(state).detach().cpu().numpy()[0]
        action = np.argmax(q_values)

    return int(action)


def generate_session(
    env,
    network: nn.Module,
    t_max: int = 1000,
    epsilon: float = -1,
    replay_buffer: tp.Optional[ReplayBuffer] = None,
) -> tp.Tuple[float, int]:
    total_reward = 0
    state = env.reset()

    for t in range(t_max):
        action = get_action(env, network, state=state, epsilon=epsilon)
        next_state, reward, done, _ = env.step(action)

        if replay_buffer is not None:
            transition = Transition(
                state, action, reward, done, next_state)
            replay_buffer.append(transition)

        total_reward += reward
        state = next_state
        if done:
            break

    return total_reward, t

def generate_sessions(
    env,
    network: nn.Module,
    t_max: int = 1000,
    epsilon:float = -1,
    replay_buffer: ReplayBuffer = None,
    num_sessions: int = 100,
) -> tp.Tuple[float, int]:
    sessions_reward, sessions_steps = 0, 0
    for i_episone in range(num_sessions):
        r, t = generate_session(
            env=env, 
            network=network,
            t_max=t_max,
            epsilon=epsilon,
            replay_buffer=replay_buffer,
        )
        sessions_reward += r
        sessions_steps += t
    return sessions_reward, sessions_steps

In [None]:
class GameCallback(dl.Callback):
    
    def __init__(
        self, 
        *, 
        env, 
        replay_buffer: ReplayBuffer,
        session_period: int,
        epsilon: float,
        epsilon_k: int,
        actor_key,
    ):
        super().__init__(order=0)
        self.env = env
        self.replay_buffer = replay_buffer
        self.session_period = session_period
        self.epsilon = epsilon
        self.epsilon_k = epsilon_k
        self.actor_key = actor_key
        self._initialized = False
        

    def on_epoch_start(self, runner: dl.IRunner):
        self.epsilon *= self.epsilon_k
        self.session_counter = 0
        self.session_steps = 0
        
        if self._initialized:
            return
        
        self.actor = runner.model[self.actor_key]
        
        self.actor.eval()
        generate_sessions(
            env=self.env, 
            network=self.actor,
            epsilon=self.epsilon,
            replay_buffer=self.replay_buffer,
            num_sessions=1000,
        )
        self.actor.train()
        self._initialized = True
    
    def on_batch_end(self, runner: dl.IRunner):
        if runner.global_batch_step % self.session_period == 0:
            self.actor.eval()
            
            session_reward, session_steps = generate_session(
                env=self.env, 
                network=self.actor,
                epsilon=self.epsilon,
                replay_buffer=self.replay_buffer
            )

            self.session_counter += 1
            self.session_steps += session_steps

            runner.batch_metrics.update({"s_reward": session_reward})
            runner.batch_metrics.update({"s_steps": session_steps})
            
            self.actor.train()

    def on_epoch_end(self, runner: dl.IRunner):
        num_sessions = 100
        
        self.actor.eval()
        valid_rewards, valid_steps = generate_sessions(
            env=self.env, 
            network=self.actor,
            num_sessions=num_sessions
        )
        self.actor.train()
        
        valid_rewards /= num_sessions
        runner.epoch_metrics["_epoch_"]["num_samples"] = self.session_steps
        runner.epoch_metrics["_epoch_"]["updates_per_sample"] = \
            runner.loader_sample_step / self.session_steps
        runner.epoch_metrics["_epoch_"]["v_reward"] = valid_rewards
        runner.epoch_metrics["_epoch_"]["epsilon"] = self.epsilon

In [None]:
def get_network(env, num_hidden: int = 128):
    inner_fn = utils.get_optimal_inner_init(nn.ReLU)
    outer_fn = utils.outer_init

    network = torch.nn.Sequential(
        nn.Linear(env.observation_space.shape[0], num_hidden),
        nn.ReLU(),
        nn.Linear(num_hidden, num_hidden),
        nn.ReLU(),
    )
    head = nn.Linear(num_hidden, env.action_space.n)

    network.apply(inner_fn)
    head.apply(outer_fn)

    return torch.nn.Sequential(network, head)


class CustomRunner(dl.Runner):
    
    def __init__(
        self, 
        *, 
        gamma: float,
        tau: float,
        tau_period: int = 1,
        origin_key: str = "origin",
        target_key: str = "target",
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.gamma: float = gamma
        self.tau: float = tau
        self.tau_period: int = tau_period
        self.origin_key: str = origin_key
        self.target_key: str = target_key
        self.origin_network: nn.Module = None
        self.target_network: nn.Module = None
        self._initialized = False
    
    def on_stage_start(self, runner: dl.IRunner):
        super().on_stage_start(runner)
        if self._initialized:
            return
        self.origin_network = self.model[self.origin_key]
        self.target_network = self.model[self.target_key]
        soft_update(self.target_network, self.origin_network, 1.0)

    def handle_batch(self, batch: tp.Sequence[np.array]):
        # model train/valid step
        states, actions, rewards, dones, next_states = batch
        network, target_network = self.origin_network, self.target_network

        # get q-values for all actions in current states
        state_qvalues = network(states)
        # select q-values for chosen actions
        state_action_qvalues = \
            state_qvalues.gather(1, actions.unsqueeze(-1)).squeeze(-1)
        
        # compute q-values for all actions in next states
        # compute V*(next_states) using predicted next q-values
        # at the last state we shall use simplified formula: 
        # Q(s,a) = r(s,a) since s' doesn't exist
        with torch.no_grad():
            next_state_qvalues = target_network(next_states)
            next_state_values = next_state_qvalues.max(1)[0]
            next_state_values[dones] = 0.0
            next_state_values = next_state_values.detach()

        # compute "target q-values" for loss, 
        # it's what's inside square parentheses in the above formula.
        target_state_action_qvalues = \
            next_state_values * self.gamma + rewards

        # mean squared error loss to minimize
        loss = self.criterion(
            state_action_qvalues,
            target_state_action_qvalues.detach()
        )
        self.batch_metrics.update({"loss": loss})

        if self.is_train_loader:
            loss.backward()
            self.optimizer.step()
            self.optimizer.zero_grad()

            if self.global_batch_step % self.tau_period == 0:
                soft_update(target_network, network, self.tau)


In [None]:
batch_size = 64
epoch_size = int(1e3) * batch_size
buffer_size = int(1e5)
# runner settings, ~training
gamma = 0.99
tau = 0.01
tau_period = 1 # in batches
# callback, ~exploration
session_period = 100 # in batches
epsilon = 0.98
epsilon_k = 0.9
# optimization
lr = 3e-4

# env_name = "LunarLander-v2"
env_name = "CartPole-v1"
env = gym.make(env_name)
replay_buffer = ReplayBuffer(buffer_size)

network, target_network = get_network(env), get_network(env)
utils.set_requires_grad(target_network, requires_grad=False)
models = {"origin": network, "target": target_network}
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(network.parameters(), lr=lr)
loaders = {
    "train_game": DataLoader(
        ReplayDataset(replay_buffer, epoch_size=epoch_size), 
        batch_size=batch_size,
    ),
}

runner = CustomRunner(gamma=gamma, tau=tau, tau_period=tau_period)
runner.train(
    model=models,
    criterion=criterion,
    optimizer=optimizer,
    loaders=loaders,
    logdir="./logs_dqn",
    num_epochs=10,
    verbose=True,
    valid_loader="_epoch_",
    valid_metric="v_reward",
    minimize_valid_metric=False,
    load_best_on_end=True,
    callbacks=[
        GameCallback(
            env=env, 
            replay_buffer=replay_buffer, 
            session_period=session_period,
            epsilon=epsilon,
            epsilon_k=epsilon_k,
            actor_key="origin",
        )
    ]
)

In [None]:
# record sessions
import gym.wrappers


env = gym.wrappers.Monitor(
    gym.make(env_name),
    directory="videos_dqn", 
    force=True)
generate_sessions(
    env=env, 
    network=runner.model["origin"],
    num_sessions=100
)
env.close()

In [None]:
# show video
from IPython.display import HTML
import os

video_names = list(
    filter(lambda s: s.endswith(".mp4"), os.listdir("./videos_dqn/")))

HTML("""
<video width="640" height="480" controls>
  <source src="{}" type="video/mp4">
</video>
""".format("./videos/"+video_names[-1]))  # this may or may not be _last_ video. Try other indices

----

### Act 22 - DDPG

In [None]:
class NormalizedActions(gym.ActionWrapper):

    def action(self, action: float) -> float:
        low_bound   = self.action_space.low
        upper_bound = self.action_space.high
        
        action = low_bound + (action + 1.0) * 0.5 * (upper_bound - low_bound)
        action = np.clip(action, low_bound, upper_bound)
        
        return action

    def _reverse_action(self, action: float) -> float:
        low_bound   = self.action_space.low
        upper_bound = self.action_space.high
        
        action = 2 * (action - low_bound) / (upper_bound - low_bound) - 1
        action = np.clip(action, low_bound, upper_bound)
        
        return action

In [None]:
def get_action(
    env,
    network: nn.Module,
    state: np.array,
    sigma: tp.Optional[float] = None
) -> np.array:
    state = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
    action = network(state).detach().cpu().numpy()[0]
    if sigma is not None:
        action = np.random.normal(action, sigma)
    return action


def generate_session(
    env,
    network: nn.Module,
    sigma: tp.Optional[float] = None,
    replay_buffer: tp.Optional[ReplayBuffer] = None,
) -> tp.Tuple[float, int]:
    total_reward = 0
    state = env.reset()

    for t in range(env.spec.max_episode_steps):
        action = get_action(env, network, state=state, sigma=sigma)
        next_state, reward, done, _ = env.step(action)

        if replay_buffer is not None:
            transition = Transition(
                state, action, reward, done, next_state)
            replay_buffer.append(transition)

        total_reward += reward
        state = next_state
        if done:
            break

    return total_reward, t


def generate_sessions(
    env,
    network: nn.Module,
    sigma: tp.Optional[float] = None,
    replay_buffer: tp.Optional[ReplayBuffer] = None,
    num_sessions: int = 100,
) -> tp.Tuple[float, int]:
    sessions_reward, sessions_steps = 0, 0
    for i_episone in range(num_sessions):
        r, t = generate_session(
            env=env, 
            network=network,
            sigma=sigma,
            replay_buffer=replay_buffer,
        )
        sessions_reward += r
        sessions_steps += t
    return sessions_reward, sessions_steps

In [None]:
class GameCallback(dl.Callback):
    def __init__(
        self, 
        *,
        env, 
        replay_buffer: ReplayBuffer,
        session_period: int,
        sigma: float,
        actor_key: str,
    ):
        super().__init__(order=0)
        self.env = env
        self.replay_buffer = replay_buffer
        self.session_period = session_period
        self.sigma = sigma
        self.actor_key = actor_key
        
    def on_stage_start(self, runner: dl.IRunner):
        self.actor = runner.model[self.actor_key]
        
        self.actor.eval()
        generate_sessions(
            env=self.env, 
            network=self.actor,
            sigma=self.sigma,
            replay_buffer=self.replay_buffer,
            num_sessions=1000,
        )
        self.actor.train()
    
    def on_epoch_start(self, runner: dl.IRunner):
        self.session_counter = 0
        self.session_steps = 0
        
    def on_batch_end(self, runner: dl.IRunner):
        if runner.global_batch_step % self.session_period == 0:
            self.actor.eval()

            session_reward, session_steps  = generate_session(
                env=self.env, 
                network=self.actor,
                sigma=self.sigma,
                replay_buffer=self.replay_buffer,
            )

            self.session_counter += 1
            self.session_steps += session_steps

            runner.batch_metrics.update({"s_reward": session_reward})
            runner.batch_metrics.update({"s_steps": session_steps})

            self.actor.train()
            
    def on_epoch_end(self, runner: dl.IRunner):
        num_sessions = 100
        
        self.actor.eval()
        valid_rewards, valid_steps = generate_sessions(
            env=self.env, 
            network=self.actor,
            num_sessions=num_sessions
        )
        self.actor.train()
        
        valid_rewards /= num_sessions
        runner.epoch_metrics["_epoch_"]["num_samples"] = self.session_steps
        runner.epoch_metrics["_epoch_"]["updates_per_sample"] = \
            runner.loader_sample_step / self.session_steps
        runner.epoch_metrics["_epoch_"]["v_reward"] = valid_rewards

In [None]:
class CustomRunner(dl.Runner):

    def __init__(
        self,
        *,
        gamma: float,
        tau: float,
        tau_period: int = 1,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.gamma = gamma
        self.tau = tau
        self.tau_period = tau_period

    def on_stage_start(self, runner: dl.IRunner):
        super().on_stage_start(runner)
        soft_update(self.model["target_actor"], self.model["actor"], 1.0)
        soft_update(self.model["target_critic"], self.model["critic"], 1.0)

    def handle_batch(self, batch: tp.Sequence[torch.Tensor]):
        # model train/valid step
        states, actions, rewards, dones, next_states = batch
        actor, target_actor = self.model["actor"], self.model["target_actor"]
        critic, target_critic = self.model["critic"], self.model["target_critic"]
        actor_optimizer, critic_optimizer = self.optimizer["actor"], self.optimizer["critic"]

        # get actions for the current state
        pred_actions = actor(states)
        # get q-values for the actions in current states
        pred_critic_states = torch.cat([states, pred_actions], 1)
        # use q-values to train the actor model
        policy_loss = (-critic(pred_critic_states)).mean()

        with torch.no_grad():
            # get possible actions for the next states
            next_state_actions = target_actor(next_states)
            # get possible q-values for the next actions
            next_critic_states = torch.cat([next_states, next_state_actions], 1)
            next_state_values = target_critic(next_critic_states).detach().squeeze()
            next_state_values[dones] = 0.0

        # compute Bellman's equation value
        target_state_values = next_state_values * self.gamma + rewards
        # compute predicted values
        critic_states = torch.cat([states, actions], 1)
        state_values = critic(critic_states).squeeze()

        # train the critic model
        value_loss = self.criterion(
            state_values,
            target_state_values.detach()
        )

        self.batch_metrics.update({
            "critic_loss": value_loss,
            "actor_loss": policy_loss
        })

        if self.is_train_loader:
            actor.zero_grad()
            actor_optimizer.zero_grad()
            policy_loss.backward()
            actor_optimizer.step()

            critic.zero_grad()
            critic_optimizer.zero_grad()
            value_loss.backward()
            critic_optimizer.step()

            if self.global_batch_step % self.tau_period == 0:
                soft_update(target_actor, actor, self.tau)
                soft_update(target_critic, critic, self.tau)

In [None]:
def get_network_actor(env):
    inner_fn = utils.get_optimal_inner_init(nn.ReLU)
    outer_fn = utils.outer_init
    
    network = torch.nn.Sequential(
        nn.Linear(env.observation_space.shape[0], 400),
        nn.ReLU(),
        nn.Linear(400, 300),
        nn.ReLU(),
    )
    head = torch.nn.Sequential(
        nn.Linear(300, 1),
        nn.Tanh()
    )
    
    network.apply(inner_fn)
    head.apply(outer_fn)
    
    return torch.nn.Sequential(network, head)

def get_network_critic(env):
    inner_fn = utils.get_optimal_inner_init(nn.LeakyReLU)
    outer_fn = utils.outer_init
    
    network = torch.nn.Sequential(
        nn.Linear(env.observation_space.shape[0] + 1, 400),
        nn.LeakyReLU(0.01),
        nn.Linear(400, 300),
        nn.LeakyReLU(0.01),
    )
    head = nn.Linear(300, 1)
    
    network.apply(inner_fn)
    head.apply(outer_fn)
    
    return torch.nn.Sequential(network, head)

In [None]:
# data
batch_size = 64
epoch_size = int(1e3) * batch_size
buffer_size = int(1e5)
# runner settings, ~training
gamma = 0.99
tau = 0.01
tau_period = 1
# callback, ~exploration
session_period = 1
sigma = 0.3
# optimization
lr_actor = 1e-4
lr_critic = 1e-3

# You can change game
# env_name = "LunarLanderContinuous-v2"
env_name = "Pendulum-v0"
env = NormalizedActions(gym.make(env_name))
replay_buffer = ReplayBuffer(buffer_size)

actor, target_actor = get_network_actor(env), get_network_actor(env)
critic, target_critic = get_network_critic(env), get_network_critic(env)
utils.set_requires_grad(target_actor, requires_grad=False)
utils.set_requires_grad(target_critic, requires_grad=False)

models = {
    "actor": actor,
    "critic": critic,
    "target_actor": target_actor,
    "target_critic": target_critic,
}

criterion = torch.nn.MSELoss()
optimizer = {
    "actor": torch.optim.Adam(actor.parameters(), lr_actor),
    "critic": torch.optim.Adam(critic.parameters(), lr=lr_critic),
}

loaders = {
    "train_game": DataLoader(
        ReplayDataset(replay_buffer, epoch_size=epoch_size), 
        batch_size=batch_size,
    ),
}


runner = CustomRunner(
    gamma=gamma, 
    tau=tau,
    tau_period=tau_period,
)

runner.train(
    model=models,
    criterion=criterion,
    optimizer=optimizer,
    loaders=loaders,
    logdir="./logs_ddpg",
    num_epochs=10,
    verbose=True,
    valid_loader="_epoch_",
    valid_metric="v_reward",
    minimize_valid_metric=False,
    load_best_on_end=True,
    callbacks=[
        GameCallback(
            env=env, 
            replay_buffer=replay_buffer, 
            session_period=session_period,
            sigma=sigma,
            actor_key="actor",
        )
    ]
)

In [None]:
import gym.wrappers


env = gym.wrappers.Monitor(
    gym.make(env_name),
    directory="videos_ddpg", 
    force=True)
generate_sessions(
    env=env, 
    network=runner.model["actor"],
    num_sessions=100
)
env.close()

In [None]:
# show video
from IPython.display import HTML
import os

video_names = list(
    filter(lambda s: s.endswith(".mp4"), os.listdir("./videos_ddpg/")))

HTML("""
<video width="640" height="480" controls>
  <source src="{}" type="video/mp4">
</video>
""".format("./videos/"+video_names[-1]))  # this may or may not be _last_ video. Try other indices

----

## Stage 4: Onpolicy Reinforcement Learning

### Act 23 - REINFORCE

In [None]:
from collections import deque, namedtuple
import typing as tp
import random

import numpy as np
import gym

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.data.dataset import IterableDataset

from catalyst import dl, metrics, utils

In [None]:
Rollout = namedtuple(
    'Rollout', 
    field_names=[
        'states', 
        'actions', 
        'rewards',
    ]
)


class RolloutBuffer:
    def __init__(self, capacity: int):
        self.capacity = capacity
        self.buffer = deque(maxlen=capacity)
    
    def append(self, rollout: Rollout):
        self.buffer.append(rollout)
    
    def sample(self, idx: int) -> tp.Sequence[np.array]:
        states, actions, rewards = self.buffer[idx]
        states, actions, rewards = (
            np.array(states, dtype=np.float32),
            np.array(actions, dtype=np.int64),
            np.array(rewards, dtype=np.float32),
        )
        return states, actions, rewards

    def __len__(self) -> int:
        return len(self.buffer)


# as far as RL does not have some predefined dataset, 
# we need to specify epoch lenght by ourselfs
class RolloutDataset(IterableDataset):

    def __init__(self, buffer: RolloutBuffer):
        self.buffer = buffer

    def __iter__(self) -> tp.Iterator[tp.Sequence[np.array]]:
        for i in range(len(self.buffer)):
            states, actions, rewards = self.buffer.sample(i)
            yield states, actions, rewards
        self.buffer.buffer.clear()

    def __len__(self) -> int:
        return self.buffer.capacity

In [None]:
def get_cumulative_rewards(rewards, gamma = 0.99):    
    G = [rewards[-1]]
    for r in reversed(rewards[:-1]):
        G.insert(0, r + gamma * G[0])
    return G

def to_one_hot(y, n_dims=None):
    """ Takes an integer vector and converts it to 1-hot matrix. """
    y_tensor = y
    y_tensor = y_tensor.type(torch.LongTensor).view(-1, 1)
    n_dims = n_dims if n_dims is not None else int(torch.max(y_tensor)) + 1
    y_one_hot = torch.zeros(y_tensor.size()[0], n_dims).scatter_(1, y_tensor, 1)
    return y_one_hot

def get_action(
    env,
    network: nn.Module,
    state: np.array,
    epsilon: float = -1
) ->  int:
#     if np.random.random() < epsilon:
#         action = env.action_space.sample()
#     else:
    state = torch.tensor(state[None], dtype=torch.float32)
    logits = network(state).detach()
    probas = F.softmax(logits, -1).cpu().numpy()[0]
    action = np.random.choice(len(probas), p=probas)
    return int(action)


def generate_session(
    env,
    network: nn.Module,
    t_max: int = 1000,
    epsilon: float = -1,
    rollout_buffer: tp.Optional[RolloutBuffer] = None,
) -> tp.Tuple[float, int]:
    total_reward = 0
    states, actions, rewards = [], [], []
    state = env.reset()

    for t in range(t_max):
        action = get_action(env, network, state=state, epsilon=epsilon)
        next_state, reward, done, _ = env.step(action)

        # record session history to train later
        states.append(state)
        actions.append(action)
        rewards.append(reward)

        total_reward += reward
        state = next_state
        if done:
            break
    if rollout_buffer is not None:
        rollout_buffer.append(Rollout(states, actions, rewards))
            
    return total_reward, t

def generate_sessions(
    env,
    network: nn.Module,
    t_max: int = 1000,
    epsilon:float = -1,
    rollout_buffer: tp.Optional[RolloutBuffer] = None,
    num_sessions: int = 100,
) -> tp.Tuple[float, int]:
    sessions_reward, sessions_steps = 0, 0
    for i_episone in range(num_sessions):
        r, t = generate_session(
            env=env, 
            network=network,
            t_max=t_max,
            epsilon=epsilon,
            rollout_buffer=rollout_buffer,
        )
        sessions_reward += r
        sessions_steps += t
    return sessions_reward, sessions_steps

In [None]:
class GameCallback(dl.Callback):
    
    def __init__(self,  *,  env, rollout_buffer: RolloutBuffer):
        super().__init__(order=0)
        self.env = env
        self.rollout_buffer = rollout_buffer

    def on_epoch_start(self, runner: dl.IRunner):
        self.actor = runner.model
        
        self.actor.eval()
        generate_sessions(
            env=self.env, 
            network=self.actor,
            rollout_buffer=self.rollout_buffer,
            num_sessions=100,
        )
        self.actor.train()

    def on_epoch_end(self, runner: dl.IRunner):
        num_sessions = 100
        
        self.actor.eval()
        valid_rewards, valid_steps = generate_sessions(
            env=self.env, 
            network=self.actor,
            num_sessions=num_sessions
        )
        self.actor.train()
        
        valid_rewards /= num_sessions
        runner.epoch_metrics["_epoch_"]["v_reward"] = valid_rewards

In [None]:
def get_network(env, num_hidden: int = 128):
    inner_fn = utils.get_optimal_inner_init(nn.ReLU)
    outer_fn = utils.outer_init

    network = torch.nn.Sequential(
        nn.Linear(env.observation_space.shape[0], num_hidden),
        nn.ReLU(),
        nn.Linear(num_hidden, num_hidden),
        nn.ReLU(),
    )
    head = nn.Linear(num_hidden, env.action_space.n)

    network.apply(inner_fn)
    head.apply(outer_fn)

    return torch.nn.Sequential(network, head)


class CustomRunner(dl.Runner):
    
    def __init__(
        self, 
        *, 
        gamma: float,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.gamma: float = gamma
        self._initialized = False

    def handle_batch(self, batch: tp.Sequence[np.array]):
        # model train/valid step
        # ATTENTION: 
        # because of different trajectories lens
        # ONLY batch_size==1 supported
        states, actions, rewards = batch
        states, actions, rewards = states[0], actions[0], rewards[0]
        cumulative_returns = torch.tensor(get_cumulative_rewards(rewards, gamma))
        network = self.model

        logits = network(states)
        probas = F.softmax(logits, -1)
        logprobas = F.log_softmax(logits, -1)
        n_actions = probas.shape[1]
        logprobas_for_actions = torch.sum(logprobas * to_one_hot(actions, n_dims=n_actions), dim=1)
        
        J_hat = torch.mean(logprobas_for_actions * cumulative_returns)
        entropy_reg = - torch.mean(torch.sum(probas * logprobas, dim = 1))
        loss = - J_hat - 0.1 * entropy_reg

        self.batch_metrics.update({"loss": loss})
        if self.is_train_loader:
            loss.backward()
            self.optimizer.step()
            self.optimizer.zero_grad()

In [None]:
batch_size = 1
epoch_size = int(1e3) * batch_size
buffer_size = int(1e2)
# runner settings
gamma = 0.99
# optimization
lr = 3e-4

# env_name = "LunarLander-v2"
env_name = "CartPole-v1"
env = gym.make(env_name)
rollout_buffer = RolloutBuffer(buffer_size)

model = get_network(env)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
loaders = {
    "train_game": DataLoader(
        RolloutDataset(rollout_buffer), 
        batch_size=batch_size,
    ),
}

runner = CustomRunner(gamma=gamma)
runner.train(
    model=model,
    optimizer=optimizer,
    loaders=loaders,
    logdir="./logs_dqn",
    num_epochs=10,
    verbose=True,
    valid_loader="_epoch_",
    valid_metric="v_reward",
    minimize_valid_metric=False,
    load_best_on_end=True,
    callbacks=[
        GameCallback(
            env=env, 
            rollout_buffer=rollout_buffer, 
        )
    ]
)

In [None]:
import gym.wrappers


env = gym.wrappers.Monitor(
    gym.make(env_name),
    directory="videos_reinforce", 
    force=True)
generate_sessions(
    env=env, 
    network=model,
    num_sessions=100
)
env.close()

In [None]:
# show video
from IPython.display import HTML
import os

video_names = list(
    filter(lambda s: s.endswith(".mp4"), os.listdir("./videos_reinforce/")))

HTML("""
<video width="640" height="480" controls>
  <source src="{}" type="video/mp4">
</video>
""".format("./videos/"+video_names[-1]))  # this may or may not be _last_ video. Try other indices