Skip to content

Commit

Permalink
Replaced Map<Object, Object> with Map<Object, Double> when possible.
Browse files Browse the repository at this point in the history
  • Loading branch information
datumbox committed Dec 28, 2016
1 parent e5e6377 commit 81ebda1
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 9 deletions.
Expand Up @@ -712,8 +712,8 @@ else if(initializationMethod==TrainingParameters.Initialization.PLUS_PLUS) {
StorageEngine storageEngine = knowledgeBase.getStorageEngine();
Set<Integer> alreadyAddedPoints = new HashSet(); //this is small. equal to k
for(int i = 0; i < k; ++i) {
Map<Object, Object> tmp_minClusterDistance = storageEngine.getBigMap("tmp_minClusterDistance", Object.class, Object.class, MapType.HASHMAP, StorageHint.IN_MEMORY, true, true);
AssociativeArray minClusterDistanceArray = new AssociativeArray(tmp_minClusterDistance);
Map<Object, Double> tmp_minClusterDistance = storageEngine.getBigMap("tmp_minClusterDistance", Object.class, Double.class, MapType.HASHMAP, StorageHint.IN_MEMORY, true, true);
AssociativeArray minClusterDistanceArray = new AssociativeArray((Map)tmp_minClusterDistance);

streamExecutor.forEach(StreamMethods.stream(trainingData.entries(), isParallelized()), e -> {
Integer rId = e.getKey();
Expand All @@ -739,7 +739,6 @@ else if(initializationMethod==TrainingParameters.Initialization.PLUS_PLUS) {
Integer selectedRecordId = (Integer) SimpleRandomSampling.weightedSampling(minClusterDistanceArray, 1, true).iterator().next();

storageEngine.dropBigMap("tmp_minClusterDistance", tmp_minClusterDistance);
//minClusterDistanceArray = null;


alreadyAddedPoints.add(selectedRecordId);
Expand Down
Expand Up @@ -31,10 +31,10 @@
import com.datumbox.framework.core.statistics.descriptivestatistics.Descriptives;
import com.datumbox.framework.core.statistics.sampling.SimpleRandomSampling;

import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;


/**
Expand Down Expand Up @@ -527,7 +527,7 @@ private int collapsedGibbsSampling(Dataframe dataset) {
}

private AssociativeArray clusterProbabilities(Record r, int n, Map<Integer, CL> clusterMap) {
Map<Object, Object> condProbCiGivenXiAndOtherCi = new ConcurrentHashMap<>();
Map<Integer, Double> condProbCiGivenXiAndOtherCi = new HashMap<>();
double alpha = knowledgeBase.getTrainingParameters().getAlpha();

//Probabilities that appear on https://www.cs.cmu.edu/~kbe/dp_tutorial.pdf
Expand All @@ -543,7 +543,7 @@ private AssociativeArray clusterProbabilities(Record r, int n, Map<Integer, CL>
condProbCiGivenXiAndOtherCi.put(clusterId, marginalLogLikelihoodXi+Math.log(mixingXi)); //concurrent map and non-overlapping keys for each thread
}

return new AssociativeArray(condProbCiGivenXiAndOtherCi);
return new AssociativeArray((Map)condProbCiGivenXiAndOtherCi);
}

private Object getSelectedClusterFromScores(AssociativeArray clusterScores) {
Expand Down
Expand Up @@ -109,7 +109,7 @@ protected void _fit(Dataframe trainingData) {
for(Map.Entry<Object, Object> entry: r.getX().entrySet()) {
Object column = entry.getKey();
if(covert2dummy(columnTypes.get(column))) {
referenceLevels.putIfAbsent(column, entry.getValue()); //This Map is an implementation of ConcurrentHashMap and we don't need a synchronized is needed.
referenceLevels.putIfAbsent(column, entry.getValue()); //This Map is thread safe and we don't need a synchronized is needed.
}
}
}
Expand Down
Expand Up @@ -21,7 +21,7 @@
import org.junit.Test;

import java.util.Arrays;
import java.util.concurrent.ConcurrentSkipListMap;
import java.util.HashMap;

import static org.junit.Assert.assertEquals;

Expand All @@ -40,7 +40,7 @@ public void testGetRanksFromValues() {
logger.info("getRanksFromValues");
FlatDataList flatDataCollection = new FlatDataList(Arrays.asList(new Object[]{50,10,10,30,40}));
FlatDataList expResult = new FlatDataList(Arrays.asList(new Object[]{5.0,1.5,1.5,3.0,4.0}));
AssociativeArray expResult2 = new AssociativeArray(new ConcurrentSkipListMap<>());
AssociativeArray expResult2 = new AssociativeArray(new HashMap<>());
expResult2.put(10, 2);
AssociativeArray tiesCounter = Ranks.getRanksFromValues(flatDataCollection);
assertEquals(expResult, flatDataCollection);
Expand Down

0 comments on commit 81ebda1

Please sign in to comment.