-
Notifications
You must be signed in to change notification settings - Fork 0
/
MultiOrderMarkovChainTest.java
315 lines (261 loc) · 14.2 KB
/
MultiOrderMarkovChainTest.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
package net.joeclark.proceduralgeneration;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Nested;
import org.junit.jupiter.api.Test;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;
import java.util.Objects;
import static net.joeclark.proceduralgeneration.MultiOrderMarkovChain.DEFAULT_PRIOR;
import static org.junit.jupiter.api.Assertions.*;
@DisplayName("MultiOrderMarkovChain...")
class MultiOrderMarkovChainTest {
@Test
@DisplayName("Can initialize a chain of Characters")
void CanInitializeAChainOfCharacters() {
MultiOrderMarkovChain<Character> chain = new MultiOrderMarkovChain<>();
chain.addSequence(new ArrayList<>(Arrays.asList('V','W','X','Y','Z')));
chain.addSequence(new ArrayList<>(Arrays.asList('X','Y','Z')));
}
@Test
@DisplayName("Can define a model by direct specification rather than training")
void CanDefineAModelByDirectSpecification() {
MultiOrderMarkovChain<Character> chain = new MultiOrderMarkovChain<>();
chain.specifyLink(Arrays.asList('A'),'B',1.5D);
chain.specifyLink(Arrays.asList('A'),'B',1.25D);
chain.specifyLink(Arrays.asList('B'),'C',1.75D);
assertEquals(1.75D,chain.model.get(Arrays.asList('B')).get('C'),"specifyLink did not correctly specify a link");
assertEquals(1.25D,chain.model.get(Arrays.asList('A')).get('B'),"specifyLink failed to correctly overwrite a prior specification");
}
@Test
@DisplayName("Can initialize a chain of Integers")
void CanInitializeAChainOfIntegers() {
MultiOrderMarkovChain<Integer> chain = new MultiOrderMarkovChain<>();
chain.addSequence(new ArrayList<>(Arrays.asList(1,1,2,3,5,8,13)));
}
@Test
@DisplayName("Can initialize and train a chain of strings on string sequences")
void CanTrainAChainOnStringSequences() {
MultiOrderMarkovChain<String> chain = new MultiOrderMarkovChain<>();
chain.addSequence(new ArrayList<>(Arrays.asList("how","much","wood","would","a","woodchuck","chuck")));
chain.addSequence(new LinkedList<>(Arrays.asList("if","a","woodchuck","could","chuck","wood")));
}
@Test
@DisplayName("Can set maxOrder with a builder-constructor function")
void CanSetMaxOrderWithBuilderConstructor() {
MultiOrderMarkovChain<Float> chain = new MultiOrderMarkovChain<Float>().withMaxOrder(2);
chain.addSequence(Arrays.asList(0.01F,3.14F,1.0F,1.414F,2.718F));
}
@Test
@DisplayName("Can be trained on a Stream<List<T>>")
void CanBeTrainedOnAStream() {
List<List<Character>> trainingData = Arrays.asList(
Arrays.asList('h','e','l','l','o'),
Arrays.asList('w','o','r','l','d')
);
MultiOrderMarkovChain<Character> chain = new MultiOrderMarkovChain<Character>().andTrain(trainingData.stream());
assertTrue(chain.knownStates.contains('h'),"didn't train states in first sequence");
assertTrue(chain.knownStates.contains('d'),"didn't train states in last sequence");
assertTrue(chain.model.get(Arrays.asList('l')).keySet().containsAll(Arrays.asList('l','o','d')),"missed some transitions in the training data stream");
}
@Test
@DisplayName("Short sequences in a training stream don't blow up a batch of training")
void ShortSequencesDontBlowUpTraining() {
List<List<Character>> trainingData = Arrays.asList(
Arrays.asList('i'),
Arrays.asList('l','i','k','e'),
Arrays.asList('s','p','a','m')
);
MultiOrderMarkovChain<Character> chain = new MultiOrderMarkovChain<Character>().andTrain(trainingData.stream());
assertEquals(2,chain.getNumTrainedSequences(),"two sequences should have been ingested, one ignored");
}
@Nested
@DisplayName("When initialized but not trained...")
class WhenUnTrained {
private MultiOrderMarkovChain<Float> chain;
@BeforeEach
void InitializeButDontTrain() {
chain = new MultiOrderMarkovChain<Float>();
}
@Test
@DisplayName("hasModel() should return false")
void HasModelShouldBeFalse() {
assertEquals(false,chain.hasModel(),"hasModel() returned true despite model being empty/untrained");
}
@Test
@DisplayName("allKnownStates should return empty set")
void AllKnownStatesShouldReturnNull() {
assertTrue(chain.allKnownStates().isEmpty());
}
}
@Nested
@DisplayName("Once trained...")
class OnceTrained {
private MultiOrderMarkovChain<String> chain;
@BeforeEach
void trainChain() {
chain = new MultiOrderMarkovChain<>();
chain.addSequence( Arrays.asList("one","small","step","for","man") );
chain.addSequence( Arrays.asList("one","giant","leap","for","mankind") );
}
@Test
@DisplayName("hasModel() should return true")
void HasModelShouldBeTrue() {
assertEquals(true,chain.hasModel(),"hasModel() returned false despite model being trained");
}
@Test
@DisplayName("allKnownStates should include 'from' states")
void AllKnownStatesShouldIncludeAllFromStates() {
assertTrue(chain.allKnownStates().contains("one"));
}
@Test
@DisplayName("allKnownStates should includes states that are both 'from' and 'to' states")
void AllKnownStatesShoudlIncludeToAndFromStates() {
assertTrue(chain.allKnownStates().contains("for"));
}
@Test
@DisplayName("allKnownStates should include all 'to' states")
void AllKnownStatesShouldIncludeAllToStates() {
assertTrue(chain.allKnownStates().contains("mankind"));
}
@Test
@DisplayName("unweightedRandomNext choice should be one of the trained links")
void RandomNextShouldBeOneOfTheTrainedLinks() {
String next = chain.unweightedRandomNext(Arrays.asList("one"));
assertTrue(next.equals("small") || next.equals("giant"));
}
@Test
@DisplayName("weightedRandomNext choice should be one of the trained links (assuming priors haven't been added)")
void WeightedRandomNextShouldBeOneOfTheTrainedLinks() {
String next = chain.weightedRandomNext(Arrays.asList("one"));
assertTrue(next.equals("small") || next.equals("giant"));
}
@Test
@DisplayName("unweightedRandomNext should throw exception if state doesn't exist")
void RandomNextShouldThrowExceptionIfStateUnknown() {
assertThrows(IllegalArgumentException.class,() -> chain.unweightedRandomNext(Arrays.asList("moon")));
}
@Test
@DisplayName("unweightedRandomNext should throw exception if state has no 'to' links")
void RandomNextShouldReturnNullIfStateHasNoToLinks() {
assertThrows(IllegalStateException.class, () -> chain.unweightedRandomNext(Arrays.asList("mankind")));
}
@Test
@DisplayName("allPossibleNext should throw exception if state doesn't exist")
void AllPossibleNextShouldThrowExceptionIfStateUnknown() {
assertThrows(IllegalArgumentException.class,() -> chain.allPossibleNext(Arrays.asList("moon")));
}
@Test
@DisplayName("allPossibleNext should return empty set if state has no 'to' links")
void AllPossibleNextShouldReturnEmptyIfStateHasNoToLinks() {
assertTrue(chain.allPossibleNext(Arrays.asList("mankind")).isEmpty());
}
@Test
@DisplayName("priors can be added for links that were not observed")
void PriorsCanBeAdded() {
chain.addPriors(0.005D);
assertEquals(1.0D,chain.model.get(Arrays.asList("one")).get("small"),"adding priors should not affect already-observed links");
assertEquals(0.005D,chain.model.get(Arrays.asList("one")).get("step"),"priors should be added for unobserved links from non-terminal states");
assertFalse(chain.model.containsKey(Arrays.asList("mankind")),"adding priors should not create new models for terminal states");
}
@Test
@DisplayName("priors can be changed by calling removeWeakLinks and then addPriors")
void PriorsCanBeChanged() {
chain.addPriors(0.005D);
assertEquals(1.0D,chain.model.get(Arrays.asList("one")).get("small"),"adding priors should not affect already-observed links");
assertEquals(0.005D,chain.model.get(Arrays.asList("one")).get("step"),"priors should be added for unobserved links from non-terminal states");
assertFalse(chain.model.containsKey(Arrays.asList("mankind")),"adding priors should not create new models for terminal states");
chain.removeWeakLinks(1D);
assertFalse(chain.model.get(Arrays.asList("one")).containsKey("step"),"removeWeakLinks failed to remove a prior");
assertTrue(chain.model.get(Arrays.asList("one")).containsKey("small"),"removeWeakLinks removed a link it shouldn't have removed");
chain.addPriors(0.001D);
assertEquals(0.001D,chain.model.get(Arrays.asList("one")).get("step"),"addPriors after removeWeakLinks did not successfully change the prior");
}
@Test
@DisplayName("priors can be set and removed with shortcut functions")
void PriorsCanBeSetAndRemovedWithShortcuts() {
chain.addPriors();
assertEquals(DEFAULT_PRIOR,chain.model.get(Arrays.asList("one")).get("step"),"prior was not set to DEFAULT_PRIOR by the shortcut addPriors()");
chain.removeWeakLinks();
assertFalse(chain.model.get(Arrays.asList("one")).containsKey("step"),"shortcut removeWeakLinks() failed to remove a prior");
}
@Test
@DisplayName("Can be serialized and deserialized")
void CanBeSerializedAndDeserialized() throws IOException, ClassNotFoundException {
chain.setMaxOrder(2);
FileOutputStream fileOutputStream = new FileOutputStream("target/mychain.ser");
ObjectOutputStream objectOutputStream = new ObjectOutputStream(fileOutputStream);
objectOutputStream.writeObject(chain);
objectOutputStream.flush();
objectOutputStream.close();
FileInputStream fileInputStream = new FileInputStream("target/mychain.ser");
ObjectInputStream objectInputStream = new ObjectInputStream(fileInputStream);
@SuppressWarnings("unchecked")
MultiOrderMarkovChain<String> loadedChain = (MultiOrderMarkovChain<String>) objectInputStream.readObject();
objectInputStream.close();
System.out.println(chain.model);
System.out.println(loadedChain.model);
assertEquals(chain.getMaxOrder(),loadedChain.getMaxOrder(),"an int field was not preserved through serialization-deserialization");
assertEquals(chain.model,loadedChain.model,"the markov model was not preserved through serialization-deserialization");
assertEquals(chain.numTrainedSequences,loadedChain.numTrainedSequences,"numTrainedSequences was not preserved through serialization-deserialization");
assertEquals(chain.knownStates,loadedChain.knownStates,"knownStates was not preserved through serialization-deserialization");
assertEquals(chain.random.nextInt(),loadedChain.random.nextInt(),"random was not preserved through serialization-deserialization");
assertEquals(chain,loadedChain,"the serialized-deserialzed chain is not .equals() to the original chain");
}
}
@Nested
@DisplayName("With a complex object type")
class WithAComplexObjectType {
class WeatherPattern implements Serializable {
public String condition;
public Integer temperature;
public Character windDirection;
public String getCondition() { return condition; }
public Integer getTemperature() { return temperature; }
public Character getWindDirection() { return windDirection; }
public WeatherPattern(String condition, Integer temperature, Character windDirection) {
this.condition = condition;
this.temperature = temperature;
this.windDirection = windDirection;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
WeatherPattern that = (WeatherPattern) o;
return Objects.equals(condition, that.condition) && Objects.equals(temperature, that.temperature) && Objects.equals(windDirection, that.windDirection);
}
@Override
public int hashCode() {
return Objects.hash(condition, temperature, windDirection);
}
@Override
public String toString() {
return "{" + condition + "}";
}
}
@Test
@DisplayName("MultiOrderMarkovChain should behave the same as with primitive types")
void ShouldBehaveTheSame() {
MultiOrderMarkovChain<WeatherPattern> weatherchain = new MultiOrderMarkovChain<>();
WeatherPattern sunny = new WeatherPattern("sunny",75,'W');
WeatherPattern cloudy = new WeatherPattern("cloudy",55,'N');
WeatherPattern partlycloudy = new WeatherPattern("partly cloudy",65,'S');
WeatherPattern stormy = new WeatherPattern("stormy",50,'E');
weatherchain.addSequence(Arrays.asList(sunny,partlycloudy,cloudy,stormy,partlycloudy,sunny));
System.out.println(weatherchain.getModel());
WeatherPattern next = weatherchain.weightedRandomNext(Arrays.asList(partlycloudy));
assertTrue( next.equals(cloudy) || next.equals(sunny) );
assertTrue( weatherchain.weightedRandomNext(Arrays.asList(sunny,partlycloudy)).equals(cloudy) );
}
}
}