Skip to content

Commit

Permalink
fix invalid packet types due to state mismatch when calling packet ev…
Browse files Browse the repository at this point in the history
…ents (#2568)
  • Loading branch information
derklaro committed Oct 25, 2023
1 parent 03d7be1 commit af33a2a
Show file tree
Hide file tree
Showing 10 changed files with 218 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ public interface ListenerInvoker {
*
* @param packet - the packet.
* @return The packet type.
* @deprecated use {@link com.comphenix.protocol.injector.packet.PacketRegistry#getPacketType(PacketType.Protocol, Class)} instead.
*/
@Deprecated
PacketType getPacketType(Object packet);
}
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,9 @@ public static Object newPacket(PacketType type) {
*
* @param packetType - packet type.
* @return A structure modifier.
* @deprecated use {@link #getStructure(PacketType)} instead.
*/
@Deprecated
public static StructureModifier<Object> getStructure(Class<?> packetType) {
// Get the ID from the class
PacketType type = PacketRegistry.getPacketType(packetType);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.comphenix.protocol.injector.netty;

import com.comphenix.protocol.PacketType;
import com.comphenix.protocol.PacketType.Protocol;
import com.comphenix.protocol.events.NetworkMarker;
import org.bukkit.entity.Player;
Expand Down Expand Up @@ -49,8 +50,21 @@ public interface Injector {
* Retrieve the current protocol state.
*
* @return The current protocol.
* @deprecated use {@link #getCurrentProtocol(PacketType.Sender)} instead.
*/
Protocol getCurrentProtocol();
@Deprecated
default Protocol getCurrentProtocol() {
return this.getCurrentProtocol(PacketType.Sender.SERVER);
}

/**
* Retrieve the current protocol state. Note that since 1.20.2 the client and server direction can be in different
* protocol states.
*
* @param sender the side for which the state should be resolved.
* @return The current protocol.
*/
Protocol getCurrentProtocol(PacketType.Sender sender);

/**
* Retrieve the network marker associated with a given packet.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
package com.comphenix.protocol.injector.netty.channel;

import com.comphenix.protocol.PacketType;
import com.comphenix.protocol.reflect.FuzzyReflection;
import com.comphenix.protocol.reflect.accessors.Accessors;
import com.comphenix.protocol.reflect.accessors.FieldAccessor;
import com.comphenix.protocol.reflect.fuzzy.FuzzyFieldContract;
import com.comphenix.protocol.utility.MinecraftReflection;
import io.netty.channel.Channel;
import io.netty.util.AttributeKey;

import java.lang.reflect.Field;
import java.lang.reflect.Modifier;
import java.util.List;
import java.util.function.BiFunction;

@SuppressWarnings("unchecked")
final class ChannelProtocolUtil {

public static final BiFunction<Channel, PacketType.Sender, PacketType.Protocol> PROTOCOL_RESOLVER;

static {
Class<?> networkManagerClass = MinecraftReflection.getNetworkManagerClass();
List<Field> attributeKeys = FuzzyReflection.fromClass(networkManagerClass, true).getFieldList(FuzzyFieldContract.newBuilder()
.typeExact(AttributeKey.class)
.requireModifier(Modifier.STATIC)
.declaringClassExactType(networkManagerClass)
.build());

BiFunction<Channel, PacketType.Sender, Object> baseResolver = null;
if (attributeKeys.size() == 1) {
// if there is only one attribute key we can assume it's the correct one (1.8 - 1.20.1)
Object protocolKey = Accessors.getFieldAccessor(attributeKeys.get(0)).get(null);
baseResolver = new Pre1_20_2DirectResolver((AttributeKey<Object>) protocolKey);
} else if (attributeKeys.size() > 1) {
// most likely 1.20.2+: 1 protocol key per protocol direction
AttributeKey<Object> serverBoundKey = null;
AttributeKey<Object> clientBoundKey = null;

for (Field keyField : attributeKeys) {
AttributeKey<Object> key = (AttributeKey<Object>) Accessors.getFieldAccessor(keyField).get(null);
if (key.name().equals("protocol")) {
// legacy (pre 1.20.2 name) - fall back to the old behaviour
baseResolver = new Pre1_20_2DirectResolver(key);
break;
}

if (key.name().contains("protocol")) {
// one of the two protocol keys for 1.20.2
if (key.name().contains("server")) {
serverBoundKey = key;
} else {
clientBoundKey = key;
}
}
}

if (baseResolver == null) {
if ((serverBoundKey == null || clientBoundKey == null)) {
// neither pre 1.20.2 key nor 1.20.2+ keys are available
throw new ExceptionInInitializerError("Unable to resolve protocol state attribute keys");
} else {
baseResolver = new Post1_20_2WrappedResolver(serverBoundKey, clientBoundKey);
}
}
} else {
throw new ExceptionInInitializerError("Unable to resolve protocol state attribute key(s)");
}

// decorate the base resolver by wrapping its return value into our packet type value
PROTOCOL_RESOLVER = baseResolver.andThen(nmsProtocol -> PacketType.Protocol.fromVanilla((Enum<?>) nmsProtocol));
}

private static final class Pre1_20_2DirectResolver implements BiFunction<Channel, PacketType.Sender, Object> {

private final AttributeKey<Object> attributeKey;

public Pre1_20_2DirectResolver(AttributeKey<Object> attributeKey) {
this.attributeKey = attributeKey;
}

@Override
public Object apply(Channel channel, PacketType.Sender sender) {
return channel.attr(this.attributeKey).get();
}
}

private static final class Post1_20_2WrappedResolver implements BiFunction<Channel, PacketType.Sender, Object> {

private final AttributeKey<Object> serverBoundKey;
private final AttributeKey<Object> clientBoundKey;

// lazy initialized when needed
private FieldAccessor protocolAccessor;

public Post1_20_2WrappedResolver(AttributeKey<Object> serverBoundKey, AttributeKey<Object> clientBoundKey) {
this.serverBoundKey = serverBoundKey;
this.clientBoundKey = clientBoundKey;
}

@Override
public Object apply(Channel channel, PacketType.Sender sender) {
AttributeKey<Object> key = this.getKeyForSender(sender);
Object codecData = channel.attr(key).get();
if (codecData == null) {
return null;
}

FieldAccessor protocolAccessor = this.getProtocolAccessor(codecData.getClass());
return protocolAccessor.get(codecData);
}

private AttributeKey<Object> getKeyForSender(PacketType.Sender sender) {
switch (sender) {
case SERVER:
return this.clientBoundKey;
case CLIENT:
return this.serverBoundKey;
default:
throw new IllegalArgumentException("Illegal packet sender " + sender.name());
}
}

private FieldAccessor getProtocolAccessor(Class<?> codecClass) {
if (this.protocolAccessor == null) {
Class<?> enumProtocolClass = MinecraftReflection.getEnumProtocolClass();
this.protocolAccessor = Accessors.getFieldAccessor(codecClass, enumProtocolClass, true);
}

return this.protocolAccessor;
}
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.comphenix.protocol.injector.netty.channel;

import com.comphenix.protocol.PacketType;
import com.comphenix.protocol.PacketType.Protocol;
import com.comphenix.protocol.events.NetworkMarker;
import com.comphenix.protocol.injector.netty.Injector;
Expand Down Expand Up @@ -42,7 +43,7 @@ public void receiveClientPacket(Object packet) {
}

@Override
public Protocol getCurrentProtocol() {
public Protocol getCurrentProtocol(PacketType.Sender sender) {
return Protocol.HANDSHAKING;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@

import java.lang.reflect.Field;
import java.lang.reflect.Modifier;
import java.util.Collections;
import java.util.HashSet;
import java.util.Map;
import java.util.Map.Entry;
import java.util.NoSuchElementException;
Expand Down Expand Up @@ -110,7 +108,6 @@ public Field getField() {

// lazy initialized fields, if we don't need them we don't bother about them
private Object playerConnection;
private FieldAccessor protocolAccessor;

public NettyChannelInjector(
Player player,
Expand Down Expand Up @@ -322,17 +319,8 @@ public void receiveClientPacket(Object packet) {
}

@Override
public Protocol getCurrentProtocol() {
// ensure that the accessor to the protocol field is available
if (this.protocolAccessor == null) {
this.protocolAccessor = Accessors.getFieldAccessor(
this.networkManager.getClass(),
MinecraftReflection.getEnumProtocolClass(),
true);
}

Object nmsProtocol = this.protocolAccessor.get(this.networkManager);
return Protocol.fromVanilla((Enum<?>) nmsProtocol);
public Protocol getCurrentProtocol(PacketType.Sender sender) {
return ChannelProtocolUtil.PROTOCOL_RESOLVER.apply(this.wrappedChannel, sender);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
import com.comphenix.protocol.reflect.fuzzy.FuzzyFieldContract;
import com.comphenix.protocol.reflect.fuzzy.FuzzyMethodContract;
import com.comphenix.protocol.utility.MinecraftReflection;
import com.comphenix.protocol.utility.Util;
import com.comphenix.protocol.wrappers.Pair;
import io.netty.channel.ChannelFuture;
import org.bukkit.Server;
Expand Down Expand Up @@ -93,7 +92,8 @@ public PacketEvent onPacketSending(Injector injector, Object packet, NetworkMark
Class<?> packetClass = packet.getClass();
if (marker != null || MinecraftReflection.isBundlePacket(packetClass) || outboundListeners.contains(packetClass)) {
// wrap packet and construct the event
PacketContainer container = new PacketContainer(PacketRegistry.getPacketType(packetClass), packet);
PacketType.Protocol currentProtocol = injector.getCurrentProtocol(PacketType.Sender.SERVER);
PacketContainer container = new PacketContainer(PacketRegistry.getPacketType(currentProtocol, packetClass), packet);
PacketEvent packetEvent = PacketEvent.fromServer(this, container, marker, injector.getPlayer());

// post to all listeners, then return the packet event we constructed
Expand All @@ -111,7 +111,8 @@ public PacketEvent onPacketReceiving(Injector injector, Object packet, NetworkMa
Class<?> packetClass = packet.getClass();
if (marker != null || inboundListeners.contains(packetClass)) {
// wrap the packet and construct the event
PacketContainer container = new PacketContainer(PacketRegistry.getPacketType(packetClass), packet);
PacketType.Protocol currentProtocol = injector.getCurrentProtocol(PacketType.Sender.CLIENT);
PacketContainer container = new PacketContainer(PacketRegistry.getPacketType(currentProtocol, packetClass), packet);
PacketEvent packetEvent = PacketEvent.fromClient(this, container, marker, injector.getPlayer());

// post to all listeners, then return the packet event we constructed
Expand Down Expand Up @@ -238,7 +239,6 @@ public void close() {
// just reset to the list we wrapped originally
ListeningList ourList = (ListeningList) currentFieldValue;
List<Object> original = ourList.getOriginal();
//noinspection SynchronizationOnLocalVariableOrMethodParameter
synchronized (original) {
// revert the injection from all values of the list
ourList.unProcessAll();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ public class PacketRegistry {
protected static class Register {
// The main lookup table
final Map<PacketType, Optional<Class<?>>> typeToClass = new ConcurrentHashMap<>();

final Map<Class<?>, PacketType> classToType = new ConcurrentHashMap<>();
final Map<PacketType.Protocol, Map<Class<?>, PacketType>> protocolClassToType = new ConcurrentHashMap<>();

volatile Set<PacketType> serverPackets = new HashSet<>();
volatile Set<PacketType> clientPackets = new HashSet<>();
Expand All @@ -58,7 +60,10 @@ public Register() {}

public void registerPacket(PacketType type, Class<?> clazz, Sender sender) {
typeToClass.put(type, Optional.of(clazz));

classToType.put(clazz, type);
protocolClassToType.computeIfAbsent(type.getProtocol(), __ -> new ConcurrentHashMap<>()).put(clazz, type);

if (sender == Sender.CLIENT) {
clientPackets.add(type);
} else {
Expand Down Expand Up @@ -430,7 +435,9 @@ public static Class<?> getPacketClassFromType(PacketType type) {
* Retrieve the packet type of a given packet.
* @param packet - the class of the packet.
* @return The packet type, or NULL if not found.
* @deprecated major issues due to packets with shared classes being registered in multiple states.
*/
@Deprecated
public static PacketType getPacketType(Class<?> packet) {
initialize();

Expand All @@ -440,7 +447,24 @@ public static PacketType getPacketType(Class<?> packet) {

return REGISTER.classToType.get(packet);
}


/**
* Retrieve the associated packet type for a packet class in the given protocol state.
*
* @param protocol the protocol state to retrieve the packet from.
* @param packet the class identifying the packet type.
* @return the packet type associated with the given class in the given protocol state, or null if not found.
*/
public static PacketType getPacketType(PacketType.Protocol protocol, Class<?> packet) {
initialize();
if (MinecraftReflection.isBundlePacket(packet)) {
return PacketType.Play.Server.BUNDLE;
}

Map<Class<?>, PacketType> classToTypesForProtocol = REGISTER.protocolClassToType.get(protocol);
return classToTypesForProtocol == null ? null : classToTypesForProtocol.get(packet);
}

/**
* Retrieve the packet type of a given packet.
* @param packet - the class of the packet.
Expand Down
17 changes: 0 additions & 17 deletions src/main/java/com/comphenix/protocol/reflect/ObjectWriter.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,6 @@

package com.comphenix.protocol.reflect;

import com.comphenix.protocol.PacketType;
import com.comphenix.protocol.injector.StructureCache;
import com.comphenix.protocol.injector.packet.PacketRegistry;
import com.comphenix.protocol.utility.MinecraftReflection;
import com.comphenix.protocol.utility.StreamSerializer;
import java.lang.reflect.Field;
import java.lang.reflect.Modifier;
import java.util.HashMap;
Expand All @@ -46,18 +41,6 @@ public class ObjectWriter {
* @return A structure modifier for the given type.
*/
private StructureModifier<Object> getModifier(Class<?> type) {
Class<?> packetClass = MinecraftReflection.getPacketClass();

// Handle subclasses of the packet class with our custom structure cache, if possible
if (!type.equals(packetClass) && packetClass.isAssignableFrom(type)) {
// might be a packet, but some packets are not registered (for example PacketPlayInFlying, only the subtypes are present)
PacketType packetType = PacketRegistry.getPacketType(type);
if (packetType != null) {
// packet is present, delegate to the cache
return StructureCache.getStructure(packetType);
}
}

// Create the structure modifier if we haven't already
StructureModifier<Object> modifier = CACHE.get(type);
if (modifier == null) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package com.comphenix.protocol.injector.netty.channel;

import com.comphenix.protocol.BukkitInitialization;
import com.comphenix.protocol.PacketType;
import io.netty.channel.Channel;
import io.netty.channel.local.LocalServerChannel;
import net.minecraft.network.EnumProtocol;
import net.minecraft.network.NetworkManager;
import net.minecraft.network.protocol.EnumProtocolDirection;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;

public class ChannelProtocolUtilTest {

@BeforeAll
public static void beforeClass() {
BukkitInitialization.initializeAll();
}

@Test
public void testProtocolResolving() {
Channel channel = new LocalServerChannel();
channel.attr(NetworkManager.e).set(EnumProtocol.e.b(EnumProtocolDirection.a)); // ATTRIBUTE_SERVERBOUND_PROTOCOL -> Protocol.CONFIG.codec(SERVERBOUND)
channel.attr(NetworkManager.f).set(EnumProtocol.b.b(EnumProtocolDirection.b)); // ATTRIBUTE_CLIENTBOUND_PROTOCOL -> Protocol.PLAY.codec(CLIENTBOUND)

PacketType.Protocol serverBoundProtocol = ChannelProtocolUtil.PROTOCOL_RESOLVER.apply(channel, PacketType.Sender.CLIENT);
Assertions.assertEquals(PacketType.Protocol.CONFIGURATION, serverBoundProtocol);

PacketType.Protocol clientBoundProtocol = ChannelProtocolUtil.PROTOCOL_RESOLVER.apply(channel, PacketType.Sender.SERVER);
Assertions.assertEquals(PacketType.Protocol.PLAY, clientBoundProtocol);
}
}

0 comments on commit af33a2a

Please sign in to comment.