In [None]:
# Nazarov A.I.

In [None]:
import warnings
warnings.simplefilter("ignore", category=FutureWarning)

# Core libraries
import pymc as pm
import arviz as az
import numpy as np
import pandas as pd
import seaborn as sns
from scipy import stats
from matplotlib import pyplot as plt

# PyTensor (replaces Theano)
import pytensor
import pytensor.tensor as pt

# --- 1D convolution using pytensor.scan ---
def conv1d(x, w, border_mode="valid"):
    """1D convolution implemented in PyTensor (symbolic)."""
    x = x[0]  # input sequence
    w = w[0]  # convolution kernel
    n_x = x.shape[0]
    n_w = w.shape[0]

    # output length by mode
    if border_mode == "valid":          # only full overlaps
        out_len = n_x - n_w + 1
        i_seq = pt.arange(out_len)
    elif border_mode == "full":         # include partial overlaps
        out_len = n_x + n_w - 1
        i_seq = pt.arange(out_len)
    else:
        raise ValueError("border_mode must be 'valid' or 'full'")

    # single convolution step
    def step(i, x, w):
        start = pt.maximum(0, i - n_w + 1)      # start index
        end = pt.minimum(i + 1, n_x)            # end index
        x_slice = x[start:end]                  # signal window
        w_slice = w[n_w - (end - start):n_w]    # reversed kernel window
        return pt.sum(x_slice * w_slice)        # dot product

    # scan over indices
    outputs, _ = pytensor.scan(fn=step, sequences=i_seq, non_sequences=[x, w])
    return outputs[None, :]  # restore batch dimension


# Read dataset

In [None]:
data = pd.read_csv('data/owid-covid-data.csv')
data['date'] = pd.to_datetime(data['date'])
print('Size of data:', data.shape)
data.head()

Размер данных: (201018, 67)
Количество местоположений: 244


Unnamed: 0,iso_code,continent,location,date,total_cases,new_cases,new_cases_smoothed,total_deaths,new_deaths,new_deaths_smoothed,...,female_smokers,male_smokers,handwashing_facilities,hospital_beds_per_thousand,life_expectancy,human_development_index,excess_mortality_cumulative_absolute,excess_mortality_cumulative,excess_mortality,excess_mortality_cumulative_per_million
0,AFG,Asia,Afghanistan,2020-02-24,5.0,5.0,,,,,...,,,37.746,0.5,64.83,0.511,,,,
1,AFG,Asia,Afghanistan,2020-02-25,5.0,0.0,,,,,...,,,37.746,0.5,64.83,0.511,,,,
2,AFG,Asia,Afghanistan,2020-02-26,5.0,0.0,,,,,...,,,37.746,0.5,64.83,0.511,,,,
3,AFG,Asia,Afghanistan,2020-02-27,5.0,0.0,,,,,...,,,37.746,0.5,64.83,0.511,,,,
4,AFG,Asia,Afghanistan,2020-02-28,5.0,0.0,,,,,...,,,37.746,0.5,64.83,0.511,,,,


In [4]:
def process_data(data: pd.DataFrame):
    """Process data for the task."""
    data = data[['location', 'date', 'total_cases', 'new_cases', 'reproduction_rate']]
    data = data[data.location.isin(['Russia', 'Italy', 'Germany', 'France'])]
    data = data.rename(columns={'total_cases': 'total', 'new_cases': 'positive'})
    data.set_index('date', inplace=True)

    # --- Clean up missing or bad values ---
    data = data.replace([np.inf, -np.inf], np.nan)  # drop infinities
    data = data.fillna(0)  # fill any NaNs with 0
    data['positive'] = data['positive'].astype('float32')
    data['total'] = data['total'].astype('float32')

    return data


In [5]:
# Оставляем только нужные нам данные: регион, дата, общее число случаев, новые случаи, репродуктивное число
# R_t, чтобы потом сравнить с расчетами :)
data = process_data(data)
data.head()

Unnamed: 0_level_0,location,total,positive,reproduction_rate
date,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
2020-01-24,France,2.0,2.0,0.0
2020-01-25,France,3.0,1.0,0.0
2020-01-26,France,3.0,0.0,0.0
2020-01-27,France,3.0,0.0,0.0
2020-01-28,France,4.0,1.0,0.0


# PyMC5 Generative model

In [None]:
# ==============================================================
#  Imports
# ==============================================================

import pymc as pm                 # Bayesian modeling and inference
import arviz as az                # Analysis and diagnostics of PyMC traces
import numpy as np                # Numerical array handling
import pandas as pd               # Time-series and tabular data handling
from scipy import stats as sps    # Statistical distributions (for delay & gen. time)
import pytensor                   # Symbolic math backend (successor of Theano)
import pytensor.tensor as pt      # Tensor operations (similar to NumPy ops in graphs)


# ==============================================================
#  Delay Distribution
# ==============================================================

def get_delay_distribution():
    """Create discrete log-normal delay between infection and case confirmation."""
    mean_delay = 8.0              # average delay (days) from infection to reporting
    std_delay = 2.0               # standard deviation of delay (days)

    # convert mean/std to log-normal μ and σ on log scale
    mu = np.log(mean_delay**2 / np.sqrt(std_delay**2 + mean_delay**2))
    sigma = np.sqrt(np.log(std_delay**2 / mean_delay**2 + 1))

    # continuous log-normal distribution for delay
    dist = sps.lognorm(scale=np.exp(mu), s=sigma)

    # days grid (0–19 days)
    days = np.arange(0, 20)

    # discrete daily probabilities = CDF differences
    cdf = dist.cdf(days)
    pdf = np.diff(cdf, prepend=0)

    # normalize to make sum(pdf) = 1
    pdf /= pdf.sum()

    # float32 for PyTensor compatibility (less memory, GPU-friendly)
    return pdf.astype("float32")


# ==============================================================
#  GenerativeModel: Bayesian time-varying Rₜ estimation
# ==============================================================

class GenerativeModel:
    """Probabilistic generative model linking infections to reported cases."""
    version = "2.0.0-pymc5"       # model version string for reproducibility

    # ----------------------------------------------------------
    #  Initialization
    # ----------------------------------------------------------
    def __init__(self, region: str, observed: pd.DataFrame, buffer_days=10):
        # locate first nonzero 'positive' entry to remove leading zeros
        first_index = observed.positive.ne(0).argmax()

        # trim dataset to start from first detected case
        observed = observed.iloc[first_index:]

        # extend timeline backward by 'buffer_days' of zeros for initial infection seeding
        new_index = pd.date_range(
            start=observed.index[0] - pd.Timedelta(days=buffer_days),  # start earlier
            end=observed.index[-1],                                    # keep same end
            freq="D",                                                  # daily frequency
        )

        # reindex data (prepend zeros)
        observed = observed.reindex(new_index, fill_value=0)

        # save data and initialize internal containers
        self.region = region                  # name of geographic region
        self.observed = observed              # input observed dataframe
        self._trace = None                    # posterior samples
        self._inference_data = None           # ArviZ-formatted samples
        self.model = None                     # PyMC model object

    # ----------------------------------------------------------
    #  Diagnostics / Posterior Access
    # ----------------------------------------------------------
    @property
    def n_divergences(self):
        """Return number of NUTS divergences (diagnostic metric)."""
        assert self._trace is not None, "Must run sample() first"
        return self._trace["diverging"].nonzero()[0].size  # count divergent transitions

    @property
    def inference_data(self):
        """Return posterior samples + posterior predictive."""
        assert self._trace is not None, "Must run sample() first"

        # generate posterior predictive samples
        with self.model:
            posterior_predictive = pm.sample_posterior_predictive(self._trace)

        # convert PyMC trace into ArviZ object
        idata = az.from_pymc(
            posterior=self._trace,
            posterior_predictive=posterior_predictive,
        )

        # attach model version metadata
        idata.posterior.attrs["model_version"] = self.version
        return idata

    # ----------------------------------------------------------
    #  Internal helper: scaling
    # ----------------------------------------------------------
    def _scale_to_positives(self, data):
        """Scale arbitrary series to match mean of observed positives."""
        scale_factor = self.observed.positive.mean() / np.mean(data)
        return scale_factor * data

    # ----------------------------------------------------------
    #  Internal helper: generation-time distribution
    # ----------------------------------------------------------
    def _get_generation_time_interval(self):
        """Discrete serial interval (infection→infection delay)."""
        mean_si, std_si = 4.7, 2.9               # mean/std of serial interval (days)

        # convert to log-normal μ, σ
        mu_si = np.log(mean_si**2 / np.sqrt(std_si**2 + mean_si**2))
        sigma_si = np.sqrt(np.log(std_si**2 / mean_si**2 + 1))

        # continuous log-normal distribution
        dist = sps.lognorm(scale=np.exp(mu_si), s=sigma_si)

        # discrete support (0–19 days)
        g_range = np.arange(0, 20)

        # discrete probability mass from CDF differences
        gt = np.diff(dist.cdf(g_range), prepend=0)
        gt /= gt.sum()                          # normalize
        return gt.astype("float32")

    # ----------------------------------------------------------
    #  Precompute convolution kernel (generation matrix)
    # ----------------------------------------------------------
    def _get_convolution_ready_gt(self, len_obs):
        """Prepare generation-time convolution matrix for pytensor.scan."""
        gt = self._get_generation_time_interval()
        conv_ready = np.zeros((len_obs - 1, len_obs), dtype="float32")  # (t, lag) kernel

        # fill kernel row by row: how past infections affect current time
        for t in range(1, len_obs):
            begin = np.maximum(0, t - len(gt) + 1)                     # start index
            slice_update = gt[1 : t - begin + 1][::-1]                 # reverse order for convolution
            conv_ready[t - 1, begin : begin + len(slice_update)] = slice_update

        # wrap as shared tensor (constant node in computation graph)
        return pytensor.shared(conv_ready)

    # ----------------------------------------------------------
    #  Model Construction
    # ----------------------------------------------------------
    def build(self):
        """Define full PyMC model."""
        p_delay = get_delay_distribution()                  # infection→confirmation delay
        len_obs = len(self.observed)                        # number of observation days
        conv_gt = self._get_convolution_ready_gt(len_obs)   # precomputed gen-time kernel
        nonzero_days = self.observed.total.gt(0)            # mask for valid test counts

        # model coordinate system for plotting & inference
        coords = {
            "date": self.observed.index.values,             # all observation dates
            "nonzero_date": self.observed.index.values[nonzero_days],
        }

        # define probabilistic model
        with pm.Model(coords=coords) as self.model:

            # --- (1) Latent reproduction number Rₜ process ---
            log_r_t = pm.GaussianRandomWalk("log_r_t", sigma=0.035, dims=["date"])  # random walk in log-space
            r_t = pm.Deterministic("r_t", pm.math.exp(log_r_t), dims=["date"])      # Rₜ = exp(log_Rₜ)

            # --- (2) Initial infection seed ---
            seed = pm.Exponential("seed", 1 / 0.02)          # small initial infection level
            y0 = pt.zeros(len_obs, dtype="float32")          # infection state array
            y0 = pt.set_subtensor(y0[0], seed)               # set first day to seed value

            # --- (3) Infection recurrence equation ---
            def recurrence(t, gt_row, y, r_t):
                # total new infections at day t = sum of past infections * Rₜ * generation weights
                return pt.set_subtensor(y[t], pt.sum(r_t * y * gt_row))

            # symbolic loop (scan) to simulate infections over time
            outputs, _ = pytensor.scan(
                fn=recurrence,                               # recurrence function
                sequences=[pt.arange(1, len_obs), conv_gt],  # loop inputs
                outputs_info=y0,                             # initial infection vector
                non_sequences=r_t,                           # external variable
                n_steps=len_obs - 1,                         # total time steps
            )
            infections = pm.Deterministic("infections", outputs[-1], dims=["date"])  # infection trajectory

            # --- (4) Convolution: infections → observed positives ---
            p_delay_shared = pytensor.shared(p_delay)         # delay distribution as tensor
            convolved_full = conv1d(                         # apply convolution manually
                infections.reshape((1, len_obs)).astype("float32"),
                p_delay_shared.reshape((1, len(p_delay))),
                border_mode="full",
            )[0]
            convolved = convolved_full[:len_obs]              # trim to observation window
            test_adj_pos = pm.Deterministic("test_adjusted_positive", convolved, dims=["date"])

            # --- (5) Exposure and expected positives ---
            tests = pm.Data("tests", self.observed.total.values.astype("float32"), dims=["date"])
            exposure = pm.Deterministic(
                "exposure",
                pm.math.clip(tests, float(self.observed.total.max()) * 0.1, 1e9),  # avoid 0 exposure
                dims=["date"],
            )
            positive = pm.Deterministic("positive", exposure * test_adj_pos, dims=["date"])  # model output

            # --- (6) Observed data ---
            obs_pos = pm.Data("observed_positive", self.observed.positive.values.astype("float32"), dims=["date"])
            nonzero_obs_pos = pm.Data(
                "nonzero_observed_positive",
                self.observed.positive[nonzero_days.values].values.astype("float32"),
                dims=["nonzero_date"],
            )

            # --- (7) Likelihood: observed positives follow NegBinomial ---
            pm.NegativeBinomial(
                "nonzero_positive",
                mu=positive[nonzero_days.values],             # expected mean
                alpha=pm.Gamma("alpha", mu=6, sigma=1),       # dispersion prior
                observed=nonzero_obs_pos,                     # actual observed data
                dims=["nonzero_date"],
            )

        return self.model

    # ----------------------------------------------------------
    #  Sampling (MCMC inference)
    # ----------------------------------------------------------
    def sample(
        self,
        cores=4,
        chains=4,
        tune=700,
        draws=200,
        target_accept=0.95,
        init="jitter+adapt_diag",
    ):
        """Run NUTS MCMC to estimate posterior distributions."""
        # ensure model built before sampling
        if self.model is None:
            self.build()

        with self.model:
            # run sampler
            self._trace = pm.sample(
                draws=draws,           # number of posterior samples
                tune=tune,             # number of warmup (adaptation) steps
                chains=chains,         # independent MCMC chains
                cores=cores,           # CPU cores to parallelize chains
                target_accept=target_accept,  # acceptance rate for NUTS
                init=init,             # initialization strategy
            )
        return self


# Utils, plotting

In [None]:
# ==============================================================
#  Inference Summary Helper
# ==============================================================

def get_result(model):
    """Flatten and unpack posterior summary DataFrame for plotting."""
    result = summarize_inference_data(model.inference_data)
    # unwrap nested structures into direct value arrays
    for col in ["mean", "median", "lower_80", "upper_80",
                "infections", "test_adjusted_positive"]:
        result[col] = result[col].transform(lambda x: x.values)
    return result



# ==============================================================
#  Data Extraction
# ==============================================================

def get_data_by_region(df: pd.DataFrame, region: str, min_cases: int = 100) -> pd.DataFrame:
    """Extract and trim data for a given region starting after threshold cases."""
    region_df = df[df["location"] == region].copy()               # isolate region
    first_valid = region_df["positive"].ge(min_cases).idxmax()    # first day ≥ threshold
    trimmed = region_df.loc[first_valid:]                         # keep later days only
    return trimmed


# ==============================================================
#  Visualization Functions
# ==============================================================

def plot_infections_and_tests(region, model, result):
    """Compare modeled infections and positive tests against reported data."""
    fig, ax = plt.subplots(figsize=(12, 8))

    # model-based expectations
    result.infections.plot(c="C2", label="Expected primary infections", ax=ax)
    result.test_adjusted_positive.plot(c="C0",
                                       label="Expected positives (constant testing)",
                                       ax=ax)
    result.test_adjusted_positive_raw.plot(c="C1",
                                           alpha=0.5,
                                           style="--",
                                           label="Expected positives (raw)",
                                           ax=ax)

    # observed data
    model.observed.positive.plot(c="C7", alpha=0.7, label="Reported positive tests", ax=ax)

    # formatting
    fig.set_facecolor("w")
    ax.legend()
    ax.set(title=f"rt.live model inference for {region}", ylabel="number of cases")
    sns.despine()
    plt.show()


def plot_effective_r_t(region, data_region, model, result):
    """Visualize effective reproduction number Rₜ with uncertainty bands."""
    fig, ax = plt.subplots(figsize=(12, 8))
    ax.set(title=f"Effective reproduction number for {region}", ylabel="$R_e(t)$")

    # posterior samples for Rₜ
    samples = model.trace["r_t"].T
    x = result.index

    # color gradient by percentile level
    percs = np.linspace(51, 99, 40)
    colors = (percs - percs.min()) / (percs.max() - percs.min())
    cmap = plt.get_cmap("Reds")

    # median line
    result["median"].plot(c="k", ls="-", ax=ax)

    # credible intervals (shaded)
    for i, p in enumerate(percs[::-1]):
        upper = np.percentile(samples, p, axis=1)
        lower = np.percentile(samples, 100 - p, axis=1)
        ax.fill_between(x, upper, lower, color=cmap(colors[i]), alpha=0.8)

    # threshold and observed comparison
    ax.axhline(1.0, c="k", lw=1, linestyle="--")
    ax.plot(data_region["reproduction_rate"], label="Reported Rₜ", c="C1", alpha=0.6)
    ax.legend()
    sns.despine()
    plt.show()


def plot_posterior_predictive(region, data_region, model):
    """Plot posterior predictive realizations against observed case counts."""
    with model.model:
        posterior_predictive = pm.sample_posterior_predictive(model.trace)

    fig, ax = plt.subplots(figsize=(12, 8))
    ax.plot(posterior_predictive["nonzero_positive"].T, color="0.5", alpha=0.05)  # simulated draws
    ax.plot(data_region["positive"].values, color="r", label="Observed")           # real data
    ax.set(ylim=(0, 300_000))
    ax.set_title(f"Posterior nonzero positive cases for {region}")
    ax.legend()
    sns.despine()
    plt.show()


In [None]:
# ==============================================================
#  Imports and setup
# ==============================================================

import os
import matplotlib.pyplot as plt
import arviz as az
import numpy as np
import pandas as pd
from sklearn.metrics import mean_absolute_error, mean_squared_error

# ensure base results directory exists
os.makedirs("results", exist_ok=True)

# list of countries to process sequentially
countries = ["Russia", "Italy", "Germany", "France"]

# container to hold inference objects for later use
results = {}


# ==============================================================
#  Core fitting routine for one country
# ==============================================================

def fit_country(name, data):
    """Fit the GenerativeModel for one country and save outputs/plots."""

    # ----------------------------
    #  Data preparation
    # ----------------------------
    # extract country data and trim early zeros
    region_data = get_data_by_region(data, name, min_cases=100)

    # replace infinite / NaN values with zeros for numerical stability
    region_data = region_data.replace([np.inf, -np.inf], np.nan).fillna(0)

    # cast key columns to float32 (required for PyTensor)
    for col in ["positive", "total", "reproduction_rate"]:
        if col in region_data.columns:
            region_data[col] = region_data[col].astype("float32")

    # ----------------------------
    #  Model build and sampling
    # ----------------------------
    gm = GenerativeModel(region=name, observed=region_data, buffer_days=10)
    gm.build()  # define the model graph
    gm.sample(  # run Bayesian inference (NUTS MCMC)
        draws=1000,
        tune=1000,
        chains=4,
        cores=8,
        target_accept=0.95
    )

    # extract inference data and variable summaries
    idata = gm.inference_data
    summary = az.summary(idata, var_names=["r_t", "alpha"]).round(3)

    # create country-specific output directory
    out_dir = os.path.join("results", name.lower())
    os.makedirs(out_dir, exist_ok=True)

    # ==============================================================
    #  Save Rₜ posterior summary
    # ==============================================================
    r_t = idata.posterior["r_t"]                                 # posterior samples of Rₜ
    r_t_mean = r_t.mean(dim=("chain", "draw")).to_dataframe().reset_index()
    r_t_hdi = az.hdi(r_t, hdi_prob=0.95).to_dataframe().reset_index()
    r_t_df = pd.merge(r_t_mean, r_t_hdi, on="date", how="left")  # merge mean + intervals
    r_t_df.rename(columns={"r_t": "mean_r_t",
                           "lower": "hdi_2.5",
                           "upper": "hdi_97.5"}, inplace=True)
    r_t_df.to_csv(os.path.join(out_dir, f"r_t_{name.lower()}.csv"), index=False)

    # ==============================================================
    #  Save fitted new case trajectories
    # ==============================================================
    pos_post = idata.posterior["positive"]                       # posterior predictions
    pos_mean = pos_post.mean(dim=("chain", "draw")).values[0]    # mean trajectory
    pos_hdi = az.hdi(pos_post, hdi_prob=0.95).to_array().values  # 95% interval

    df_fit = pd.DataFrame({
        "date": region_data.index,
        "observed": region_data["positive"].values,
        "fitted_mean": pos_mean,
        "hdi_2.5": pos_hdi[0],
        "hdi_97.5": pos_hdi[1]
    })
    df_fit.to_csv(os.path.join(out_dir, f"fit_{name.lower()}.csv"), index=False)

    # ==============================================================
    #  Diagnostic plots
    # ==============================================================
    plt.figure(figsize=(12, 5))

    # --- (1) Effective Rₜ over time ---
    plt.subplot(1, 2, 1)
    plt.plot(r_t_df["date"], r_t_df["mean_r_t"], label="Mean Rₜ")
    plt.fill_between(r_t_df["date"], r_t_df["hdi_2.5"], r_t_df["hdi_97.5"], alpha=0.2)
    if "reproduction_rate" in region_data.columns:
        plt.plot(region_data.index, region_data["reproduction_rate"],
                 "--", label="Reported Rₜ")
    plt.axhline(1, color="r", linestyle=":")
    plt.title(f"{name} — Effective Rₜ")
    plt.legend()

    # --- (2) Observed vs modeled new cases ---
    plt.subplot(1, 2, 2)
    plt.plot(df_fit["date"], df_fit["observed"], label="Observed", color="black")
    plt.plot(df_fit["date"], df_fit["fitted_mean"], label="Model mean")
    plt.fill_between(df_fit["date"], df_fit["hdi_2.5"], df_fit["hdi_97.5"], alpha=0.2)
    plt.title(f"{name} — New Cases")
    plt.legend()
    plt.tight_layout()

    # save combined plot
    plt.savefig(os.path.join(out_dir, f"summary_{name.lower()}.png"))
    plt.close()

    # ==============================================================
    #  Forecast evaluation (02–14 Dec 2020)
    # ==============================================================
    forecast_dates = pd.date_range("2020-12-02", "2020-12-14")
    df_forecast = df_fit[df_fit["date"].isin(forecast_dates)].copy()

    if not df_forecast.empty:
        # compute error metrics
        mae = mean_absolute_error(df_forecast["observed"], df_forecast["fitted_mean"])
        rmse = mean_squared_error(df_forecast["observed"],
                                  df_forecast["fitted_mean"],
                                  squared=False)
        mape = np.mean(np.abs((df_forecast["observed"] - df_forecast["fitted_mean"]) /
                              np.clip(df_forecast["observed"], 1, None))) * 100

        # save metrics as small table
        metrics = pd.DataFrame({
            "MAE": [mae],
            "RMSE": [rmse],
            "MAPE(%)": [mape]
        })
        metrics.to_csv(os.path.join(out_dir,
                                    f"forecast_metrics_{name.lower()}.csv"),
                       index=False)
        print(f"{name}: MAE={mae:.2f}, RMSE={rmse:.2f}, MAPE={mape:.1f}%")

    # ==============================================================
    #  Save inference artifacts
    # ==============================================================
    az.to_netcdf(idata, os.path.join(out_dir, f"idata_{name.lower()}.nc"))  # full posterior
    summary.to_csv(os.path.join(out_dir, f"summary_{name.lower()}.csv"))    # textual summary
    results[name] = {"idata": idata, "summary": summary}                    # store reference

    print(f"✓ {name} done — results saved to {out_dir}\n")


# ==============================================================
#  Run full pipeline for all countries
# ==============================================================

for c in countries:
    print(f"\n=== {c} ===")
    fit_country(c, data)     # sequentially fit models


# ==============================================================
#  Aggregate forecast metrics across all countries
# ==============================================================

all_metrics = []
for c in countries:
    metrics_path = f"results/{c.lower()}/forecast_metrics_{c.lower()}.csv"
    if os.path.exists(metrics_path):
        m = pd.read_csv(metrics_path)
        m.insert(0, "Country", c)          # add country column
        all_metrics.append(m)

# combine and export single summary CSV if data available
if all_metrics:
    combined = pd.concat(all_metrics)
    combined.to_csv("results/forecast_summary.csv", index=False)
    print("Combined forecast metrics saved to results/forecast_summary.csv")



=== Russia ===


Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [log_r_t, seed, alpha]


  t_fn, n_steps = scan_perform_ext.perform(
  t_fn, n_steps = scan_perform_ext.perform(
  t_fn, n_steps = scan_perform_ext.perform(
  t_fn, n_steps = scan_perform_ext.perform(


In [None]:
# Долго считалось, переписал conv1d вручную. Расчет очень долгий, не досчитался