Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

nn.Conv2d output mismatch with torch.nn.Conv2d #66

Closed
dashesy opened this issue Nov 1, 2022 · 7 comments
Closed

nn.Conv2d output mismatch with torch.nn.Conv2d #66

dashesy opened this issue Nov 1, 2022 · 7 comments

Comments

@dashesy
Copy link
Contributor

dashesy commented Nov 1, 2022

This is a full repro:

import torch
import numpy as np

from collections import OrderedDict

from aitemplate.testing import detect_target
from aitemplate.frontend import nn, Tensor
from aitemplate.compiler import compile_model

def map_pt_params(ait_model, pt_model):
  ait_model.name_parameter_tensor()
  pt_params = dict(pt_model.named_parameters())
  mapped_pt_params = OrderedDict()
  # names should be valid C++ variables
  for name, _ in ait_model.named_parameters():
    ait_name = name.replace(".", "_")
    assert name in pt_params
    params = pt_params[name]
    if len(params.shape) == 4:
        # NCHW->NHWC
        params = params.permute(0,2,3,1).contiguous()
        # Pad for few channels
        if params.shape[-1] == 3:
            print(f"pad {name}")
            params = torch.nn.functional.pad(params, (0,1))
    mapped_pt_params[ait_name] = params
  return mapped_pt_params

def mark_output(Y):
    Y._attrs["is_output"] = True
    Y._attrs["name"] = "Y"


def get_input(shape=None):
    X = Tensor(
        shape=shape,
        name="X",
        dtype="float16",
        is_input=True,
    )
    return X

EPS = 1e-1

class ConvEmbed(nn.Module):
    """ Image to Patch Embedding
    """

    def __init__(
        self,
        in_chans=3,
        embed_dim=64,
        patch_size=7,
        stride=4,
        padding=2,
    ):
        super().__init__()
        self.patch_size = patch_size

        self.proj = nn.Conv2dBias(
            in_chans, embed_dim, kernel_size=patch_size, stride=stride, padding=padding
        )


    def forward(self, x):
        x = self.proj(x)
        return x

class ConvEmbedPt(torch.nn.Module):
    """ Image to Patch Embedding
    """

    def __init__(
        self,
        in_chans=3,
        embed_dim=64,
        patch_size=7,
        stride=4,
        padding=2,
    ):
        super().__init__()
        self.patch_size = patch_size

        self.proj = torch.nn.Conv2d(
            in_chans, embed_dim, kernel_size=patch_size, stride=stride, padding=padding
        )


    def forward(self, x):
        x = self.proj(x)
        return x

def build_convembed0():
    ait_model = ConvEmbed(in_chans=4, embed_dim=256, patch_size=7, stride=4, padding=3)
    ait_model.name_parameter_tensor()
    X = get_input(shape=[1,384,384,4])
    Y = ait_model(X)
    mark_output(Y)
    return ait_model, Y

x = torch.rand(1,4,384,384).cuda().half()
m = ConvEmbedPt(4,256,patch_size=7,stride=4,padding=3).cuda().half()
with torch.no_grad():
    y_pt = m(x)
ait_model, Y = build_convembed0()
weights = map_pt_params(ait_model, m)
target = detect_target()
module = compile_model(Y, target, "./output", "repro", constants=weights)

inputs = [x.permute((0, 2, 3, 1)).contiguous()]
ys = []
num_ouputs = len(module.get_output_name_to_index_map())
for i in range(num_ouputs):
    shape = module.get_output_maximum_shape(i)
    ys.append(torch.empty(shape).cuda().half())

module.run_with_tensors(inputs, ys)
print((y_pt.permute(0,2,3,1) - ys[0]).abs().max())
np.testing.assert_allclose(
    y_pt.transpose(1,-1).detach().cpu().numpy(),
    ys[0].cpu().numpy(),
    atol=0.1,
    rtol=0.1,
)
  1. I tried using nn.Conv2dBias and nn.Conv2d which did not help (Conv has bias but it was not clear if nn.Conv2dBias is same as nn.Conv2d).
  2. Also tries not transposing the weights, also not helpful
@antinucleon
Copy link
Contributor

Could you directly modify this unittest to see if there anything abnormal? https://github.com/facebookincubator/AITemplate/blob/main/tests/unittest/ops/test_conv.py

@dashesy
Copy link
Contributor Author

dashesy commented Nov 1, 2022

When I change that UT it passes:

class ConvTestCase(unittest.TestCase):
    def test_fp16(self, batch=1):
        target = detect_target()
        X = Tensor(
            shape=[1, 384, 384, 4],
            dtype="float16",
            name="input_0",
            is_input=True,
        )
        W = Tensor(
            shape=[256, 7, 7, 4], dtype="float16", name="input_1", is_input=True
        )
        OP = ops.conv2d(stride=4, pad=3, dilate=1)
        Y = OP(X, W)
        Y._attrs["name"] = "output_0"
        Y._attrs["is_output"] = True
        module = compile_model(Y, target, "./tmp", "conv2d")

        X_pt = torch.randn(1, 4, 384, 384).cuda().half()
        W_pt = torch.randn(256, 4, 7, 7).cuda().half()
        Y_pt = torch.nn.functional.conv2d(X_pt, W_pt, padding=3, stride=4)
        x = X_pt.permute((0, 2, 3, 1)).contiguous()
        w = W_pt.permute((0, 2, 3, 1)).contiguous()
        y = torch.empty([1, 96, 96, 256]).cuda().half()
        module.run_with_tensors({"input_0": x, "input_1": w}, [y])
        y_transpose = y.permute((0, 3, 1, 2))
        if target.name() == "cuda":
            self.assertTrue(torch.allclose(Y_pt, y_transpose, atol=1e-2, rtol=1e-2))
        else:
            self.assertTrue(torch.allclose(Y_pt, y_transpose, atol=1.25e-1, rtol=1e-1))

Will have to look further to see what is the difference between module and OP versions

@antinucleon
Copy link
Contributor

This visualization tool is very helpful for investigating: https://facebookincubator.github.io/AITemplate/tutorial/how_to_visualize.html

@dashesy
Copy link
Contributor Author

dashesy commented Nov 2, 2022

The visualization is pretty simple

image

@antinucleon
Copy link
Contributor

antinucleon commented Nov 2, 2022 via email

@dashesy
Copy link
Contributor Author

dashesy commented Nov 2, 2022

The only attribute is op_type. For nn.Conv2dBias it is conv2d_bias so I changed test_conv_bias accordingly but it too passes the UT.

class ConvBiasTestCase(unittest.TestCase):
    def test_fp16(self, batch=4):
        target = detect_target()
        X = Tensor(
            shape=[1, 384, 384, 4],
            dtype="float16",
            name="input_0",
            is_input=True,
        )
        W = Tensor(
            shape=[256, 7, 7, 4], dtype="float16", name="input_1", is_input=True
        )
        B = Tensor(shape=[256], dtype="float16", name="input_2", is_input=True)
        OP = ops.conv2d_bias(stride=4, pad=3, dilate=1)
        Y = OP(X, W, B)
        Y._attrs["name"] = "output_0"
        Y._attrs["is_output"] = True
        module = compile_model(Y, target, "./tmp", "conv2d_bias")

        X_pt = torch.randn(1, 4, 384, 384).cuda().half()
        W_pt = torch.randn(256, 4, 7, 7).cuda().half()
        B_pt = torch.randn(1, 256, 1, 1).cuda().half()
        Y_pt = torch.nn.functional.conv2d(X_pt, W_pt, padding=3, stride=4)
        Y_pt = Y_pt + B_pt
        x = X_pt.permute((0, 2, 3, 1)).contiguous()
        w = W_pt.permute((0, 2, 3, 1)).contiguous()
        inputs = {"input_0": x, "input_1": w, "input_2": B_pt.squeeze()}
        y = torch.empty([1, 96, 96, 256]).cuda().half()
        module.run_with_tensors(inputs, [y])
        y_transpose = y.permute((0, 3, 1, 2))
        if target.name() == "cuda":
            self.assertTrue(torch.allclose(Y_pt, y_transpose, atol=1e-2, rtol=1e-2))
        else:
            self.assertTrue(torch.allclose(Y_pt, y_transpose, atol=1.25e-1, rtol=1e-1))

@dashesy
Copy link
Contributor Author

dashesy commented Nov 2, 2022

I think this was because I had to delete the temp folder.

@dashesy dashesy closed this as completed Nov 2, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants