/
BasicClassFactory.cs
105 lines (89 loc) · 3.33 KB
/
BasicClassFactory.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
using System;
using System.Runtime.InteropServices;
namespace COMRegistration
{
// https://docs.microsoft.com/windows/win32/api/unknwn/nn-unknwn-iclassfactory
[ComImport]
[ComVisible(false)]
[Guid("00000001-0000-0000-C000-000000000046")]
[InterfaceType(ComInterfaceType.InterfaceIsIUnknown)]
internal interface IClassFactory
{
void CreateInstance(
[MarshalAs(UnmanagedType.Interface)] object pUnkOuter,
ref Guid riid,
out IntPtr ppvObject);
void LockServer([MarshalAs(UnmanagedType.Bool)] bool fLock);
}
[ComVisible(true)]
internal class BasicClassFactory<T> : IClassFactory where T : new()
{
public void CreateInstance(
[MarshalAs(UnmanagedType.Interface)] object pUnkOuter,
ref Guid riid,
out IntPtr ppvObject)
{
Type interfaceType = GetValidatedInterfaceType(typeof(T), ref riid, pUnkOuter);
object obj = new T();
if (pUnkOuter != null)
{
obj = CreateAggregatedObject(pUnkOuter, obj);
}
ppvObject = GetObjectAsInterface(obj, interfaceType);
}
public void LockServer([MarshalAs(UnmanagedType.Bool)] bool fLock) { }
private static readonly Guid IID_IUnknown = Guid.Parse("00000000-0000-0000-C000-000000000046");
private static Type GetValidatedInterfaceType(Type classType, ref Guid riid, object outer)
{
if (riid == IID_IUnknown)
{
return typeof(object);
}
// Aggregation can only be done when requesting IUnknown.
if (outer != null)
{
const int CLASS_E_NOAGGREGATION = unchecked((int)0x80040110);
throw new COMException(string.Empty, CLASS_E_NOAGGREGATION);
}
// Verify the class implements the desired interface
foreach (Type i in classType.GetInterfaces())
{
if (i.GUID == riid)
{
return i;
}
}
// E_NOINTERFACE
throw new InvalidCastException();
}
private static IntPtr GetObjectAsInterface(object obj, Type interfaceType)
{
// If the requested "interface type" is type object then return as IUnknown
if (interfaceType == typeof(object))
{
return Marshal.GetIUnknownForObject(obj);
}
IntPtr interfaceMaybe = Marshal.GetComInterfaceForObject(obj, interfaceType, CustomQueryInterfaceMode.Ignore);
if (interfaceMaybe == IntPtr.Zero)
{
// E_NOINTERFACE
throw new InvalidCastException();
}
return interfaceMaybe;
}
private static object CreateAggregatedObject(object pUnkOuter, object comObject)
{
IntPtr outerPtr = Marshal.GetIUnknownForObject(pUnkOuter);
try
{
IntPtr innerPtr = Marshal.CreateAggregatedObject(outerPtr, comObject);
return Marshal.GetObjectForIUnknown(innerPtr);
}
finally
{
// Decrement the above 'Marshal.GetIUnknownForObject()'
Marshal.Release(outerPtr);
}
}
}
}