Skip to content

Commit

Permalink
implement Layer interface
Browse files Browse the repository at this point in the history
  • Loading branch information
haifengl committed Mar 27, 2024
1 parent 4363b82 commit 6ef9c63
Showing 1 changed file with 13 additions and 6 deletions.
19 changes: 13 additions & 6 deletions deep/src/main/java/smile/llm/PositionalEncoding.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package smile.llm;

import org.bytedeco.pytorch.Module;
import smile.deep.layer.Layer;
import smile.deep.tensor.Device;
import smile.deep.tensor.Tensor;
import static smile.deep.tensor.Index.*;
Expand All @@ -31,7 +32,7 @@
*
* @author Haifeng Li
*/
public class PositionalEncoding {
public class PositionalEncoding implements Layer {
/** The module to register the buffer. */
private Module module;
/** The dropout probability. */
Expand Down Expand Up @@ -66,11 +67,7 @@ public PositionalEncoding(int dModel, double dropout, int maxLen) {
module.register_buffer("pe", pe.asTorch());
}

/**
* Returns the positional encoding of a sequence.
* @param x the sequence fed to the positional encoder model.
* @return the encoded tensor.
*/
@Override
public Tensor forward(Tensor x) {
Tensor p = pe.get(
slice(null, x.size(0)),
Expand All @@ -80,6 +77,16 @@ public Tensor forward(Tensor x) {
return Tensor.of(torch.dropout(xp.asTorch(), dropout, true));
}

@Override
public void register(String name, Layer parent) {
module = parent.asTorch().register_module(name, module);
}

@Override
public Module asTorch() {
return module;
}

/**
* Moves the encoder to a device.
* @param device the compute device.
Expand Down

0 comments on commit 6ef9c63

Please sign in to comment.