Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix cloader #615

Merged
merged 3 commits into from
May 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import java.util.jar.JarFile;
import java.util.jar.Manifest;

import javax.management.openmbean.CompositeDataInvocationHandler;

import com.security.smithloader.MemCheck;
import com.security.smithloader.common.JarUtil;
import com.security.smithloader.common.ParseParameter;
Expand All @@ -13,11 +15,16 @@
import java.lang.instrument.Instrumentation;
import java.lang.reflect.Constructor;
import java.lang.reflect.Method;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.concurrent.Callable;
import java.util.concurrent.FutureTask;
import java.util.concurrent.locks.ReentrantLock;

public class SmithAgent {
private static ReentrantLock xLoaderLock = new ReentrantLock();
private static SmithLoader xLoader = null;
private static Object xLoader = null;
private static Class<?> SmithProberClazz = null;
private static Object SmithProberObj = null;
private static Object SmithProberProxyObj = null;
Expand All @@ -26,6 +33,7 @@ public class SmithAgent {
private static String probeVersion;
private static String checksumStr;
private static String proberPath;
private static Instrumentation instrumentation = null;

public static Object getClassLoader() {
return xLoader;
Expand Down Expand Up @@ -73,9 +81,16 @@ private static boolean loadSmithProber(String proberPath, Instrumentation inst)
SmithAgentLogger.logger.info("loadSmithProber Entry");

try {
xLoader = new SmithLoader(proberPath, null);
SmithProberClazz = xLoader.loadClass("com.security.smith.SmithProbe");

Class<?> smithLoaderClazz = ClassLoader.getSystemClassLoader().loadClass("com.security.smithloader.SmithLoader");
Constructor<?> xconstructor = smithLoaderClazz.getConstructor(String.class, ClassLoader.class);
xLoader = xconstructor.newInstance(proberPath,null);

String smithProbeClassName = "com.security.smith.SmithProbe";
Class<?>[] loadclassargType = new Class[]{String.class};
SmithProberClazz = (Class<?>)Reflection.invokeMethod(xLoader,"loadClass", loadclassargType,smithProbeClassName);

SmithAgentLogger.logger.info("SmithProbe ClassLoader:"+SmithProberClazz.getClassLoader());

Class<?>[] emptyArgTypes = new Class[]{};
if (SmithProberClazz != null) {
Constructor<?> constructor = SmithProberClazz.getDeclaredConstructor();
Expand Down Expand Up @@ -182,6 +197,47 @@ private static String getProberVersion(String jarFilePath) {

return null;
}
private static class MyCallable implements Callable<String> {
@Override
public String call() throws Exception {
xLoaderLock.lock();
try {
if(xLoader != null) {
String agent = System.getProperty("rasp.probe");

if(unLoadSmithProber()) {
System.setProperty("smith.status", "detach");
}
if (agent != null) {
System.clearProperty("rasp.probe");
}
xLoader = null;
SmithProberObj = null;
SmithProberClazz = null;
}

System.setProperty("smith.rasp", "");
if (!checkMemoryAvailable()) {
System.setProperty("smith.status", "memory not enough");
SmithAgentLogger.logger.warning("checkMemory failed");
} else {
if(!loadSmithProber(proberPath,instrumentation)) {
System.setProperty("smith.status",proberPath + " loading fail");
SmithAgentLogger.logger.warning(proberPath + " loading fail!");
}
else {
System.setProperty("smith.rasp", probeVersion+"-"+checksumStr);
System.setProperty("smith.status", "attach");
System.setProperty("rasp.probe", "smith");
}
}
}
finally {
xLoaderLock.unlock();
}
return "SmithProbeLoader";
}
}

public static void premain(String agentArgs, Instrumentation inst) {
agentmain(agentArgs, inst);
Expand Down Expand Up @@ -214,57 +270,27 @@ public static void agentmain(String agentArgs, Instrumentation inst) {
SmithAgentLogger.logger.warning(proberPath + " check fail!");
return ;
}


try {
inst.appendToBootstrapClassLoaderSearch(new JarFile(proberPath));
}
catch(Exception e) {
SmithAgentLogger.exception(e);
if(instrumentation == null) {
instrumentation = inst;
}

probeVersion = getProberVersion(proberPath);
SmithAgentLogger.logger.info("proberVersion:" + probeVersion);

xLoaderLock.lock();
try {
if(xLoader != null) {
if(unLoadSmithProber()) {
System.setProperty("smith.status", "detach");
}
if (agent != null) {
System.clearProperty("rasp.probe");
}
xLoader = null;
SmithProberObj = null;
SmithProberClazz = null;
}

System.setProperty("smith.rasp", "");
if (!checkMemoryAvailable()) {
System.setProperty("smith.status", "memory not enough");
SmithAgentLogger.logger.warning("checkMemory failed");
} else {
if(!loadSmithProber(proberPath,inst)) {
System.setProperty("smith.status",proberPath + " loading fail");
SmithAgentLogger.logger.warning(proberPath + " loading fail!");
}
else {
System.setProperty("smith.rasp", probeVersion+"-"+checksumStr);
System.setProperty("smith.status", "attach");
System.setProperty("rasp.probe", "smith");
}
}
}
finally {
xLoaderLock.unlock();
}
Callable<String> callable = new MyCallable();

FutureTask<String> futureTask = new FutureTask<>(callable);
Thread newThread = new Thread(futureTask, "SmithProbeLoader Thread");
newThread.setContextClassLoader(ClassLoader.getSystemClassLoader());
newThread.start();
}
else if(cmd.equals("detach")) {
xLoaderLock.lock();
try {
if(xLoader != null) {
if(unLoadSmithProber()) {
SmithAgentLogger.logger.warning("SmithProber detach success!");
System.setProperty("smith.status", "detach");
}
else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import java.io.File;
import java.io.IOException;
import java.io.ByteArrayOutputStream;
import java.io.InputStream;
import java.net.URL;
import java.util.Enumeration;
Expand All @@ -12,9 +13,11 @@
import java.util.zip.ZipEntry;

public class SmithLoader extends ClassLoader {
private File file;
private JarFile jarFile;
public SmithLoader(String jarFilePath, ClassLoader parent) throws IOException {
this.jarFile = new JarFile(new File(jarFilePath));
file = new File(jarFilePath);
this.jarFile = new JarFile(file);
}

@Override
Expand All @@ -32,42 +35,47 @@ protected Class<?> findClass(String name) throws ClassNotFoundException {
} catch (ClassNotFoundException e) {
// If the class is not found in JAR file,try to load from parent class loader
return super.findClass(name);
//throw e;
}

return null;
}

private byte[] readAllBytes(InputStream inputStream) throws IOException {
ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
byte[] buffer = new byte[4096];
int bytesRead;
while ((bytesRead = inputStream.read(buffer)) != -1) {
outputStream.write(buffer, 0, bytesRead);
}
return outputStream.toByteArray();
}

private byte[] loadClassData(String className) throws IOException {
byte[] bytes = null;
byte[] data = null;

try {
ZipEntry zEntry = jarFile.getEntry(className);
if(zEntry == null) {
throw new IOException("class not found");
}

InputStream inputStream = jarFile.getInputStream(zEntry);
if(inputStream == null) {
throw new IOException("class not found");
}

bytes = new byte[inputStream.available()];

int bytesRead = inputStream.read(bytes);
if (bytesRead != bytes.length) {
throw new IOException("get byte array fail");
try (InputStream inputStream = jarFile.getInputStream(zEntry)) {
data = readAllBytes(inputStream);
inputStream.close();
}
}
catch(Exception e) {
throw e;
}

return bytes;
return data;
}

@Override
protected void finalize() throws Throwable {
try {
jarFile.close();
jarFile = null;
} finally {
super.finalize();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -607,7 +607,7 @@ public byte[] transform(ClassLoader loader, String className, Class<?> classBein
classReader.accept(classVisitor, ClassReader.EXPAND_FRAMES);

return classWriter.toByteArray();
} catch (Exception e) {
} catch (Throwable e) {
SmithLogger.exception(e);
}

Expand Down
Loading