-
Notifications
You must be signed in to change notification settings - Fork 942
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
What is the equivalent to a Flatten layer in MLX? #1308
Comments
There is no class MLP(nn.Module):
def __init__(self, out_dims):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(1, 20, 5), # input channels, output channels, kernel size
nn.ReLU(),
nn.MaxPool2d(2, 2), # kernel size, stride length
nn.Conv2d(20, 50, 5),
nn.ReLU(),
nn.MaxPool2d(2, 2),
)
self.mlp = nn.Sequential(
nn.Linear(800, 500),
nn.ReLU(),
nn.Linear(500, 10),
)
def __call__(self, x):
x = self.conv(x):
x = x.flatten(-3, -1)
x = self.mlp(x)
return(x) At some point we considered adding |
@awni Thanks, I have tried this. But I've ended up across another bug. Here is an minimal example:
The parameters become |
That I have no idea about. You’d need to share more code to fully reproduce this so we can help debug. also make sure you are using the latest MLX. |
@awni I'm attaching my code here. I am using version 0.16.1. |
You are converting My recommendation would be to not convert the labels to one hot, just use them as is which works with Alternatively you could change your batch iteration to get the right dataset size: def batch_iterate(batch_size, X, y):
perm = mx.array(np.random.permutation(y.shape[0]))
for s in range(0, y.shape[0], batch_size):
ids = perm[s: s+batch_size]
yield X[ids], y[ids] |
Great catch. I overlooked |
I am trying to implement a simple LeNet:
Torch offers nn.Flatten for this. I could not find an equivalent in MLX. Can someone give directions regarding this?
The text was updated successfully, but these errors were encountered: