Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

What is the equivalent to a Flatten layer in MLX? #1308

Closed
s4m13337 opened this issue Aug 4, 2024 · 6 comments
Closed

What is the equivalent to a Flatten layer in MLX? #1308

s4m13337 opened this issue Aug 4, 2024 · 6 comments

Comments

@s4m13337
Copy link

s4m13337 commented Aug 4, 2024

I am trying to implement a simple LeNet:

class MLP(nn.Module):

    def __init__(self, out_dims):
        super().__init__()
        self.layers = [
            nn.Conv2d(1, 20, 5),    # input channels, output channels, kernel size
            nn.ReLU(),
            nn.MaxPool2d(2, 2),    # kernel size, stride length
            nn.Conv2d(20, 50, 5),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            ### Need a flatten layer here ###
            nn.Linear(800, 500),
            nn.ReLU(),
            nn.Linear(500, 10),
        ]

    def __call__(self, x):
        for l in self.layers:
            x = l(x)
        return(x)

Torch offers nn.Flatten for this. I could not find an equivalent in MLX. Can someone give directions regarding this?

@awni
Copy link
Member

awni commented Aug 4, 2024

There is no Flatten layer yet. You would have to redo the computation like so:

class MLP(nn.Module):

    def __init__(self, out_dims):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 20, 5),    # input channels, output channels, kernel size
            nn.ReLU(),
            nn.MaxPool2d(2, 2),    # kernel size, stride length
            nn.Conv2d(20, 50, 5),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
        )
        self.mlp = nn.Sequential(
            nn.Linear(800, 500),
            nn.ReLU(),
            nn.Linear(500, 10),
        )

    def __call__(self, x):
        x = self.conv(x):
        x = x.flatten(-3, -1)
        x =  self.mlp(x)
        return(x)

At some point we considered adding Flatten but decided we prefer not to mirror every op with an NN equivalent and the equivalent above is not so onerous.

@awni awni closed this as completed Aug 4, 2024
@s4m13337
Copy link
Author

s4m13337 commented Aug 4, 2024

@awni Thanks, I have tried this. But I've ended up across another bug.

Here is an minimal example:

i = 0
for X, y in batch_iterate(64, train_images, train_labels):
    i += 1
    loss, grads = loss_and_grad_fn(model, X, y)
    optimizer.update(model, grads)
    mx.eval(model.parameters(), optimizer.state)
    if(i == 100):
        break
print(model.parameters())

The parameters become nan at some point between the 100th and 150th batch. It varies on every evaluation but is usually within this range. Is this related to #1277 or #319 ?

@awni
Copy link
Member

awni commented Aug 4, 2024

That I have no idea about. You’d need to share more code to fully reproduce this so we can help debug.

also make sure you are using the latest MLX.

@s4m13337
Copy link
Author

s4m13337 commented Aug 4, 2024

LeNet MLX.ipynb.zip

@awni I'm attaching my code here. I am using version 0.16.1.

@awni
Copy link
Member

awni commented Aug 5, 2024

You are converting train_labels to one hot and then using it's size to determine the size of the dataset. That is a bug because the size will be a factor of 10 too large. So when in your batch_iterate function you will be reading lot's of unitialized memory.

My recommendation would be to not convert the labels to one hot, just use them as is which works with cross_entropy and is more efficient.

Alternatively you could change your batch iteration to get the right dataset size:

def batch_iterate(batch_size, X, y):
    perm = mx.array(np.random.permutation(y.shape[0]))
    for s in range(0, y.shape[0], batch_size):
        ids = perm[s: s+batch_size]
        yield X[ids], y[ids]

@s4m13337
Copy link
Author

s4m13337 commented Aug 6, 2024

Great catch. I overlooked y.size. Thank you for the help.

This issue was closed.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants