In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

import segmentation_models_pytorch as smp
from segmentation_models_pytorch.encoders.efficientnet import EfficientNetEncoder

In [2]:
class ToyModel(nn.Module):
    
    def __init__(self, dev0, dev1):
        super(ToyModel, self).__init__()
        self.dev0 = dev0
        self.dev1 = dev1
        self.net1 = torch.nn.Linear(10, 10).to(self.dev0)
        self.relu = torch.nn.ReLU()
        self.net2 = torch.nn.Linear(10, 5).to(self.dev1)

    def forward(self, x):
        x = self.relu(self.net1(x.to(self.dev0)))
        return self.net2(x.to(self.dev1))

In [3]:
model = ToyModel('cuda:0', 'cuda:1')
loss_fn = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.001)

optimizer.zero_grad()
outputs = model(torch.randn(20, 10))
labels = torch.randn(20, 5).to('cuda:1')
loss_fn(outputs, labels).backward()
optimizer.step()

In [2]:
def relu_fn(x):
    """ Swish activation function """
    return x * torch.sigmoid(x)

class ParallelEfficientNetEncoder(EfficientNetEncoder):
    
    def __init__(self,
        dev0, dev1,
        width_coeff,
        depth_coeff,
        image_size,
        dropout_rate,
        drop_connect_rate,
        block_chunks,
        in_channels = 3  # rgb
    ):
        super().__init__(width_coeff, depth_coeff, image_size, dropout_rate,
                         drop_connect_rate, block_chunks, in_channels)
        self.dev0, self.dev1 = dev0, dev1
        
        # split layers across devices
        self._bn0 = self._bn0.to(dev0)
        self._conv_stem = self._conv_stem.to(dev0)
        self.blocks_to_device(self.block_chunks[0], self.block_chunks[1], dev1)
        self.blocks_to_device(self.block_chunks[1], self.block_chunks[2], dev0)
        self.blocks_to_device(self.block_chunks[2], self.block_chunks[3], dev1)
        self.blocks_to_device(self.block_chunks[3], self.block_chunks[4], dev0)
        
    def blocks_to_device(self, start_idx, end_idx, device):
        for idx in range(start_idx, end_idx):
            self._blocks[idx].to(device)
            
    def forward(self, x):
        x0 = relu_fn(self._bn0(self._conv_stem(x.to(self.dev0))))
        x1 = self.forward_blocks(x0.to(self.dev1), self.block_chunks[0], self.block_chunks[1])
        x2 = self.forward_blocks(x1.to(self.dev0), self.block_chunks[1], self.block_chunks[2])
        x3 = self.forward_blocks(x2.to(self.dev1), self.block_chunks[2], self.block_chunks[3])
        x4 = self.forward_blocks(x3.to(self.dev0), self.block_chunks[3], self.block_chunks[4])
        return [x4.to(self.dev1), x3, x2.to(self.dev1), x1, x0.to(self.dev1)]
    
class PipelineEfficientNetEncoder(EfficientNetEncoder):
    
    def __init__(self,
        dev0, dev1, split_size,
        width_coeff,
        depth_coeff,
        image_size,
        dropout_rate,
        drop_connect_rate,
        block_chunks,
        in_channels = 3  # rgb
    ):
        super().__init__(width_coeff, depth_coeff, image_size, dropout_rate,
                         drop_connect_rate, block_chunks, in_channels)
        self.dev0, self.dev1 = dev0, dev1
        self.split_size = split_size
        
        # split layers across devices
        self._bn0 = self._bn0.to(dev0)
        self._conv_stem = self._conv_stem.to(dev0)
        self.blocks_to_device(self.block_chunks[0], self.block_chunks[1], dev1)
        self.blocks_to_device(self.block_chunks[1], self.block_chunks[2], dev0)
        self.blocks_to_device(self.block_chunks[2], self.block_chunks[3], dev1)
        self.blocks_to_device(self.block_chunks[3], self.block_chunks[4], dev0)
        
    def blocks_to_device(self, start_idx, end_idx, device):
        for idx in range(start_idx, end_idx):
            self._blocks[idx].to(device)
            
    def forward(self, x):
        x0 = relu_fn(self._bn0(self._conv_stem(x.to(self.dev0))))
        x1 = self.forward_blocks(x0.to(self.dev1), self.block_chunks[0], self.block_chunks[1])
        x2 = self.forward_blocks(x1.to(self.dev0), self.block_chunks[1], self.block_chunks[2])
        x3 = self.forward_blocks(x2.to(self.dev1), self.block_chunks[2], self.block_chunks[3])
        x4 = self.forward_blocks(x3.to(self.dev0), self.block_chunks[3], self.block_chunks[4])
        return [x4.to(self.dev1), x3, x2.to(self.dev1), x1, x0.to(self.dev1)]

In [3]:
ctor_kwargs = {
    'dev0': 'cuda:0',
    'dev1': 'cuda:1',
    'width_coeff': 1.1,
    'depth_coeff': 1.2,
    'image_size': 260,
    'dropout_rate': 0.3,
    'drop_connect_rate': 0.2,
    'block_chunks': [0, 3, 8, 16, 23]
}

model = ParallelEfficientNetEncoder(**ctor_kwargs)
out_shapes = model._out_shapes
out_shapes

[352, 120, 48, 24, 32]

In [5]:
model.__class__.__name__

'ParallelEfficientNetEncoder'

In [12]:
loss_fn = nn.BCEWithLogitsLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001)

In [21]:
B, C, H, W = 16, 3, 256, 256
# dummy data
data = torch.rand((B, C, H, W))
t0 = torch.rand((B, out_shapes[4], H // 2, W // 2)).to('cuda:1')
t1 = torch.rand((B, out_shapes[3], H // 4, W // 4)).to('cuda:1')
t2 = torch.rand((B, out_shapes[2], H // 8, W // 8)).to('cuda:1')
t3 = torch.rand((B, out_shapes[1], H // 16, W // 16)).to('cuda:1')
t4 = torch.rand((B, out_shapes[0], H // 32, W // 32)).to('cuda:1')
print(data.size(), t0.size(), t1.size(), t2.size(), t3.size(), t4.size())
targets = [t4, t3, t2, t1, t0]

torch.Size([16, 3, 256, 256]) torch.Size([16, 32, 128, 128]) torch.Size([16, 24, 64, 64]) torch.Size([16, 40, 32, 32]) torch.Size([16, 112, 16, 16]) torch.Size([16, 320, 8, 8])


In [28]:
splits = iter(data.split(4, dim=0))
?splits

In [30]:
s_next = next(splits)
?s_next

In [24]:
def train(batches):
    for _ in range(batches):
        optimizer.zero_grad()
        outputs = model(data)
        for i in range(5):
            loss_fn(outputs[i], targets[i]).backward(retain_graph=True)
        optimizer.step()

In [25]:
%%timeit
train(10)

1.87 s ± 11.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [None]:
# 1.87 s ± 11.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

In [5]:
unet = smp.Unet('efficientnet-b0')
unet

Load result: None


Unet(
  (encoder): EfficientNetEncoder(
    (_conv_stem): Conv2dStaticSamePadding(
      3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False
      (static_padding): ZeroPad2d(padding=(0, 1, 0, 1), value=0.0)
    )
    (_bn0): BatchNorm2d(32, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
    (_blocks): ModuleList(
      (0): MBConvBlock(
        (_depthwise_conv): Conv2dStaticSamePadding(
          32, 32, kernel_size=(3, 3), stride=[1, 1], groups=32, bias=False
          (static_padding): ZeroPad2d(padding=(1, 1, 1, 1), value=0.0)
        )
        (_bn1): BatchNorm2d(32, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
        (_se_reduce): Conv2dStaticSamePadding(
          32, 8, kernel_size=(1, 1), stride=(1, 1)
          (static_padding): Identity()
        )
        (_se_expand): Conv2dStaticSamePadding(
          8, 32, kernel_size=(1, 1), stride=(1, 1)
          (static_padding): Identity()
        )
        (_