Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GR-18163] Support buffer keyword argument to Array#pack #3566

Merged
merged 1 commit into from
May 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ Compatibility:
* Allow null encoding pointer in `rb_enc_interned_str_cstr` (@thomasmarshall).
* Allow anonymous memberless Struct (@simonlevasseur).
* Set `$!` when a `Kernel#at_exit` hook raises an exception (#3535, @andrykonchin).
* Support `:buffer` keyword argument to `Array#pack` (#3559, @andrykonchyn).

Performance:
* Fix inline caching for Regexp creation from Strings (#3492, @andrykonchin, @eregon).
Expand Down
10 changes: 10 additions & 0 deletions spec/ruby/core/array/pack/buffer_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,16 @@
TypeError, "buffer must be String, not Array")
end

it "raise FrozenError if buffer is frozen" do
-> { [65].pack("c", buffer: "frozen-string".freeze) }.should raise_error(FrozenError)
end

it "preserves the encoding of the given buffer" do
buffer = ''.encode(Encoding::ISO_8859_1)
[65, 66, 67].pack("ccc", buffer: buffer)
buffer.encoding.should == Encoding::ISO_8859_1
end

context "offset (@) is specified" do
it 'keeps buffer content if it is longer than offset' do
n = [ 65, 66, 67 ]
Expand Down
6 changes: 0 additions & 6 deletions spec/tags/core/array/pack/buffer_tags.txt

This file was deleted.

52 changes: 40 additions & 12 deletions src/main/java/org/truffleruby/core/array/ArrayNodes.java
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
import org.truffleruby.core.encoding.RubyEncoding;
import org.truffleruby.core.format.BytesResult;
import org.truffleruby.core.format.FormatExceptionTranslator;
import org.truffleruby.core.format.FormatRootNode;
import org.truffleruby.core.format.exceptions.FormatException;
import org.truffleruby.core.format.pack.PackCompiler;
import org.truffleruby.core.hash.HashingNodes;
Expand Down Expand Up @@ -1491,15 +1492,15 @@ public void accept(Node node, CallBlockNode yieldNode, RubyArray array, Object s

}

@CoreMethod(names = "pack", required = 1, split = Split.ALWAYS)
public abstract static class ArrayPackNode extends CoreMethodArrayArgumentsNode {
@Primitive(name = "array_pack", lowerFixnum = 1)
public abstract static class ArrayPackPrimitiveNode extends PrimitiveArrayArgumentsNode {

@Specialization
RubyString pack(RubyArray array, Object format,
RubyString pack(RubyArray array, Object format, Object buffer,
@Cached ToStrNode toStrNode,
@Cached PackNode packNode) {
final var formatAsString = toStrNode.execute(this, format);
return packNode.execute(this, array, formatAsString);
return packNode.execute(this, array, formatAsString, buffer);
}
}

Expand All @@ -1508,28 +1509,35 @@ RubyString pack(RubyArray array, Object format,
@ReportPolymorphism
public abstract static class PackNode extends RubyBaseNode {

public abstract RubyString execute(Node node, RubyArray array, Object format);
public abstract RubyString execute(Node node, RubyArray array, Object format, Object buffer);

@Specialization(
guards = {
"libFormat.isRubyString(format)",
"libBuffer.isRubyString(buffer)",
"equalNode.execute(libFormat, format, cachedFormat, cachedEncoding)" },
limit = "getCacheLimit()")
static RubyString packCached(Node node, RubyArray array, Object format,
static RubyString packCached(Node node, RubyArray array, Object format, Object buffer,
@Cached @Shared InlinedBranchProfile exceptionProfile,
@Cached @Shared InlinedConditionProfile resizeProfile,
@Cached @Shared RubyStringLibrary libFormat,
@Cached @Shared RubyStringLibrary libBuffer,
@Cached @Shared WriteObjectFieldNode writeAssociatedNode,
@Cached @Shared TruffleString.FromByteArrayNode fromByteArrayNode,
@Cached @Shared TruffleString.CopyToByteArrayNode copyToByteArrayNode,
@Cached("asTruffleStringUncached(format)") TruffleString cachedFormat,
@Cached("libFormat.getEncoding(format)") RubyEncoding cachedEncoding,
@Cached("cachedFormat.byteLength(cachedEncoding.tencoding)") int cachedFormatLength,
@Cached("create(compileFormat(node, getJavaString(format)))") DirectCallNode callPackNode,
@Cached("compileFormat(node, getJavaString(format))") RootCallTarget formatCallTarget,
@Cached("create(formatCallTarget)") DirectCallNode callPackNode,
@Cached StringHelperNodes.EqualNode equalNode) {
final byte[] bytes = initOutputBytes(buffer, libBuffer, formatCallTarget, copyToByteArrayNode);

final BytesResult result;

try {
result = (BytesResult) callPackNode.call(
new Object[]{ array.getStore(), array.size, false, null });
new Object[]{ array.getStore(), array.size, bytes, libBuffer.byteLength(buffer) });
} catch (FormatException e) {
exceptionProfile.enter(node);
throw FormatExceptionTranslator.translate(getContext(node), node, e);
Expand All @@ -1538,22 +1546,28 @@ static RubyString packCached(Node node, RubyArray array, Object format,
return finishPack(node, cachedFormatLength, result, resizeProfile, writeAssociatedNode, fromByteArrayNode);
}

@Specialization(guards = { "libFormat.isRubyString(format)" }, replaces = "packCached")
static RubyString packUncached(Node node, RubyArray array, Object format,
@Specialization(guards = { "libFormat.isRubyString(format)", "libBuffer.isRubyString(buffer)" },
replaces = "packCached")
static RubyString packUncached(Node node, RubyArray array, Object format, Object buffer,
@Cached @Shared InlinedBranchProfile exceptionProfile,
@Cached @Shared InlinedConditionProfile resizeProfile,
@Cached @Shared RubyStringLibrary libFormat,
@Cached @Shared RubyStringLibrary libBuffer,
@Cached @Shared WriteObjectFieldNode writeAssociatedNode,
@Cached @Shared TruffleString.FromByteArrayNode fromByteArrayNode,
@Cached @Shared TruffleString.CopyToByteArrayNode copyToByteArrayNode,
@Cached ToJavaStringNode toJavaStringNode,
@Cached IndirectCallNode callPackNode) {
final String formatString = toJavaStringNode.execute(node, format);
final RootCallTarget formatCallTarget = compileFormat(node, formatString);
final byte[] bytes = initOutputBytes(buffer, libBuffer, formatCallTarget, copyToByteArrayNode);

final BytesResult result;

try {
result = (BytesResult) callPackNode.call(
compileFormat(node, formatString),
new Object[]{ array.getStore(), array.size, false, null });
formatCallTarget,
new Object[]{ array.getStore(), array.size, bytes, libBuffer.byteLength(buffer) });
} catch (FormatException e) {
exceptionProfile.enter(node);
throw FormatExceptionTranslator.translate(getContext(node), node, e);
Expand Down Expand Up @@ -1591,6 +1605,20 @@ protected static RootCallTarget compileFormat(Node node, String format) {
}
}

private static byte[] initOutputBytes(Object buffer, RubyStringLibrary libBuffer,
RootCallTarget formatCallTarget, TruffleString.CopyToByteArrayNode copyToByteArrayNode) {
int bufferLength = libBuffer.byteLength(buffer);
var formatRootNode = (FormatRootNode) formatCallTarget.getRootNode();
int expectedLength = formatRootNode.getExpectedLength();

// output buffer should be at least expectedLength to not mess up the expectedLength's logic
final int length = Math.max(bufferLength, expectedLength);
final byte[] bytes = new byte[length];
copyToByteArrayNode.execute(libBuffer.getTString(buffer), 0, bytes, 0, bufferLength,
libBuffer.getTEncoding(buffer));
return bytes;
}

protected int getCacheLimit() {
return getLanguage().options.PACK_CACHE;
}
Expand Down
23 changes: 20 additions & 3 deletions src/main/java/org/truffleruby/core/format/FormatRootNode.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,26 +30,39 @@ public final class FormatRootNode extends RubyBaseRootNode implements InternalRo
@Child private FormatNode child;

@CompilationFinal private int expectedLength = 0;
private final boolean acceptOutput;
private final boolean acceptOutputPosition;

public FormatRootNode(
RubyLanguage language,
SourceSection sourceSection,
FormatEncoding encoding,
FormatNode child) {
FormatNode child,
boolean acceptOutput,
boolean acceptOutputPosition) {
super(language, FormatFrameDescriptor.FRAME_DESCRIPTOR, sourceSection);
this.encoding = encoding;
this.child = child;
this.acceptOutput = acceptOutput;
this.acceptOutputPosition = acceptOutputPosition;
}

/** Accepts the following arguments stored in a frame: source array, its length, output buffer as a bytes array,
* (optional) position in the output buffer to start from */
@SuppressWarnings("unchecked")
@Override
public Object execute(VirtualFrame frame) {
frame.setObject(FormatFrameDescriptor.SOURCE_SLOT, frame.getArguments()[0]);
frame.setInt(FormatFrameDescriptor.SOURCE_END_POSITION_SLOT, (int) frame.getArguments()[1]);
frame.setInt(FormatFrameDescriptor.SOURCE_START_POSITION_SLOT, 0);
frame.setInt(FormatFrameDescriptor.SOURCE_POSITION_SLOT, 0);
frame.setObject(FormatFrameDescriptor.OUTPUT_SLOT, new byte[expectedLength]);
frame.setInt(FormatFrameDescriptor.OUTPUT_POSITION_SLOT, 0);

final byte[] outputInit = acceptOutput ? (byte[]) frame.getArguments()[2] : new byte[expectedLength];
frame.setObject(FormatFrameDescriptor.OUTPUT_SLOT, outputInit);

final int outputPosition = acceptOutputPosition ? (int) frame.getArguments()[3] : 0;
frame.setInt(FormatFrameDescriptor.OUTPUT_POSITION_SLOT, outputPosition);

frame.setObject(FormatFrameDescriptor.ASSOCIATED_SLOT, null);

child.execute(frame);
Expand Down Expand Up @@ -95,6 +108,10 @@ public String getName() {
return "format";
}

public int getExpectedLength() {
return expectedLength;
}

@Override
public String toString() {
return getName();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ public RootCallTarget compile(String format) throws DeferredRaiseException {
language,
currentNode.getEncapsulatingSourceSection(),
builder.getEncoding(),
builder.getNode()).getCallTarget();
builder.getNode(),
true,
true).getCallTarget();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ public RootCallTarget compile(AbstractTruffleString tstring, RubyEncoding encodi
language,
currentNode.getEncapsulatingSourceSection(),
new FormatEncoding(encoding),
builder.getNode()).getCallTarget();
builder.getNode(),
false,
false).getCallTarget();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ public RootCallTarget compile(AbstractTruffleString formatTString, RubyEncoding
language,
currentNode.getEncapsulatingSourceSection(),
new FormatEncoding(formatEncoding),
builder.getNode()).getCallTarget();
builder.getNode(),
false,
false).getCallTarget();
}

private static int SIGN = 0x10;
Expand Down
14 changes: 14 additions & 0 deletions src/main/ruby/truffleruby/core/array.rb
Original file line number Diff line number Diff line change
Expand Up @@ -719,6 +719,20 @@ def last(n = undefined)
Array.new self[-n..-1]
end

def pack(format, buffer: nil)
if Primitive.nil? buffer
Primitive.array_pack(self, format, '')
else
unless Primitive.is_a?(buffer, String)
raise TypeError, "buffer must be String, not #{Primitive.class(buffer)}"
end

string = Primitive.array_pack(self, format, buffer)
buffer.replace string.force_encoding(buffer.encoding)
end
end
Truffle::Graal.always_split instance_method(:pack)

def permutation(num = undefined, &block)
unless block_given?
return to_enum(:permutation, num) do
Expand Down
2 changes: 1 addition & 1 deletion test/truffle/integration/backtraces/pack.backtrace
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/pack.rb:13:in `to_str': message (RuntimeError)
from /pack.rb:18:in `pack'
from <internal:core> core/array.rb:LINE:in `pack'
from /pack.rb:18:in `block in <main>'
from /backtraces.rb:17:in `check'
from /pack.rb:17:in `<main>'