Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unindent continuous_mean_std buffer #2

Merged
merged 1 commit into from Jan 5, 2021
Merged

Unindent continuous_mean_std buffer #2

merged 1 commit into from Jan 5, 2021

Conversation

spliew
Copy link
Contributor

@spliew spliew commented Jan 5, 2021

Problem: continuous_mean_std is not an attribute of TabTransformer if not defined in the argument explicitly.
Example reproducing AttributeError:

model = TabTransformer(
    categories = (10, 5, 6, 5, 8),      # tuple containing the number of unique values within each category
    num_continuous = 10,                # number of continuous values
    dim = 32,                           # dimension, paper set at 32
    dim_out = 1,                        # binary prediction, but could be anything
    depth = 6,                          # depth, paper recommended 6
    heads = 8,                          # heads, paper recommends 8
    attn_dropout = 0.1,                 # post-attention dropout
    ff_dropout = 0.1,                   # feed forward dropout
    mlp_hidden_mults = (4, 2),          # relative multiples of each hidden dimension of the last mlp to logits
    mlp_act = nn.ReLU(),                # activation for final mlp, defaults to relu, but could be anything else (selu etc)
# continuous_mean_std = cont_mean_std # (optional) - normalize the continuous values before layer norm)
x_categ = torch.randint(0, 5, (1, 5))     # category values, from 0 - max number of categories, in the order as passed into the constructor above
x_cont = torch.randn(1, 10)               # assume continuous values are already normalized individually
pred = model(x_categ, x_cont) # gives AttributeError

Solution: Simply un-indenting the buffer registration of continuous_mean_std.

@lucidrains lucidrains merged commit 3db25cc into lucidrains:main Jan 5, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants