In [None]:
# 动态并行的两个重要 API 是：

# torch.jit.fork(fn : Callable[..., T], *args, **kwargs) -> torch.jit.Future[T]

# torch.jit.wait(fut : torch.jit.Future[T]) -> T

In [453]:
import torch

def foo(x):
    return torch.neg(x)


# @torch.jit.script 是 PyTorch 中用于将 Python 函数或类转换为 TorchScript 的装饰器。TorchScript 是一种中间表示形式，允许 PyTorch 模型在不依赖 Python 解释器的情况下进行优化和执行。这使得模型可以在生产环境中更高效地运行，尤其是在部署到 C++ 环境或移动设备时。
@torch.jit.script
def example(x):
    # Call `foo` using parallelism:
    # First, we "fork" off a task. This task will run `foo` with argument `x`
    future = torch.jit.fork(foo, x)

    # Call `foo` normally
    x_normal = foo(x)

    # Second, we "wait" on the task. Since the task may be running in
    # parallel, we have to "wait" for its result to become available.
    # Notice that by having lines of code between the "fork()" and "wait()"
    # call for a given Future, we can overlap computations so that they
    # run in parallel.
    x_parallel = torch.jit.wait(future)

    return x_normal, x_parallel

print(example(torch.ones(1))) # (-1., -1.)

(tensor([-1.]), tensor([-1.]))


In [454]:
# fork()接受可调用函数fn和该可调用函数的参数args ，kwargs并创建一个异步任务来执行fn。 fn可以是函数、方法或 Module 实例。fork()返回对此执行结果值的引用，称为Future。由于fork在创建异步任务后立即返回，因此在执行调用fn后的代码行时可能尚未执行。因此，用于等待异步任务完成并返回值。fork()wait()


In [455]:
# 这些结构可用于重叠函数内的语句的执行（如示例部分所示）或与其他语言结构（如循环）组合
import torch
from typing import List

def foo(x):
    return torch.neg(x)

@torch.jit.script
def example(x):
    futures : List[torch.jit.Future[torch.Tensor]] = []
    for _ in range(100):
        futures.append(torch.jit.fork(foo, x))

    results = []
    for future in futures:
        results.append(torch.jit.wait(future))

    return torch.sum(torch.stack(results))

print(example(torch.ones([])))

tensor(-100.)


In [456]:
# 应用示例：双向 LSTM 集成
import torch, time

# In RNN parlance, the dimensions we care about are:
# # of time-steps (T)
# Batch size (B)
# Hidden size/number of "channels" (C)
T, B, C = 50, 50, 1024

# A module that defines a single "bidirectional LSTM". This is simply two
# LSTMs applied to the same sequence, but one in reverse
class BidirectionalRecurrentLSTM(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.cell_f = torch.nn.LSTM(input_size=C, hidden_size=C)
        self.cell_b = torch.nn.LSTM(input_size=C, hidden_size=C)

    def forward(self, x : torch.Tensor) -> torch.Tensor:
        # Forward layer
        output_f, _ = self.cell_f(x)

        # Backward layer. Flip input in the time dimension (dim 0), apply the
        # layer, then flip the outputs in the time dimension
        x_rev = torch.flip(x, dims=[0])
        output_b, _ = self.cell_b(torch.flip(x, dims=[0]))
        output_b_rev = torch.flip(output_b, dims=[0])

        return torch.cat((output_f, output_b_rev), dim=2)


# An "ensemble" of `BidirectionalRecurrentLSTM` modules. The modules in the
# ensemble are run one-by-one on the same input then their results are
# stacked and summed together, returning the combined result.
class LSTMEnsemble(torch.nn.Module):
    def __init__(self, n_models):
        super().__init__()
        self.n_models = n_models
        self.models = torch.nn.ModuleList([
            BidirectionalRecurrentLSTM() for _ in range(self.n_models)])

    def forward(self, x : torch.Tensor) -> torch.Tensor:
        results = []
        for model in self.models:
            results.append(model(x))
        return torch.stack(results).sum(dim=0)

# For a head-to-head comparison to what we're going to do with fork/wait, let's
# instantiate the model and compile it with TorchScript
ens = torch.jit.script(LSTMEnsemble(n_models=4))

# Normally you would pull this input out of an embedding table, but for the
# purpose of this demo let's just use random data.
x = torch.rand(T, B, C)

# Let's run the model once to warm up things like the memory allocator
ens(x)

x = torch.rand(T, B, C)

# Let's see how fast it runs!
s = time.time()
ens(x)
print('Inference took', time.time() - s, ' seconds')
# 并行化前向层和后向层
def forward(self, x : torch.Tensor) -> torch.Tensor:
    # Forward layer - fork() so this can run in parallel to the backward
    # layer
    future_f = torch.jit.fork(self.cell_f, x)

    # Backward layer. Flip input in the time dimension (dim 0), apply the
    # layer, then flip the outputs in the time dimension
    x_rev = torch.flip(x, dims=[0])
    output_b, _ = self.cell_b(torch.flip(x, dims=[0]))
    output_b_rev = torch.flip(output_b, dims=[0])

    # Retrieve the output from the forward layer. Note this needs to happen
    # *after* the stuff we want to parallelize with
    output_f, _ = torch.jit.wait(future_f)

    return torch.cat((output_f, output_b_rev), dim=2)

Inference took 0.6468088626861572  seconds


In [457]:
# 集成中的并行模型
def forward(self, x : torch.Tensor) -> torch.Tensor:
    # Launch tasks for each model
    futures : List[torch.jit.Future[torch.Tensor]] = []
    for model in self.models:
        futures.append(torch.jit.fork(model, x))

    # Collect the results from the launched tasks
    results : List[torch.Tensor] = []
    for future in futures:
        results.append(torch.jit.wait(future))

    return torch.stack(results).sum(dim=0)
# 简洁版
def forward(self, x : torch.Tensor) -> torch.Tensor:
    futures = [torch.jit.fork(model, x) for model in self.models]
    results = [torch.jit.wait(fut) for fut in futures]
    return torch.stack(results).sum(dim=0)