/
DefaultKeycloakSessionFactory.java
executable file
·136 lines (110 loc) · 5.02 KB
/
DefaultKeycloakSessionFactory.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
package org.keycloak.services;
import org.jboss.logging.Logger;
import org.keycloak.Config;
import org.keycloak.models.KeycloakSession;
import org.keycloak.models.KeycloakSessionFactory;
import org.keycloak.models.RealmModel;
import org.keycloak.models.RealmProvider;
import org.keycloak.provider.Provider;
import org.keycloak.provider.ProviderEvent;
import org.keycloak.provider.ProviderEventListener;
import org.keycloak.provider.ProviderFactory;
import org.keycloak.provider.ProviderManager;
import org.keycloak.provider.Spi;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.ServiceLoader;
import java.util.Set;
import java.util.concurrent.CopyOnWriteArrayList;
public class DefaultKeycloakSessionFactory implements KeycloakSessionFactory {
private static final Logger log = Logger.getLogger(DefaultKeycloakSessionFactory.class);
private Map<Class<? extends Provider>, String> provider = new HashMap<Class<? extends Provider>, String>();
private Map<Class<? extends Provider>, Map<String, ProviderFactory>> factoriesMap = new HashMap<Class<? extends Provider>, Map<String, ProviderFactory>>();
protected CopyOnWriteArrayList<ProviderEventListener> listeners = new CopyOnWriteArrayList<ProviderEventListener>();
@Override
public void register(ProviderEventListener listener) {
listeners.add(listener);
}
@Override
public void unregister(ProviderEventListener listener) {
listeners.remove(listener);
}
@Override
public void publish(ProviderEvent event) {
for (ProviderEventListener listener : listeners) {
listener.onEvent(event);
}
}
public void init() {
ProviderManager pm = new ProviderManager(getClass().getClassLoader(), Config.scope().getArray("providers"));
for (Spi spi : ServiceLoader.load(Spi.class)) {
Map<String, ProviderFactory> factories = new HashMap<String, ProviderFactory>();
factoriesMap.put(spi.getProviderClass(), factories);
String provider = Config.getProvider(spi.getName());
if (provider != null) {
this.provider.put(spi.getProviderClass(), provider);
ProviderFactory factory = pm.load(spi, provider);
if (factory == null) {
throw new RuntimeException("Failed to find provider " + provider + " for " + spi.getName());
}
Config.Scope scope = Config.scope(spi.getName(), provider);
factory.init(scope);
factories.put(factory.getId(), factory);
log.debugv("Loaded SPI {0} (provider = {1})", spi.getName(), provider);
} else {
for (ProviderFactory factory : pm.load(spi)) {
Config.Scope scope = Config.scope(spi.getName(), factory.getId());
factory.init(scope);
factories.put(factory.getId(), factory);
}
if (factories.size() == 1) {
provider = factories.values().iterator().next().getId();
this.provider.put(spi.getProviderClass(), provider);
log.debugv("Loaded SPI {0} (provider = {1})", spi.getName(), provider);
} else {
log.debugv("Loaded SPI {0} (providers = {1})", spi.getName(), factories.keySet());
}
}
}
}
public KeycloakSession create() {
return new DefaultKeycloakSession(this);
}
<T extends Provider> String getDefaultProvider(Class<T> clazz) {
return provider.get(clazz);
}
@Override
public <T extends Provider> ProviderFactory<T> getProviderFactory(Class<T> clazz) {
return getProviderFactory(clazz, provider.get(clazz));
}
@Override
public <T extends Provider> ProviderFactory<T> getProviderFactory(Class<T> clazz, String id) {
return factoriesMap.get(clazz).get(id);
}
@Override
public List<ProviderFactory> getProviderFactories(Class<? extends Provider> clazz) {
List<ProviderFactory> list = new LinkedList<ProviderFactory>();
if (factoriesMap == null) return list;
Map<String, ProviderFactory> providerFactoryMap = factoriesMap.get(clazz);
if (providerFactoryMap == null) return list;
list.addAll(providerFactoryMap.values());
return list;
}
<T extends Provider> Set<String> getAllProviderIds(Class<T> clazz) {
Set<String> ids = new HashSet<String>();
for (ProviderFactory f : factoriesMap.get(clazz).values()) {
ids.add(f.getId());
}
return ids;
}
public void close() {
for (Map<String, ProviderFactory> factories : factoriesMap.values()) {
for (ProviderFactory factory : factories.values()) {
factory.close();
}
}
}
}