# ETDS Explained


Paper: Equivalent Transformation and Dual Stream Network Construction for Mobile Image Super-Resolution [[Link]](https://openaccess.thecvf.com/content/CVPR2023/papers/Chao_Equivalent_Transformation_and_Dual_Stream_Network_Construction_for_Mobile_Image_CVPR_2023_paper.pdf)

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

## Example of Horizontal Concatenation - Eqn.2

In [2]:
x = torch.randn(1, 32, 64, 64)
y = torch.randn(1, 32, 64, 64)

w1 = torch.randn(64, 32, 3, 3)
w2 = torch.randn(64, 32, 3, 3)

res_1 = F.conv2d(x, w1, padding=1) + F.conv2d(y, w2, padding=1)

# This is horizontal concatenation
res_2 = F.conv2d(
    torch.cat([x, y], dim=1), torch.cat([w1, w2], dim=1), padding=1
)

torch.allclose(res_1, res_2, atol=1e-3)

True

## Example of Vertical Concatenation - Eqn.3

In [3]:
x = torch.randn(1, 32, 64, 64)

w1 = torch.randn(32, 32, 3, 3)
w2 = torch.randn(32, 32, 3, 3)

res_1 = torch.cat([F.conv2d(x, w1, padding=1), F.conv2d(x, w2, padding=1)], dim=1)

# This is vertical concatenation
res_2 = F.conv2d(x, torch.cat([w1, w2], dim=0), padding=1)

torch.allclose(res_1, res_2, atol=1e-3)

True

## ET for Repeat Operation - Eqn.4

In [4]:
x = torch.randn(1, 32, 64, 64)

I = torch.eye(32).view(32, 32, 1, 1)

n_concat = 4
x_concat_1 = torch.cat([x] * n_concat, dim=1)

# Vertical concatenation
x_concat_2 = F.conv2d(x, torch.cat([I] * n_concat, dim=0), padding=0)

torch.allclose(x_concat_1, x_concat_2, atol=1e-3)

True

## ET for Add Operation - Eqn.5

In [5]:
x = torch.randn(1, 32, 64, 64)
y = torch.randn(1, 32, 64, 64)

I = torch.eye(32).view(32, 32, 1, 1)

sum_1 = x + y

# Horizontal concatenation
sum_2 = F.conv2d(
    torch.cat([x, y], dim=1), torch.cat([I, I], dim=1), padding=0
)

torch.allclose(sum_1, sum_2, atol=1e-3)

True

## ET for Concat Operation - Eqn.6

In [6]:
x = torch.randn(1, 32, 64, 64)
y = torch.randn(1, 32, 64, 64)

w1 = torch.randn(32, 32, 3, 3)
b1 = torch.randn(32)
w2 = torch.randn(32, 32, 3, 3)
b2 = torch.randn(32)

O = torch.zeros(32, 32, 3, 3)

cat_1= torch.cat([F.conv2d(x, w1, padding=1) + b1.view(1, -1, 1, 1), F.conv2d(y, w2, padding=1) + b2.view(1, -1, 1, 1)], dim=1)

# Horizontal concatenation + vertical concatenation
cat_2 = F.conv2d(
    torch.cat([x, y], dim=1), torch.cat([torch.cat([w1, O], dim=1), torch.cat([O, w2], dim=1)], dim=0), padding=1
) + torch.cat([b1, b2], dim=0).view(1, -1, 1, 1)

torch.allclose(cat_1, cat_2, atol=1e-3)

True

## ET for Concat Operation(Where one residual has no convolution) - Eqn.8

In [7]:
x = torch.randn(1, 32, 64, 64)
y = torch.randn(1, 32, 64, 64)

w1 = torch.randn(32, 32, 3, 3)
b1 = torch.randn(32)

I = torch.eye(32).view(32, 32, 1, 1)
I = F.pad(I, (1, 1, 1, 1))  # To 3x3
O = torch.zeros(32, 32, 3, 3)

cat_1 = torch.cat([F.conv2d(x, w1, padding=1) + b1.view(1, -1, 1, 1), y], dim=1)

# Horizontal concatenation + vertical concatenation
cat_2 = F.conv2d(
    torch.cat([x, y], dim=1), torch.cat([torch.cat([w1, O], dim=1), torch.cat([O, I], dim=1)], dim=0), padding=1
) + torch.cat([b1, torch.zeros_like(b1)], dim=0).view(1, -1, 1, 1)

torch.allclose(cat_1, cat_2, atol=1e-3)

True

## ET for Concat Operation(Where one residual has no convolution and x equals y) - Eqn.9

In [8]:
x = torch.randn(1, 32, 64, 64)
y = x.clone()

w1 = torch.randn(32, 32, 3, 3)
b1 = torch.randn(32)

I = torch.eye(32).view(32, 32, 1, 1)
I = F.pad(I, (1, 1, 1, 1))  # To 3x3

cat_1 = torch.cat([F.conv2d(x, w1, padding=1) + b1.view(1, -1, 1, 1), y], dim=1)

# Vertical concatenation
cat_2 = F.conv2d(
    x, torch.cat([w1, I], dim=0), padding=1
) + torch.cat([b1, torch.zeros_like(b1)], dim=0).view(1, -1, 1, 1)

torch.allclose(cat_1, cat_2, atol=1e-3)

True

## ET for Clip Operation - Eqn.10

In [9]:
x = torch.randn(1, 32, 64, 64) * 255

I = torch.eye(32).view(32, 32, 1, 1)

clip_1 = torch.clip(x, 0, 255)

# clip(x) equals ReLU(-ReLU(-x + 255) + 255)

y = F.relu(F.conv2d(x, -I, padding=0) + 255)
clip_2 = F.relu(F.conv2d(y, -I, padding=0) + 255)

torch.allclose(clip_1, clip_2, atol=1e-3)

True

## Let's Convert PlainSR using ET

In [19]:
def conv_1x1(x, w):
    return F.conv2d(x, w)

def conv_3x3(x, w):
    return F.conv2d(x, w, padding=1)

def relu(x):  # Actually, PlainSR used PReLU but for convenience we use ReLU
    return F.relu(x)


# This is M3C32 Example. For convenience, bias is omitted.
def PlainSR(x, w_proj, w_1, w_2, w_3, w_out, scale=4):
    x_skip = x.clone()

    # body
    x = relu(conv_3x3(x, w_proj))
    x = relu(conv_3x3(x, w_1))
    x = relu(conv_3x3(x, w_2))
    x = relu(conv_3x3(x, w_3))
    x = conv_3x3(x, w_out)
    
    # skip connection
    x = x + x_skip.repeat(1, scale ** 2, 1, 1)  # This equals F.pixel_shuffle(x) + nn.UpSample(scale_factor=scale, mode='nearest')(x_skip)
    
    # clip
    x = torch.clip(x, 0, 255)
    
    # upsample
    x = F.pixel_shuffle(x, scale)
    return x


x = torch.FloatTensor(1, 3, 64, 64).uniform_(0, 1)
w_proj = torch.randn(32, 3, 3, 3)
w_1 = torch.randn(32, 32, 3, 3)
w_2 = torch.randn(32, 32, 3, 3)
w_3 = torch.randn(32, 32, 3, 3)
w_out = torch.randn(3 * 4 ** 2, 32, 3, 3)  # Colors * Scale ** 2, inp_planes, kernel_size, kernel_size

res_1 = PlainSR(x, w_proj, w_1, w_2, w_3, w_out)

print(res_1.shape)

torch.Size([1, 3, 256, 256])


## 1. Convert Repeat Operation into Convolution

In [22]:
def PlainSR_repeat_to_conv(x, w_proj, w_1, w_2, w_3, w_out, scale=4):
    x_skip = x.clone()
    I_input = torch.cat(
        [torch.eye(x.shape[1]).view(x.shape[1], x.shape[1], 1, 1)] * scale ** 2, dim=0
    )

    # body
    x = relu(conv_3x3(x, w_proj))
    x = relu(conv_3x3(x, w_1))
    x = relu(conv_3x3(x, w_2))
    x = relu(conv_3x3(x, w_3))
    x = conv_3x3(x, w_out)
    
    # skip connection
    x_skip = conv_1x1(x_skip, I_input) # <<<<<<<<<<<<<<<<<<<<<
    x = x + x_skip
    
    # clip
    x = torch.clip(x, 0, 255)
    
    # upsample
    x = F.pixel_shuffle(x, scale)
    return x

res_2 = PlainSR_repeat_to_conv(x, w_proj, w_1, w_2, w_3, w_out)

torch.allclose(res_1, res_2, atol=1e-3)

True

## 2. Convert Add Operation into Concat and Convolution

In [24]:
def PlainSR_add_to_conv(x, w_proj, w_1, w_2, w_3, w_out, scale=4):
    x_skip = x.clone()
    I_input = torch.cat(
        [torch.eye(x.shape[1]).view(x.shape[1], x.shape[1], 1, 1)] * scale ** 2, dim=0
    )
    I_out = torch.eye(w_out.shape[0]).view(w_out.shape[0], w_out.shape[0], 1, 1).repeat(1, 2, 1, 1) # <<<<<<<<<<<<<<<<<<<<<

    # body
    x = relu(conv_3x3(x, w_proj))
    x = relu(conv_3x3(x, w_1))
    x = relu(conv_3x3(x, w_2))
    x = relu(conv_3x3(x, w_3))
    x = conv_3x3(x, w_out)
    
    # skip connection
    x_skip = conv_1x1(x_skip, I_input) 
    x_cat = torch.cat([x, x_skip], dim=1)  # <<<<<<<<<<<<<<<<<<<<<
    x = conv_1x1(x_cat, I_out)   # <<<<<<<<<<<<<<<<<<<<<
    
    # clip
    x = torch.clip(x, 0, 255)
    
    # upsample
    x = F.pixel_shuffle(x, scale)
    return x

res_3 = PlainSR_add_to_conv(x, w_proj, w_1, w_2, w_3, w_out)

torch.allclose(res_1, res_3, atol=1e-3)

True

## 3. Convert Concat Operation into Convolution

In [27]:
def PlainSR_concat_to_conv(x, w_proj, w_1, w_2, w_3, w_out, scale=4):
    x_skip = x.clone()
    I_input = torch.cat(
        [torch.eye(x.shape[1]).view(x.shape[1], x.shape[1], 1, 1)] * scale ** 2, dim=0
    )
    I_input = F.pad(I_input, (1, 1, 1, 1))  # To 3x3 <<<<<<<<<<<<<<<<<<<<<
    I_out = torch.eye(w_out.shape[0]).view(w_out.shape[0], w_out.shape[0], 1, 1).repeat(1, 2, 1, 1)

    w_out_et = torch.cat(
        [torch.cat([w_out, torch.zeros_like(I_input)], dim=1), torch.cat([torch.zeros_like(w_out), I_input], dim=1)], dim=0
    )

    # body
    x = relu(conv_3x3(x, w_proj))
    x = relu(conv_3x3(x, w_1))
    x = relu(conv_3x3(x, w_2))
    x = relu(conv_3x3(x, w_3))
    # x = conv_3x3(x, w_out)
    x_cat_bef = torch.cat([x, x_skip], dim=1) # <<<<<<<<<<<<<<<<<<<<<
    x_cat = conv_3x3(x_cat_bef, w_out_et) # <<<<<<<<<<<<<<<<<<<<<
    
    # skip connection
    x = conv_1x1(x_cat, I_out)   # <<<<<<<<<<<<<<<<<<<<<
    
    # clip
    x = torch.clip(x, 0, 255)
    
    # upsample
    x = F.pixel_shuffle(x, scale)
    
    return x

res_4 = PlainSR_concat_to_conv(x, w_proj, w_1, w_2, w_3, w_out)

torch.allclose(res_1, res_4, atol=1e-3)

True

## 4. Structural re-parameterize Concat Conv and Skip Conv

In [56]:
def merge_sequential_conv(w1, w2):
    w1 = w1.permute(2, 3, 1, 0)
    w2 = w2.permute(2, 3, 1, 0)
    k = w1 @ w2
    k = k.permute(3, 2, 0, 1)
    return k

def PlainSR_reparam_concat_and_skip(x, w_proj, w_1, w_2, w_3, w_out, scale=4):
    x_skip = x.clone()
    I_input = torch.cat(
        [torch.eye(x.shape[1]).view(x.shape[1], x.shape[1], 1, 1)] * scale ** 2, dim=0
    )
    I_input = F.pad(I_input, (1, 1, 1, 1))
    I_out = torch.eye(w_out.shape[0]).view(w_out.shape[0], w_out.shape[0], 1, 1).repeat(1, 2, 1, 1)

    w_out_et = torch.cat(
        [torch.cat([w_out, torch.zeros_like(I_input)], dim=1), torch.cat([torch.zeros_like(w_out), I_input], dim=1)], dim=0
    )
    
    w_concat_and_skip_et = merge_sequential_conv(w_out_et, I_out) # <<<<<<<<<<<<<<<<<<<<<
    
    # body
    x = relu(conv_3x3(x, w_proj))
    x = relu(conv_3x3(x, w_1))
    x = relu(conv_3x3(x, w_2))
    x = relu(conv_3x3(x, w_3))
    x_cat_bef = torch.cat([x, x_skip], dim=1) 
    
    # concat + skip
    x = conv_3x3(x_cat_bef, w_concat_and_skip_et) # <<<<<<<<<<<<<<<<<<<<<
    # x_cat = conv_3x3(x_cat_bef, w_out_et)
    
    # # skip connection
    # x = conv_1x1(x_cat, I_out)
    
    # clip
    x = torch.clip(x, 0, 255)
    
    # upsample
    x = F.pixel_shuffle(x, scale)
    
    return x


res_5 = PlainSR_reparam_concat_and_skip(x, w_proj, w_1, w_2, w_3, w_out)

torch.allclose(res_1, res_5, atol=1e-3)

True

## 5. Convert Concat Operation into Convolution(Where one residual has no convolution)

In [57]:
def PlainSR_concat_noconvres_to_conv(x, w_proj, w_1, w_2, w_3, w_out, scale=4):
    x_skip = x.clone()
    I = torch.eye(x.shape[1]).view(x.shape[1], x.shape[1], 1, 1)
    I = F.pad(I, (1, 1, 1, 1))  # To 3x3 <<<<<<<<<<<<<<<<<<<<<
    I_input = torch.cat(
        [torch.eye(x.shape[1]).view(x.shape[1], x.shape[1], 1, 1)] * scale ** 2, dim=0
    )
    I_input = F.pad(I_input, (1, 1, 1, 1))
    I_out = torch.eye(w_out.shape[0]).view(w_out.shape[0], w_out.shape[0], 1, 1).repeat(1, 2, 1, 1)
    
    w_3_et = torch.cat(
        [torch.cat([w_3, torch.zeros((w_3.shape[0], I.shape[1], 3, 3))], dim=1), torch.cat([torch.zeros((I.shape[0], w_3.shape[1], 3, 3)), I], dim=1)], dim=0
    )
    w_out_et = torch.cat(
        [torch.cat([w_out, torch.zeros_like(I_input)], dim=1), torch.cat([torch.zeros_like(w_out), I_input], dim=1)], dim=0
    )
    
    w_concat_and_skip_et = merge_sequential_conv(w_out_et, I_out)
    
    # body
    x = relu(conv_3x3(x, w_proj))
    x = relu(conv_3x3(x, w_1))
    x = relu(conv_3x3(x, w_2))
    x_cat_bef_3 = torch.cat([x, x_skip], dim=1)
    x_cat_bef = relu(conv_3x3(x_cat_bef_3, w_3_et)) # <<<<<<<<<<<<<<<<<<<<<
    # x = relu(conv_3x3(x, w_3))
    # x_cat_bef = torch.cat([x, x_skip], dim=1) 
    
    # concat + skip
    x = conv_3x3(x_cat_bef, w_concat_and_skip_et)
    
    # clip
    x = torch.clip(x, 0, 255)
    
    # upsample
    x = F.pixel_shuffle(x, scale)
    
    return x


res_6 = PlainSR_concat_noconvres_to_conv(x, w_proj, w_1, w_2, w_3, w_out)

torch.allclose(res_1, res_6, atol=1e-3)

True

## 5-2. Repeat this to w1 and w2

In [58]:
def PlainSR_concat_noconvres_to_conv_2(x, w_proj, w_1, w_2, w_3, w_out, scale=4):
    x_skip = x.clone()
    I = torch.eye(x.shape[1]).view(x.shape[1], x.shape[1], 1, 1)
    I = F.pad(I, (1, 1, 1, 1))
    I_input = torch.cat(
        [torch.eye(x.shape[1]).view(x.shape[1], x.shape[1], 1, 1)] * scale ** 2, dim=0
    )
    I_input = F.pad(I_input, (1, 1, 1, 1))
    I_out = torch.eye(w_out.shape[0]).view(w_out.shape[0], w_out.shape[0], 1, 1).repeat(1, 2, 1, 1)
    
    w_2_et = torch.cat(
        [torch.cat([w_2, torch.zeros((w_2.shape[0], I.shape[1], 3, 3))], dim=1), torch.cat([torch.zeros((I.shape[0], w_2.shape[1], 3, 3)), I], dim=1)], dim=0
    )    # <<<<<<<<<<<<<<<<<<<<<
    w_3_et = torch.cat(
        [torch.cat([w_3, torch.zeros((w_3.shape[0], I.shape[1], 3, 3))], dim=1), torch.cat([torch.zeros((I.shape[0], w_3.shape[1], 3, 3)), I], dim=1)], dim=0
    )
    w_out_et = torch.cat(
        [torch.cat([w_out, torch.zeros_like(I_input)], dim=1), torch.cat([torch.zeros_like(w_out), I_input], dim=1)], dim=0
    )
    
    w_concat_and_skip_et = merge_sequential_conv(w_out_et, I_out)
    
    # body
    x = relu(conv_3x3(x, w_proj))
    x = relu(conv_3x3(x, w_1))
    x_cat_bef_2 = torch.cat([x, x_skip], dim=1)
    x_cat_bef_3 = relu(conv_3x3(x_cat_bef_2, w_2_et)) # <<<<<<<<<<<<<<<<<<<<<
    # x = relu(conv_3x3(x, w_2))
    # x_cat_bef_3 = torch.cat([x, x_skip], dim=1)
    x_cat_bef = relu(conv_3x3(x_cat_bef_3, w_3_et))
    
    # concat + skip
    x = conv_3x3(x_cat_bef, w_concat_and_skip_et)
    
    # clip
    x = torch.clip(x, 0, 255)
    
    # upsample
    x = F.pixel_shuffle(x, scale)
    
    return x

res_7 = PlainSR_concat_noconvres_to_conv_2(x, w_proj, w_1, w_2, w_3, w_out)

torch.allclose(res_1, res_7, atol=1e-3)

True

In [59]:
def PlainSR_concat_noconvres_to_conv_3(x, w_proj, w_1, w_2, w_3, w_out, scale=4):
    x_skip = x.clone()
    I = torch.eye(x.shape[1]).view(x.shape[1], x.shape[1], 1, 1)
    I = F.pad(I, (1, 1, 1, 1))
    I_input = torch.cat(
        [torch.eye(x.shape[1]).view(x.shape[1], x.shape[1], 1, 1)] * scale ** 2, dim=0
    )
    I_input = F.pad(I_input, (1, 1, 1, 1))
    I_out = torch.eye(w_out.shape[0]).view(w_out.shape[0], w_out.shape[0], 1, 1).repeat(1, 2, 1, 1)
    
    w_1_et = torch.cat(
        [torch.cat([w_1, torch.zeros((w_1.shape[0], I.shape[1], 3, 3))], dim=1), torch.cat([torch.zeros((I.shape[0], w_1.shape[1], 3, 3)), I], dim=1)], dim=0
    )   # <<<<<<<<<<<<<<<<<<<<<
    w_2_et = torch.cat(
        [torch.cat([w_2, torch.zeros((w_2.shape[0], I.shape[1], 3, 3))], dim=1), torch.cat([torch.zeros((I.shape[0], w_2.shape[1], 3, 3)), I], dim=1)], dim=0
    )  
    w_3_et = torch.cat(
        [torch.cat([w_3, torch.zeros((w_3.shape[0], I.shape[1], 3, 3))], dim=1), torch.cat([torch.zeros((I.shape[0], w_3.shape[1], 3, 3)), I], dim=1)], dim=0
    )
    w_out_et = torch.cat(
        [torch.cat([w_out, torch.zeros_like(I_input)], dim=1), torch.cat([torch.zeros_like(w_out), I_input], dim=1)], dim=0
    )
    
    w_concat_and_skip_et = merge_sequential_conv(w_out_et, I_out)
    
    # body
    x = relu(conv_3x3(x, w_proj))
    x_cat_bef_1 = torch.cat([x, x_skip], dim=1)
    x_cat_bef_2 = relu(conv_3x3(x_cat_bef_1, w_1_et)) # <<<<<<<<<<<<<<<<<<<<<
    x_cat_bef_3 = relu(conv_3x3(x_cat_bef_2, w_2_et))
    x_cat_bef = relu(conv_3x3(x_cat_bef_3, w_3_et))
    
    # concat + skip
    x = conv_3x3(x_cat_bef, w_concat_and_skip_et)
    
    # clip
    x = torch.clip(x, 0, 255)
    
    # upsample
    x = F.pixel_shuffle(x, scale)
    
    return x

res_8 = PlainSR_concat_noconvres_to_conv_3(x, w_proj, w_1, w_2, w_3, w_out)

torch.allclose(res_1, res_8, atol=1e-3)

True

## 6. Convert Input Concat Operation into Convolution

In [60]:
def PlainSR_concat_input_to_conv(x, w_proj, w_1, w_2, w_3, w_out, scale=4):
    # x_skip = x.clone()  <<<<<<<<<<<<<<<<<<<<<
    I = torch.eye(x.shape[1]).view(x.shape[1], x.shape[1], 1, 1)
    I = F.pad(I, (1, 1, 1, 1))
    I_input = torch.cat(
        [torch.eye(x.shape[1]).view(x.shape[1], x.shape[1], 1, 1)] * scale ** 2, dim=0
    )
    I_input = F.pad(I_input, (1, 1, 1, 1))
    I_out = torch.eye(w_out.shape[0]).view(w_out.shape[0], w_out.shape[0], 1, 1).repeat(1, 2, 1, 1)
    
    w_proj_et = torch.cat([w_proj, I], dim=0)   # <<<<<<<<<<<<<<<<<<<<<
    w_1_et = torch.cat(
        [torch.cat([w_1, torch.zeros((w_1.shape[0], I.shape[1], 3, 3))], dim=1), torch.cat([torch.zeros((I.shape[0], w_1.shape[1], 3, 3)), I], dim=1)], dim=0
    )   
    w_2_et = torch.cat(
        [torch.cat([w_2, torch.zeros((w_2.shape[0], I.shape[1], 3, 3))], dim=1), torch.cat([torch.zeros((I.shape[0], w_2.shape[1], 3, 3)), I], dim=1)], dim=0
    )  
    w_3_et = torch.cat(
        [torch.cat([w_3, torch.zeros((w_3.shape[0], I.shape[1], 3, 3))], dim=1), torch.cat([torch.zeros((I.shape[0], w_3.shape[1], 3, 3)), I], dim=1)], dim=0
    )
    w_out_et = torch.cat(
        [torch.cat([w_out, torch.zeros_like(I_input)], dim=1), torch.cat([torch.zeros_like(w_out), I_input], dim=1)], dim=0
    )
    
    w_concat_and_skip_et = merge_sequential_conv(w_out_et, I_out)
    
    # body
    x_cat_bef_1 = relu(conv_3x3(x, w_proj_et)) # <<<<<<<<<<<<<<<<<<<<<
    # x = relu(conv_3x3(x, w_proj))
    # x_cat_bef_1 = torch.cat([x, x_skip], dim=1)
    x_cat_bef_2 = relu(conv_3x3(x_cat_bef_1, w_1_et))
    x_cat_bef_3 = relu(conv_3x3(x_cat_bef_2, w_2_et))
    x_cat_bef = relu(conv_3x3(x_cat_bef_3, w_3_et))
    
    # concat + skip
    x = conv_3x3(x_cat_bef, w_concat_and_skip_et)
    
    # clip
    x = torch.clip(x, 0, 255)
    
    # upsample
    x = F.pixel_shuffle(x, scale)
    
    return x

res_8 = PlainSR_concat_input_to_conv(x, w_proj, w_1, w_2, w_3, w_out)

torch.allclose(res_1, res_8, atol=1e-3)

True

## 7. Convert Clip Operation into Convolution

In [64]:
def PlainSR_clip_to_conv(x, w_proj, w_1, w_2, w_3, w_out, scale=4):
    I = torch.eye(x.shape[1]).view(x.shape[1], x.shape[1], 1, 1)
    I = F.pad(I, (1, 1, 1, 1))
    I_input = torch.cat(
        [torch.eye(x.shape[1]).view(x.shape[1], x.shape[1], 1, 1)] * scale ** 2, dim=0
    )
    I_input = F.pad(I_input, (1, 1, 1, 1))
    I_out = torch.eye(w_out.shape[0]).view(w_out.shape[0], w_out.shape[0], 1, 1).repeat(1, 2, 1, 1)
    
    w_proj_et = torch.cat([w_proj, I], dim=0)
    w_1_et = torch.cat(
        [torch.cat([w_1, torch.zeros((w_1.shape[0], I.shape[1], 3, 3))], dim=1), torch.cat([torch.zeros((I.shape[0], w_1.shape[1], 3, 3)), I], dim=1)], dim=0
    )   
    w_2_et = torch.cat(
        [torch.cat([w_2, torch.zeros((w_2.shape[0], I.shape[1], 3, 3))], dim=1), torch.cat([torch.zeros((I.shape[0], w_2.shape[1], 3, 3)), I], dim=1)], dim=0
    )  
    w_3_et = torch.cat(
        [torch.cat([w_3, torch.zeros((w_3.shape[0], I.shape[1], 3, 3))], dim=1), torch.cat([torch.zeros((I.shape[0], w_3.shape[1], 3, 3)), I], dim=1)], dim=0
    )
    w_out_et = torch.cat(
        [torch.cat([w_out, torch.zeros_like(I_input)], dim=1), torch.cat([torch.zeros_like(w_out), I_input], dim=1)], dim=0
    )
    
    w_concat_and_skip_et = merge_sequential_conv(w_out_et, I_out)
    
    # body
    x_cat_bef_1 = relu(conv_3x3(x, w_proj_et))
    x_cat_bef_2 = relu(conv_3x3(x_cat_bef_1, w_1_et))
    x_cat_bef_3 = relu(conv_3x3(x_cat_bef_2, w_2_et))
    x_cat_bef = relu(conv_3x3(x_cat_bef_3, w_3_et))
    
    # concat + skip
    x = conv_3x3(x_cat_bef, w_concat_and_skip_et)
    
    # clip
    negative_I_clip = -torch.eye(x.shape[1]).view(x.shape[1], x.shape[1], 1, 1)
    y = relu(conv_1x1(x, negative_I_clip) + 255)
    x = relu(conv_1x1(y, negative_I_clip) + 255)
    
    # upsample
    x = F.pixel_shuffle(x, scale)
    
    return x


res_9 = PlainSR_clip_to_conv(x, w_proj, w_1, w_2, w_3, w_out)

torch.allclose(res_1, res_9, atol=1e-3)

True

Result of Converted PlainSR has no repeat, add, concat, and clip operation that slows down Mobile SoC's performance.