diff --git a/3rdparty/phnt/include/ntioapi.h b/3rdparty/phnt/include/ntioapi.h index a08f8775..bfcaf20d 100644 --- a/3rdparty/phnt/include/ntioapi.h +++ b/3rdparty/phnt/include/ntioapi.h @@ -774,22 +774,6 @@ typedef struct _FILE_ID_EXTD_BOTH_DIR_INFORMATION WCHAR FileName[1]; } FILE_ID_EXTD_BOTH_DIR_INFORMATION, *PFILE_ID_EXTD_BOTH_DIR_INFORMATION; -// private -typedef struct _FILE_STAT_INFORMATION -{ - LARGE_INTEGER FileId; - LARGE_INTEGER CreationTime; - LARGE_INTEGER LastAccessTime; - LARGE_INTEGER LastWriteTime; - LARGE_INTEGER ChangeTime; - LARGE_INTEGER AllocationSize; - LARGE_INTEGER EndOfFile; - ULONG FileAttributes; - ULONG ReparseTag; - ULONG NumberOfLinks; - ACCESS_MASK EffectiveAccess; -} FILE_STAT_INFORMATION, *PFILE_STAT_INFORMATION; - // private typedef struct _FILE_MEMORY_PARTITION_INFORMATION { @@ -805,43 +789,6 @@ typedef struct _FILE_MEMORY_PARTITION_INFORMATION } Flags; } FILE_MEMORY_PARTITION_INFORMATION, *PFILE_MEMORY_PARTITION_INFORMATION; -// LxFlags -#define LX_FILE_METADATA_HAS_UID 0x1 -#define LX_FILE_METADATA_HAS_GID 0x2 -#define LX_FILE_METADATA_HAS_MODE 0x4 -#define LX_FILE_METADATA_HAS_DEVICE_ID 0x8 -#define LX_FILE_CASE_SENSITIVE_DIR 0x10 - -// private -typedef struct _FILE_STAT_LX_INFORMATION -{ - LARGE_INTEGER FileId; - LARGE_INTEGER CreationTime; - LARGE_INTEGER LastAccessTime; - LARGE_INTEGER LastWriteTime; - LARGE_INTEGER ChangeTime; - LARGE_INTEGER AllocationSize; - LARGE_INTEGER EndOfFile; - ULONG FileAttributes; - ULONG ReparseTag; - ULONG NumberOfLinks; - ACCESS_MASK EffectiveAccess; - ULONG LxFlags; - ULONG LxUid; - ULONG LxGid; - ULONG LxMode; - ULONG LxDeviceIdMajor; - ULONG LxDeviceIdMinor; -} FILE_STAT_LX_INFORMATION, *PFILE_STAT_LX_INFORMATION; - -#define FILE_CS_FLAG_CASE_SENSITIVE_DIR 0x00000001 - -// private -typedef struct _FILE_CASE_SENSITIVE_INFORMATION -{ - ULONG Flags; -} FILE_CASE_SENSITIVE_INFORMATION, *PFILE_CASE_SENSITIVE_INFORMATION; - // private typedef enum _FILE_KNOWN_FOLDER_TYPE { diff --git a/MemoryModule/Initialize.cpp b/MemoryModule/Initialize.cpp index 42da0ba5..aa023c8e 100644 --- a/MemoryModule/Initialize.cpp +++ b/MemoryModule/Initialize.cpp @@ -438,6 +438,7 @@ NTSTATUS InitializeLockHeld() { status = STATUS_NOT_SUPPORTED; } else { + ++MmpGlobalDataPtr->ReferenceCount; status = STATUS_SUCCESS; } } @@ -448,6 +449,7 @@ NTSTATUS InitializeLockHeld() { MmpGlobalDataPtr->MajorVersion = MEMORY_MODULE_MAJOR_VERSION; MmpGlobalDataPtr->MinorVersion = MEMORY_MODULE_MINOR_VERSION; MmpGlobalDataPtr->BaseAddress = MmpGlobalDataPtr; + MmpGlobalDataPtr->ReferenceCount = 1; GetSystemInfo(&MmpGlobalDataPtr->SystemInfo); @@ -504,16 +506,74 @@ NTSTATUS InitializeLockHeld() { return status; } -NTSTATUS NTAPI Initialize() { +NTSTATUS NTAPI MmInitialize() { NTSTATUS status; - RtlAcquirePebLock(); - status = InitializeLockHeld(); - RtlReleasePebLock(); + PVOID cookie; + LdrLockLoaderLock(LDR_LOCK_LOADER_LOCK_FLAG_RAISE_ON_ERRORS, nullptr, &cookie); + + __try { + status = InitializeLockHeld(); + } + __finally { + LdrUnlockLoaderLock(LDR_UNLOCK_LOADER_LOCK_FLAG_RAISE_ON_ERRORS, cookie); + } return status; } +NTSTATUS CleanupLockHeld() { + + PLIST_ENTRY ListHead = &NtCurrentPeb()->Ldr->InLoadOrderModuleList, ListEntry = ListHead->Flink; + PLDR_DATA_TABLE_ENTRY CurEntry; + + while (ListEntry != ListHead) { + CurEntry = CONTAINING_RECORD(ListEntry, LDR_DATA_TABLE_ENTRY, InLoadOrderLinks); + ListEntry = ListEntry->Flink; + + if (IsValidMemoryModuleHandle((HMEMORYMODULE)CurEntry->DllBase)) { + + // + // Make sure all memory module is unloaded. + // + + return STATUS_NOT_SUPPORTED; + } + } + + if (--MmpGlobalDataPtr->ReferenceCount > 0) { + return STATUS_SUCCESS; + } + + MmpTlsCleanup(); + MmpCleanupDotNetHooks(); + + NtUnmapViewOfSection(NtCurrentProcess(), MmpGlobalDataPtr->BaseAddress); + MmpGlobalDataPtr = nullptr; + return STATUS_SUCCESS; +} + +NTSTATUS NTAPI MmCleanup() { + NTSTATUS status; + PVOID cookie; + LdrLockLoaderLock(LDR_LOCK_LOADER_LOCK_FLAG_RAISE_ON_ERRORS, nullptr, &cookie); + + __try { + + if (MmpGlobalDataPtr == nullptr) { + status = STATUS_ACCESS_VIOLATION; + __leave; + } + + status = CleanupLockHeld(); + } + __finally { + LdrUnlockLoaderLock(LDR_UNLOCK_LOADER_LOCK_FLAG_RAISE_ON_ERRORS, cookie); + } + + return status; +} + #ifdef _USRDLL extern "C" __declspec(dllexport) BOOL WINAPI ReflectiveMapDll(HMODULE hModule) { PIMAGE_NT_HEADERS headers = RtlImageNtHeader(hModule); @@ -542,7 +602,8 @@ extern "C" __declspec(dllexport) BOOL WINAPI ReflectiveMapDll(HMODULE hModule) { BOOL APIENTRY DllMain(HMODULE hModule, DWORD ul_reason_for_call, LPVOID lpReserved) { if (ul_reason_for_call == DLL_PROCESS_ATTACH) { - if (NT_SUCCESS(Initialize())) { +#ifdef _HAS_AUTO_INITIALIZE + if (NT_SUCCESS(MmInitialize())) { if (lpReserved == (PVOID)-1) { if (!ReflectiveMapDll(hModule)) { RtlRaiseStatus(STATUS_NOT_SUPPORTED); @@ -553,10 +614,13 @@ BOOL APIENTRY DllMain(HMODULE hModule, DWORD ul_reason_for_call, LPVOID lpReser } return FALSE; +#endif } return TRUE; } #else -const NTSTATUS Initializer = Initialize(); +#ifdef _HAS_AUTO_INITIALIZE +const NTSTATUS Initializer = MmInitialize(); +#endif #endif diff --git a/MemoryModule/Initialize.h b/MemoryModule/Initialize.h new file mode 100644 index 00000000..655df6ed --- /dev/null +++ b/MemoryModule/Initialize.h @@ -0,0 +1,9 @@ +#pragma once + +NTSTATUS NTAPI MmInitialize(); +NTSTATUS NTAPI MmCleanup(); + +// +// This function is available only if the MMPP is compiled as a DLL. +// +BOOL WINAPI ReflectiveMapDll(HMODULE hModule); diff --git a/MemoryModule/LoadDllMemoryApi.h b/MemoryModule/LoadDllMemoryApi.h index 54d8368a..0a413b50 100644 --- a/MemoryModule/LoadDllMemoryApi.h +++ b/MemoryModule/LoadDllMemoryApi.h @@ -3,6 +3,7 @@ typedef HMODULE HMEMORYMODULE; #include "Loader.h" +#include "Initialize.h" #define MemoryModuleToModule(_hMemoryModule_) (_hMemoryModule_) @@ -10,25 +11,29 @@ typedef HMODULE HMEMORYMODULE; #define NT_SUCCESS(Status) (((NTSTATUS)(Status)) >= 0) #endif -HMEMORYMODULE WINAPI LoadLibraryMemory(_In_ PVOID BufferAddress); - -HMEMORYMODULE WINAPI LoadLibraryMemoryExA( - _In_ PVOID BufferAddress, - _In_ size_t Reserved, - _In_opt_ LPCSTR DllBaseName, - _In_opt_ LPCSTR DllFullName, - _In_ DWORD Flags -); - -HMEMORYMODULE WINAPI LoadLibraryMemoryExW( - _In_ PVOID BufferAddress, - _In_ size_t Reserved, - _In_opt_ LPCWSTR DllBaseName, - _In_opt_ LPCWSTR DllFullName, - _In_ DWORD Flags -); - -BOOL WINAPI FreeLibraryMemory(_In_ HMEMORYMODULE hMemoryModule); +extern "C" { + + HMEMORYMODULE WINAPI LoadLibraryMemory(_In_ PVOID BufferAddress); + + HMEMORYMODULE WINAPI LoadLibraryMemoryExA( + _In_ PVOID BufferAddress, + _In_ size_t Reserved, + _In_opt_ LPCSTR DllBaseName, + _In_opt_ LPCSTR DllFullName, + _In_ DWORD Flags + ); + + HMEMORYMODULE WINAPI LoadLibraryMemoryExW( + _In_ PVOID BufferAddress, + _In_ size_t Reserved, + _In_opt_ LPCWSTR DllBaseName, + _In_opt_ LPCWSTR DllFullName, + _In_ DWORD Flags + ); + + BOOL WINAPI FreeLibraryMemory(_In_ HMEMORYMODULE hMemoryModule); + +} #define NtLoadDllMemory LdrLoadDllMemory #define NtLoadDllMemoryExA LdrLoadDllMemoryExA diff --git a/MemoryModule/Loader.cpp b/MemoryModule/Loader.cpp index 73eb5ffe..846ad818 100644 --- a/MemoryModule/Loader.cpp +++ b/MemoryModule/Loader.cpp @@ -61,7 +61,14 @@ NTSTATUS NTAPI LdrLoadDllMemoryExW( __try { *BaseAddress = nullptr; if (LdrEntry)*LdrEntry = nullptr; - if (!RtlIsValidImageBuffer(BufferAddress, &BufferSize) && !(dwFlags & LOAD_FLAGS_PASS_IMAGE_CHECK))status = STATUS_INVALID_IMAGE_FORMAT; + + if (!RtlIsValidImageBuffer(BufferAddress, &BufferSize) && !(dwFlags & LOAD_FLAGS_PASS_IMAGE_CHECK)) { + status = STATUS_INVALID_IMAGE_FORMAT; + } + + if (MmpGlobalDataPtr == nullptr) { + status = STATUS_INVALID_PARAMETER; + } } __except (EXCEPTION_EXECUTE_HANDLER) { status = GetExceptionCode(); @@ -229,6 +236,11 @@ NTSTATUS NTAPI LdrUnloadDllMemory(_In_ HMEMORYMODULE BaseAddress) { break; } + if (MmpGlobalDataPtr == nullptr) { + status = STATUS_INVALID_PARAMETER; + break; + } + //Mapping dll failed if (!module->MappedDll) { module->underUnload = true; diff --git a/MemoryModule/Loader.h b/MemoryModule/Loader.h index 8bdce287..a71afe7f 100644 --- a/MemoryModule/Loader.h +++ b/MemoryModule/Loader.h @@ -9,9 +9,6 @@ #define MEMORY_FEATURE_LDRP_RELEASE_TLS_ENTRY 0x00000040 #define MEMORY_FEATURE_ALL 0x0000007f -//Get the implementation of the currently running operating system. -NTSTATUS NTAPI LdrQuerySystemMemoryModuleFeatures(_Out_ PDWORD pFeatures); - /* LdrLoadDllMemoryEx dwFlags @@ -47,21 +44,24 @@ NTSTATUS NTAPI LdrQuerySystemMemoryModuleFeatures(_Out_ PDWORD pFeatures); //Hook for dotnet dlls #define LOAD_FLAGS_HOOK_DOT_NET 0x00000010 +extern "C" { -NTSTATUS NTAPI LdrLoadDllMemoryExW( - _Out_ HMEMORYMODULE* BaseAddress, // Output module base address - _Out_opt_ PVOID* LdrEntry, // Receive a pointer to the LDR node of the module - _In_ DWORD dwFlags, // Flags - _In_ LPVOID BufferAddress, // Pointer to the dll file data buffer - _In_ size_t Reserved, // Reserved parameter, must be 0 - _In_opt_ LPCWSTR DllName, // Module file name - _In_opt_ LPCWSTR DllFullName // Module file full path -); + //Get the implementation of the currently running operating system. + NTSTATUS NTAPI LdrQuerySystemMemoryModuleFeatures(_Out_ PDWORD pFeatures); + + NTSTATUS NTAPI LdrLoadDllMemoryExW( + _Out_ HMEMORYMODULE* BaseAddress, // Output module base address + _Out_opt_ PVOID* LdrEntry, // Receive a pointer to the LDR node of the module + _In_ DWORD dwFlags, // Flags + _In_ LPVOID BufferAddress, // Pointer to the dll file data buffer + _In_ size_t Reserved, // Reserved parameter, must be 0 + _In_opt_ LPCWSTR DllName, // Module file name + _In_opt_ LPCWSTR DllFullName // Module file full path + ); -//Unload modules previously loaded from memory -NTSTATUS NTAPI LdrUnloadDllMemory(_In_ HMEMORYMODULE BaseAddress); + //Unload modules previously loaded from memory + NTSTATUS NTAPI LdrUnloadDllMemory(_In_ HMEMORYMODULE BaseAddress); -extern "C" { __declspec(noreturn) VOID NTAPI LdrUnloadDllMemoryAndExitThread( _In_ HMEMORYMODULE BaseAddress, _In_ DWORD dwExitCode diff --git a/MemoryModule/MemoryModule.vcxproj b/MemoryModule/MemoryModule.vcxproj index 5ea04bc9..ebcea8da 100644 --- a/MemoryModule/MemoryModule.vcxproj +++ b/MemoryModule/MemoryModule.vcxproj @@ -152,6 +152,7 @@ + @@ -419,7 +420,7 @@ NotUsing Level3 true - _MEMORY_MODULE;WIN32;_DEBUG;_LIB;%(PreprocessorDefinitions) + _MEMORY_MODULE;_HAS_AUTO_INITIALIZE;WIN32;_DEBUG;_LIB;%(PreprocessorDefinitions) true @@ -437,7 +438,7 @@ NotUsing Level3 true - _MEMORY_MODULE;WIN32;_DEBUG;_LIB;%(PreprocessorDefinitions) + _MEMORY_MODULE;_HAS_AUTO_INITIALIZE;WIN32;_DEBUG;_LIB;%(PreprocessorDefinitions) true @@ -456,7 +457,7 @@ NotUsing Level3 true - _MEMORY_MODULE;WIN32;_DEBUG;_USRDLL;%(PreprocessorDefinitions) + _MEMORY_MODULE;_HAS_AUTO_INITIALIZE;WIN32;_DEBUG;_USRDLL;%(PreprocessorDefinitions) true @@ -475,7 +476,7 @@ NotUsing Level3 true - _MEMORY_MODULE;WIN32;_DEBUG;_USRDLL;%(PreprocessorDefinitions) + _MEMORY_MODULE;_HAS_AUTO_INITIALIZE;WIN32;_DEBUG;_USRDLL;%(PreprocessorDefinitions) true @@ -494,7 +495,7 @@ NotUsing Level3 true - _MEMORY_MODULE;_DEBUG;_LIB;%(PreprocessorDefinitions) + _MEMORY_MODULE;_HAS_AUTO_INITIALIZE;_DEBUG;_LIB;%(PreprocessorDefinitions) true @@ -512,7 +513,7 @@ NotUsing Level3 true - _MEMORY_MODULE;_DEBUG;_LIB;%(PreprocessorDefinitions) + _MEMORY_MODULE;_HAS_AUTO_INITIALIZE;_DEBUG;_LIB;%(PreprocessorDefinitions) true @@ -531,7 +532,7 @@ NotUsing Level3 true - _MEMORY_MODULE;_DEBUG;_USRDLL;%(PreprocessorDefinitions) + _MEMORY_MODULE;_HAS_AUTO_INITIALIZE;_DEBUG;_USRDLL;%(PreprocessorDefinitions) true @@ -550,7 +551,7 @@ NotUsing Level3 true - _MEMORY_MODULE;_DEBUG;_USRDLL;%(PreprocessorDefinitions) + _MEMORY_MODULE;_HAS_AUTO_INITIALIZE;_DEBUG;_USRDLL;%(PreprocessorDefinitions) true @@ -571,7 +572,7 @@ true true true - _MEMORY_MODULE;WIN32;NDEBUG;_LIB;%(PreprocessorDefinitions) + _MEMORY_MODULE;_HAS_AUTO_INITIALIZE;WIN32;NDEBUG;_LIB;%(PreprocessorDefinitions) true @@ -593,7 +594,7 @@ true true true - _MEMORY_MODULE;WIN32;NDEBUG;_LIB;%(PreprocessorDefinitions) + _MEMORY_MODULE;_HAS_AUTO_INITIALIZE;WIN32;NDEBUG;_LIB;%(PreprocessorDefinitions) true @@ -616,7 +617,7 @@ true true true - _MEMORY_MODULE;WIN32;NDEBUG;_USRDLL;%(PreprocessorDefinitions) + _MEMORY_MODULE;_HAS_AUTO_INITIALIZE;WIN32;NDEBUG;_USRDLL;%(PreprocessorDefinitions) true @@ -639,7 +640,7 @@ true true true - _MEMORY_MODULE;WIN32;NDEBUG;_USRDLL;%(PreprocessorDefinitions) + _MEMORY_MODULE;_HAS_AUTO_INITIALIZE;WIN32;NDEBUG;_USRDLL;%(PreprocessorDefinitions) true @@ -662,7 +663,7 @@ true true true - _MEMORY_MODULE;NDEBUG;_LIB;%(PreprocessorDefinitions) + _MEMORY_MODULE;_HAS_AUTO_INITIALIZE;NDEBUG;_LIB;%(PreprocessorDefinitions) true @@ -684,7 +685,7 @@ true true true - _MEMORY_MODULE;NDEBUG;_LIB;%(PreprocessorDefinitions) + _MEMORY_MODULE;_HAS_AUTO_INITIALIZE;NDEBUG;_LIB;%(PreprocessorDefinitions) true @@ -707,7 +708,7 @@ true true true - _MEMORY_MODULE;NDEBUG;_USRDLL;%(PreprocessorDefinitions) + _MEMORY_MODULE;_HAS_AUTO_INITIALIZE;NDEBUG;_USRDLL;%(PreprocessorDefinitions) true @@ -730,7 +731,7 @@ true true true - _MEMORY_MODULE;NDEBUG;_USRDLL;%(PreprocessorDefinitions) + _MEMORY_MODULE;_HAS_AUTO_INITIALIZE;NDEBUG;_USRDLL;%(PreprocessorDefinitions) true diff --git a/MemoryModule/MemoryModule.vcxproj.filters b/MemoryModule/MemoryModule.vcxproj.filters index 9413f3cc..e38f9fdb 100644 --- a/MemoryModule/MemoryModule.vcxproj.filters +++ b/MemoryModule/MemoryModule.vcxproj.filters @@ -260,6 +260,9 @@ Header Files + + Header Files + diff --git a/MemoryModule/MemoryModulePP.def b/MemoryModule/MemoryModulePP.def index 7491bc1e..fcbae96d 100644 --- a/MemoryModule/MemoryModulePP.def +++ b/MemoryModule/MemoryModulePP.def @@ -1,5 +1,9 @@ LIBRARY EXPORTS + +MmInitialize +MmCleanup + LoadLibraryMemory LoadLibraryMemoryExA LoadLibraryMemoryExW diff --git a/MemoryModule/MmpDotNet.cpp b/MemoryModule/MmpDotNet.cpp index 5f5e9585..ca8bcc1c 100644 --- a/MemoryModule/MmpDotNet.cpp +++ b/MemoryModule/MmpDotNet.cpp @@ -443,3 +443,38 @@ BOOL WINAPI MmpInitializeHooksForDotNet() { return FALSE; } + +VOID WINAPI MmpCleanupDotNetHooks() { + EnterCriticalSection(NtCurrentPeb()->FastPebLock); + + if (MmpGlobalDataPtr->MmpDotNet->PreHooked) { + DetourTransactionBegin(); + DetourUpdateThread(NtCurrentThread()); + + DetourDetach((PVOID*)&MmpGlobalDataPtr->MmpDotNet->Hooks.OriginCreateFileW, HookCreateFileW); + DetourDetach((PVOID*)&MmpGlobalDataPtr->MmpDotNet->Hooks.OriginGetFileInformationByHandle, HookGetFileInformationByHandle); + DetourDetach((PVOID*)&MmpGlobalDataPtr->MmpDotNet->Hooks.OriginGetFileAttributesExW, HookGetFileAttributesExW); + DetourDetach((PVOID*)&MmpGlobalDataPtr->MmpDotNet->Hooks.OriginGetFileSize, HookGetFileSize); + DetourDetach((PVOID*)&MmpGlobalDataPtr->MmpDotNet->Hooks.OriginGetFileSizeEx, HookGetFileSizeEx); + DetourDetach((PVOID*)&MmpGlobalDataPtr->MmpDotNet->Hooks.OriginCreateFileMappingW, HookCreateFileMappingW); + DetourDetach((PVOID*)&MmpGlobalDataPtr->MmpDotNet->Hooks.OriginMapViewOfFileEx, HookMapViewOfFileEx); + DetourDetach((PVOID*)&MmpGlobalDataPtr->MmpDotNet->Hooks.OriginMapViewOfFile, HookMapViewOfFile); + DetourDetach((PVOID*)&MmpGlobalDataPtr->MmpDotNet->Hooks.OriginUnmapViewOfFile, HookUnmapViewOfFile); + DetourDetach((PVOID*)&MmpGlobalDataPtr->MmpDotNet->Hooks.OriginCloseHandle, HookCloseHandle); + DetourDetach((PVOID*)&MmpGlobalDataPtr->MmpDotNet->Hooks.OriginGetFileVersion2, HookGetFileVersion); + + DetourTransactionCommit(); + + MmpGlobalDataPtr->MmpDotNet->PreHooked = FALSE; + } + + if (MmpGlobalDataPtr->MmpDotNet->Initialized) { + DetourTransactionBegin(); + DetourUpdateThread(NtCurrentThread()); + DetourDetach((PVOID*)&MmpGlobalDataPtr->MmpDotNet->Hooks.OriginGetFileVersion1, HookGetFileVersion); + DetourTransactionCommit(); + MmpGlobalDataPtr->MmpDotNet->Initialized = FALSE; + } + + LeaveCriticalSection(NtCurrentPeb()->FastPebLock); +} diff --git a/MemoryModule/MmpDotNet.h b/MemoryModule/MmpDotNet.h index 14253e58..be433b04 100644 --- a/MemoryModule/MmpDotNet.h +++ b/MemoryModule/MmpDotNet.h @@ -8,4 +8,5 @@ typedef HRESULT(WINAPI* GetFileVersion_T)( ); BOOL WINAPI MmpPreInitializeHooksForDotNet(); -BOOL WINAPI MmpInitializeHooksForDotNet(); \ No newline at end of file +BOOL WINAPI MmpInitializeHooksForDotNet(); +VOID WINAPI MmpCleanupDotNetHooks(); diff --git a/MemoryModule/MmpGlobalData.h b/MemoryModule/MmpGlobalData.h index e4530a2d..2a986b0e 100644 --- a/MemoryModule/MmpGlobalData.h +++ b/MemoryModule/MmpGlobalData.h @@ -100,7 +100,7 @@ typedef enum class _WINDOWS_VERSION :BYTE { #define MEMORY_MODULE_GET_MINOR_VERSION(MinorVersion) (~0x8000&(MinorVersion)) #define MEMORY_MODULE_MAJOR_VERSION 2 -#define MEMORY_MODULE_MINOR_VERSION MEMORY_MODULE_MAKE_PREVIEW(1) +#define MEMORY_MODULE_MINOR_VERSION MEMORY_MODULE_MAKE_PREVIEW(2) typedef struct _MMP_GLOBAL_DATA { @@ -137,6 +137,8 @@ typedef struct _MMP_GLOBAL_DATA { PMMP_IAT_DATA MmpIat; + DWORD ReferenceCount; + }MMP_GLOBAL_DATA, * PMMP_GLOBAL_DATA; #define MMP_GLOBAL_DATA_SIZE (\ diff --git a/MemoryModule/MmpTls.cpp b/MemoryModule/MmpTls.cpp index 4bc0aa30..71f7d358 100644 --- a/MemoryModule/MmpTls.cpp +++ b/MemoryModule/MmpTls.cpp @@ -7,6 +7,7 @@ #include #include #include <3rdparty/Detours/detours.h> +#include PVOID NTAPI MmpQuerySystemInformation( @@ -402,6 +403,96 @@ BOOL NTAPI PreHookNtSetInformationProcess() { return success; } +int MmpSyncThreadTlsData() { + PSYSTEM_PROCESS_INFORMATION pspi = (PSYSTEM_PROCESS_INFORMATION)MmpQuerySystemInformation(SYSTEM_INFORMATION_CLASS::SystemProcessInformation, nullptr); + PSYSTEM_PROCESS_INFORMATION current = pspi; + std::setthreads; + int count = 0; + + // + // Build thread id set. + // + + PLIST_ENTRY entry = MmpGlobalDataPtr->MmpTls->MmpThreadLocalStoragePointer.Flink; + while (entry != &MmpGlobalDataPtr->MmpTls->MmpThreadLocalStoragePointer) { + PMMP_TLSP_RECORD j = CONTAINING_RECORD(entry, MMP_TLSP_RECORD, InMmpThreadLocalStoragePointer); + threads.insert(j->UniqueThread); + + entry = entry->Flink; + } + + while (pspi) { + + if (current->UniqueProcessId == NtCurrentTeb()->ClientId.UniqueProcess) { + + for (ULONG index = 0; index < current->NumberOfThreads; ++index) { + CLIENT_ID cid = current->Threads[index].ClientId; + + if (threads.find(cid.UniqueThread) == threads.end()) { + + HANDLE hThread; + OBJECT_ATTRIBUTES oa{}; + NTSTATUS status = NtOpenThread(&hThread, THREAD_QUERY_INFORMATION, &oa, &cid); + if (NT_SUCCESS(status)) { + + THREAD_BASIC_INFORMATION tbi{}; + status = NtQueryInformationThread(hThread, THREADINFOCLASS::ThreadBasicInformation, &tbi, sizeof(tbi), nullptr); + if (NT_SUCCESS(status)) { + + PTEB teb = tbi.TebBaseAddress; + if (teb->ThreadLocalStoragePointer) { + + // + // Allocate TLS record + // + + auto record = PMMP_TLSP_RECORD(RtlAllocateHeap(RtlProcessHeap(), 0, sizeof(MMP_TLSP_RECORD))); + if (record) { + record->TlspLdrBlock = (PVOID*)teb->ThreadLocalStoragePointer; + record->TlspMmpBlock = (PVOID*)MmpAllocateTlsp(); + record->UniqueThread = cid.UniqueThread; + if (record->TlspMmpBlock) { + record->TlspMmpBlock = ((PTLS_VECTOR)record->TlspMmpBlock)->ModuleTlsData; + + auto size = CONTAINING_RECORD(record->TlspLdrBlock, TLS_VECTOR, ModuleTlsData)->Length; + if ((HANDLE)(ULONG_PTR)size != record->UniqueThread) { + RtlCopyMemory( + record->TlspMmpBlock, + record->TlspLdrBlock, + size * sizeof(PVOID) + ); + } + + teb->ThreadLocalStoragePointer = record->TlspMmpBlock; + InsertTailList(&MmpGlobalDataPtr->MmpTls->MmpThreadLocalStoragePointer, &record->InMmpThreadLocalStoragePointer); + InterlockedIncrement(&MmpGlobalDataPtr->MmpTls->MmpActiveThreadCount); + + ++count; + } + else { + RtlFreeHeap(RtlProcessHeap(), 0, record); + } + } + } + } + + NtClose(hThread); + } + + } + } + + break; + } + + if (!current->NextEntryOffset)break; + current = (PSYSTEM_PROCESS_INFORMATION)((PBYTE)current + current->NextEntryOffset); + } + + RtlFreeHeap(RtlProcessHeap(), 0, pspi); + return count; +} + NTSTATUS NTAPI HookNtSetInformationProcess( _In_opt_ HANDLE ProcessHandle, _In_ PROCESSINFOCLASS ProcessInformationClass, @@ -423,6 +514,12 @@ NTSTATUS NTAPI HookNtSetInformationProcess( PPROCESS_TLS_INFORMATION Tls = nullptr; NTSTATUS status = STATUS_SUCCESS; + // + // Sync thread data with ntdll!Ldr. + // + + MmpSyncThreadTlsData(); + do { if (ProcessTlsInformation->OperationType >= MaxProcessTlsOperation) { status = STATUS_INVALID_PARAMETER; @@ -456,7 +553,7 @@ NTSTATUS NTAPI HookNtSetInformationProcess( break; } - // reserved 0x50 PVOID for ntdll loader + // reserved 0x80 PVOID for ntdll loader if (ProcessTlsInformation->TlsVectorLength >= MMP_START_TLS_INDEX) { status = STATUS_NO_MEMORY; break; @@ -496,50 +593,55 @@ NTSTATUS NTAPI HookNtSetInformationProcess( // EnterCriticalSection(&MmpGlobalDataPtr->MmpTls->MmpTlspLock); for (ULONG i = 0; i < Tls->ThreadDataCount; ++i) { - BOOL found = FALSE; - PLIST_ENTRY entry = MmpGlobalDataPtr->MmpTls->MmpThreadLocalStoragePointer.Flink; - // Find thread-spec tlsp - while (entry != &MmpGlobalDataPtr->MmpTls->MmpThreadLocalStoragePointer) { + if (Tls->ThreadData[i].Flags == 2) { - PMMP_TLSP_RECORD j = CONTAINING_RECORD(entry, MMP_TLSP_RECORD, InMmpThreadLocalStoragePointer); + BOOL found = FALSE; + PLIST_ENTRY entry = MmpGlobalDataPtr->MmpTls->MmpThreadLocalStoragePointer.Flink; - if (ProcessTlsInformation->OperationType == ProcessTlsReplaceVector) { - if (j->TlspMmpBlock[ProcessTlsInformation->TlsVectorLength] == ProcessTlsInformation->ThreadData[i].TlsVector[ProcessTlsInformation->TlsVectorLength]) { - found = TRUE; + // Find thread-spec tlsp + while (entry != &MmpGlobalDataPtr->MmpTls->MmpThreadLocalStoragePointer) { - // Copy old data to new pointer - RtlCopyMemory( - ProcessTlsInformation->ThreadData[i].TlsVector, - j->TlspMmpBlock, - sizeof(PVOID) * ProcessTlsInformation->TlsVectorLength - ); + PMMP_TLSP_RECORD j = CONTAINING_RECORD(entry, MMP_TLSP_RECORD, InMmpThreadLocalStoragePointer); - // Swap the tlsp - std::swap( - j->TlspLdrBlock, - ProcessTlsInformation->ThreadData[i].TlsVector - ); + if (ProcessTlsInformation->OperationType == ProcessTlsReplaceVector) { + if (j->TlspMmpBlock[ProcessTlsInformation->TlsVectorLength] == ProcessTlsInformation->ThreadData[i].TlsVector[ProcessTlsInformation->TlsVectorLength]) { + found = TRUE; + + // Copy old data to new pointer + RtlCopyMemory( + ProcessTlsInformation->ThreadData[i].TlsVector, + j->TlspMmpBlock, + sizeof(PVOID) * ProcessTlsInformation->TlsVectorLength + ); + + // Swap the tlsp + std::swap( + j->TlspLdrBlock, + ProcessTlsInformation->ThreadData[i].TlsVector + ); + } } - } - else { - if (j->TlspMmpBlock[ProcessTlsInformation->TlsIndex] == ProcessTlsInformation->ThreadData[i].TlsModulePointer) { - found = TRUE; + else { + if (j->TlspMmpBlock[ProcessTlsInformation->TlsIndex] == ProcessTlsInformation->ThreadData[i].TlsModulePointer) { + found = TRUE; + + if (ProcessHandle) { + j->TlspLdrBlock[ProcessTlsInformation->TlsIndex] = ProcessTlsInformation->ThreadData[i].TlsModulePointer; + } - if (ProcessHandle) { - j->TlspLdrBlock[ProcessTlsInformation->TlsIndex] = ProcessTlsInformation->ThreadData[i].TlsModulePointer; + ProcessTlsInformation->ThreadData[i].TlsModulePointer = Tls->ThreadData[i].TlsModulePointer; } - - ProcessTlsInformation->ThreadData[i].TlsModulePointer = Tls->ThreadData[i].TlsModulePointer; } + + if (found)break; + entry = entry->Flink; } - if (found)break; - entry = entry->Flink; + ProcessTlsInformation->ThreadData[i].Flags = Tls->ThreadData[i].Flags; + ProcessTlsInformation->ThreadData[i].ThreadId = Tls->ThreadData[i].ThreadId; } - ProcessTlsInformation->ThreadData[i].Flags = Tls->ThreadData[i].Flags; - ProcessTlsInformation->ThreadData[i].ThreadId = Tls->ThreadData[i].ThreadId; } LeaveCriticalSection(&MmpGlobalDataPtr->MmpTls->MmpTlspLock); @@ -798,4 +900,15 @@ BOOL NTAPI MmpTlsInitialize() { return TRUE; } +VOID NTAPI MmpTlsCleanup() { + + DetourTransactionBegin(); + DetourUpdateThread(NtCurrentThread()); + DetourDetach((PVOID*)&MmpGlobalDataPtr->MmpTls->Hooks.OriginLdrShutdownThread, HookLdrShutdownThread); + DetourDetach((PVOID*)&MmpGlobalDataPtr->MmpTls->Hooks.OriginNtSetInformationProcess, HookNtSetInformationProcess); + DetourDetach((PVOID*)&MmpGlobalDataPtr->MmpTls->Hooks.OriginRtlUserThreadStart, HookRtlUserThreadStart); + DetourTransactionCommit(); + +} + #endif diff --git a/MemoryModule/MmpTls.h b/MemoryModule/MmpTls.h index 33de5c9c..4c2fe0e3 100644 --- a/MemoryModule/MmpTls.h +++ b/MemoryModule/MmpTls.h @@ -2,6 +2,8 @@ BOOL NTAPI MmpTlsInitialize(); +VOID NTAPI MmpTlsCleanup(); + NTSTATUS NTAPI MmpReleaseTlsEntry(_In_ PLDR_DATA_TABLE_ENTRY lpModuleEntry); NTSTATUS NTAPI MmpHandleTlsData(_In_ PLDR_DATA_TABLE_ENTRY lpModuleEntry); diff --git a/MemoryModule/MmpTlsFiber.cpp b/MemoryModule/MmpTlsFiber.cpp index 3b5391e2..7ea4d8c9 100644 --- a/MemoryModule/MmpTlsFiber.cpp +++ b/MemoryModule/MmpTlsFiber.cpp @@ -3,6 +3,7 @@ #include "MmpTlsFiber.h" #include +#include typedef struct _MMP_POSTPONED_TLS { @@ -31,8 +32,7 @@ DWORD WINAPI MmpReleasePostponedTlsWorker(PVOID) { auto iter = MmpPostponedTlsList->begin(); while (iter != MmpPostponedTlsList->end()) { - const auto& item = *iter; - GetExitCodeThread(item.hThread, &code); + GetExitCodeThread(iter->hThread, &code); if (code == STILL_ACTIVE) { ++iter; @@ -41,7 +41,7 @@ DWORD WINAPI MmpReleasePostponedTlsWorker(PVOID) { RtlAcquireSRWLockExclusive(&MmpGlobalDataPtr->MmpTls->MmpTlsListLock); - auto TlspMmpBlock = (PVOID*)item.lpOldTlsVector->ModuleTlsData; + auto TlspMmpBlock = (PVOID*)iter->lpOldTlsVector->ModuleTlsData; auto entry = MmpGlobalDataPtr->MmpTls->MmpTlsList.Flink; while (entry != &MmpGlobalDataPtr->MmpTls->MmpTlsList) { @@ -51,13 +51,13 @@ DWORD WINAPI MmpReleasePostponedTlsWorker(PVOID) { entry = entry->Flink; } - RtlFreeHeap(RtlProcessHeap(), 0, CONTAINING_RECORD(item.lpTlsRecord->TlspLdrBlock, TLS_VECTOR, TLS_VECTOR::ModuleTlsData)); - RtlFreeHeap(RtlProcessHeap(), 0, item.lpTlsRecord); - RtlFreeHeap(RtlProcessHeap(), 0, item.lpOldTlsVector); + RtlFreeHeap(RtlProcessHeap(), 0, CONTAINING_RECORD(iter->lpTlsRecord->TlspLdrBlock, TLS_VECTOR, TLS_VECTOR::ModuleTlsData)); + RtlFreeHeap(RtlProcessHeap(), 0, iter->lpTlsRecord); + RtlFreeHeap(RtlProcessHeap(), 0, iter->lpOldTlsVector); RtlReleaseSRWLockExclusive(&MmpGlobalDataPtr->MmpTls->MmpTlsListLock); - CloseHandle(item.hThread); + CloseHandle(iter->hThread); iter = MmpPostponedTlsList->erase(iter); } @@ -95,6 +95,7 @@ VOID WINAPI MmpQueuePostponedTls(PMMP_TLSP_RECORD record) { ); item.lpOldTlsVector = MmpAllocateTlsp(); + assert(item.lpOldTlsVector); item.lpTlsRecord = record; diff --git a/test/test.cpp b/test/test.cpp index 3ea4e308..e5c964af 100644 --- a/test/test.cpp +++ b/test/test.cpp @@ -129,7 +129,7 @@ int test() { } int main() { - + DisplayStatus(); test(); diff --git a/test/test.vcxproj b/test/test.vcxproj index a7616ce9..aee15d34 100644 --- a/test/test.vcxproj +++ b/test/test.vcxproj @@ -39,7 +39,7 @@ {5B3131BA-178A-4A28-BD54-315A45C97ED1} Win32Proj test - 10.0 + 10.0.22621.0