Skip to content

Commit

Permalink
Remove id from Record
Browse files Browse the repository at this point in the history
  • Loading branch information
datumbox committed Apr 5, 2015
1 parent 7e3b70b commit 6d30b8c
Show file tree
Hide file tree
Showing 51 changed files with 337 additions and 281 deletions.
8 changes: 5 additions & 3 deletions src/main/java/com/datumbox/applications/nlp/CETR.java
Expand Up @@ -141,7 +141,8 @@ private List<Integer> selectRows(List<String> rows, Parameters parameters) {

Map<Object, Double> avgTTRscorePerCluster = new HashMap<>();
Map<Object, Integer> clusterCounts = new HashMap<>();
for(Record r : dataset) {
for(Integer rId : dataset) {
Record r = dataset.get(rId);
Integer clusterId = (Integer)r.getYPredicted();
Double ttr = r.getX().getDouble(0); //the first value is always set the TTR as you can see above

Expand Down Expand Up @@ -170,11 +171,12 @@ private List<Integer> selectRows(List<String> rows, Parameters parameters) {
Integer nonContentClusterId = (Integer)entry.getKey();

List<Integer> selectedRows = new ArrayList<>();
for(Record r : dataset) {
for(Integer rId : dataset) {
Record r = dataset.get(rId);
Integer clusterId = (Integer)r.getYPredicted();
//if the point is not classified as non-content add it in the selected list
if(!Objects.equals(clusterId, nonContentClusterId)) {
selectedRows.add(r.getId());
selectedRows.add(rId);
}
}

Expand Down
Expand Up @@ -195,7 +195,8 @@ public List<Object> predict(List<String> text) {

//extract responses
List<Object> predictedClasses = new LinkedList<>();
for(Record r : predictedDataset) {
for(Integer rId : predictedDataset) {
Record r = predictedDataset.get(rId);
predictedClasses.add(r.getYPredicted());
}
predictedDataset = null;
Expand Down Expand Up @@ -224,7 +225,8 @@ public List<AssociativeArray> predictProbabilities(List<String> text) {

//extract responses
List<AssociativeArray> predictedClassProbabilities = new LinkedList<>();
for(Record r : predictedDataset) {
for(Integer rId : predictedDataset) {
Record r = predictedDataset.get(rId);
predictedClassProbabilities.add(r.getYPredictedProbabilities());
}
predictedDataset = null;
Expand Down
49 changes: 23 additions & 26 deletions src/main/java/com/datumbox/common/dataobjects/Dataset.java
Expand Up @@ -27,7 +27,7 @@
*
* @author bbriniotis
*/
public final class Dataset implements Serializable, Iterable<Record> {
public final class Dataset implements Serializable, Iterable<Integer> {

private final Map<Integer, Record> recordList;

Expand Down Expand Up @@ -118,7 +118,8 @@ public boolean isEmpty() {
public FlatDataList extractColumnValues(Object column) {
FlatDataList flatDataList = new FlatDataList();

for(Record r : recordList.values()) {
for(Integer rId : this) {
Record r = recordList.get(rId);
flatDataList.add(r.getX().get(column));
}

Expand All @@ -134,7 +135,8 @@ public FlatDataList extractColumnValues(Object column) {
public FlatDataList extractYValues() {
FlatDataList flatDataList = new FlatDataList();

for(Record r : recordList.values()) {
for(Integer rId : this) {
Record r = recordList.get(rId);
flatDataList.add(r.getY());
}

Expand All @@ -153,7 +155,8 @@ public FlatDataList extractYValues() {
public TransposeDataList extractColumnValuesByY(Object column) {
TransposeDataList transposeDataList = new TransposeDataList();

for(Record r : recordList.values()) {
for(Integer rId : this) {
Record r = recordList.get(rId);
if(!transposeDataList.containsKey(r.getY())) {
transposeDataList.put(r.getY(), new FlatDataList(new ArrayList<>()) );
}
Expand Down Expand Up @@ -199,7 +202,8 @@ public Record get(Integer id) {
*/
public boolean removeColumn(Object column) {
if(columns.remove(column)!=null) { //try to remove it from the columns and it if it removed remove it from the list too
for(Record r : recordList.values()) {
for(Integer rId : this) {
Record r = recordList.get(rId);
r.getX().remove(column); //TODO: do we need to store the record again in the map?
}

Expand Down Expand Up @@ -227,8 +231,8 @@ private void updateMeta(Record r) {

public void resetMeta() {
columns.clear();
for(Record r: this) {
updateMeta(r);
for(Integer id: this) {
updateMeta(recordList.get(id));
}
}

Expand All @@ -238,30 +242,23 @@ public void resetMeta() {
* @param d
*/
public void merge(Dataset d) {
//does not modify the ids of the records of the Dataset d
for(Record r : d) {
this.add(r);
for(Integer id : d) {
this.add(d.get(id));
}
//TODO: do we still need merge after changing the PCA algorithm and the Dataset?
}

/**
* Adds the record in the dataset. The original record is shallow copied
* and its id is updated (this does not affect the id of the original record).
* The add method returns the id of the new record.
* Adds the record in the dataset. The add method returns the id of the new record.
*
* @param original
* @param r
* @return
*/
public Integer add(Record original) {
Record newRecord = original.quickCopy();

public Integer add(Record r) {
Integer newId=(Integer) recordList.size();
newRecord.setId(newId);
recordList.put(newId, newRecord);
updateMeta(newRecord);
recordList.put(newId, r);
updateMeta(r);

return newRecord.getId();
return newId;
}

/**
Expand All @@ -279,8 +276,8 @@ public void clear() {
* @return
*/
@Override
public Iterator<Record> iterator() {
return new Iterator<Record>() {
public Iterator<Integer> iterator() {
return new Iterator<Integer>() {
private Iterator<Integer> it = recordList.keySet().iterator();

@Override
Expand All @@ -289,8 +286,8 @@ public boolean hasNext() {
}

@Override
public Record next() {
return recordList.get(it.next());
public Integer next() {
return it.next();
}

@Override
Expand Down
14 changes: 8 additions & 6 deletions src/main/java/com/datumbox/common/dataobjects/MatrixDataset.java
Expand Up @@ -79,7 +79,7 @@ public static MatrixDataset newInstance(Dataset dataset, boolean addConstantColu
return m;
}

boolean extractY=(Dataset.value2ColumnType(dataset.iterator().next().getY())==Dataset.ColumnType.NUMERICAL);
boolean extractY=(Dataset.value2ColumnType(dataset.get(dataset.iterator().next()).getY())==Dataset.ColumnType.NUMERICAL);

int previousFeatureId=0;
if(addConstantColumn) {
Expand All @@ -90,8 +90,9 @@ public static MatrixDataset newInstance(Dataset dataset, boolean addConstantColu
++previousFeatureId;
}

for(Record r : dataset) {
int row = r.getId();
for(Integer id : dataset) {
Record r = dataset.get(id);
int row = id;

if(extractY) {
m.Y.setEntry(row, TypeConversions.toDouble(r.getY()));
Expand Down Expand Up @@ -143,12 +144,13 @@ public static MatrixDataset parseDataset(Dataset newDataset, Map<Object, Integer
return m;
}

boolean extractY=(Dataset.value2ColumnType(newDataset.iterator().next().getY())==Dataset.ColumnType.NUMERICAL);
boolean extractY=(Dataset.value2ColumnType(newDataset.get(newDataset.iterator().next()).getY())==Dataset.ColumnType.NUMERICAL);

boolean addConstantColumn = m.feature2ColumnId.containsKey(Dataset.constantColumnName);

for(Record r : newDataset) {
int row = r.getId();
for(Integer id : newDataset) {
Record r = newDataset.get(id);
int row = id;

if(extractY) {
m.Y.setEntry(row, TypeConversions.toDouble(r.getY()));
Expand Down
46 changes: 10 additions & 36 deletions src/main/java/com/datumbox/common/dataobjects/Record.java
Expand Up @@ -24,8 +24,6 @@
* @author bbriniotis
*/
public final class Record implements Serializable {
/* The numeric id of the Record. Used for identification perposes. */
private Integer id;

/* The X vector of the Record */
private AssociativeArray x;
Expand Down Expand Up @@ -53,27 +51,6 @@ public static <T> Record newDataVector(T[] xArray, Object y) {
return r;
}

public Record quickCopy() {
//shallow copy of Record. It is used in order to avoid modifying the ids of the records when assigned to different datasets
Record r = new Record();
//r.id = id;
r.x = x; //shallow copies
r.y = y;
r.yPredicted = yPredicted;
r.yPredictedProbabilities = yPredictedProbabilities; //shallow copies

return r;
}


// Getters and Setters
public Integer getId() {
return id;
}

protected void setId(Integer id) {
this.id = id;
}

public AssociativeArray getX() {
return x;
Expand Down Expand Up @@ -106,9 +83,15 @@ public AssociativeArray getYPredictedProbabilities() {
public void setYPredictedProbabilities(AssociativeArray yPredictedProbabilities) {
this.yPredictedProbabilities = yPredictedProbabilities;
}


// Internal methods

@Override
public int hashCode() {
int hash = 7;
hash = 23 * hash + Objects.hashCode(this.x);
hash = 23 * hash + Objects.hashCode(this.y);
return hash;
}

@Override
public boolean equals(Object obj) {
if (obj == null) {
Expand All @@ -118,23 +101,14 @@ public boolean equals(Object obj) {
return false;
}
final Record other = (Record) obj;
if (!Objects.equals(this.id, other.id)) {
if (!Objects.equals(this.x, other.x)) {
return false;
}
if (!Objects.equals(this.y, other.y)) {
return false;
}
if (this.x.equals(other.x)) {
return false;
}
return true;
}

@Override
public int hashCode() {
int hash = 7;
hash = 59 * hash + Objects.hashCode(this.id);
return hash;
}

}
Expand Up @@ -96,7 +96,8 @@ protected void predictDataset(Dataset newData) {

Map<Object, Double> cachedLogPriors = new HashMap<>(logPriors); //this is small. Size equal to class numbers. We cache it because we don't want to load it again and again from the DB

for(Record r : newData) {
for(Integer id : newData) {
Record r = newData.get(id);
//Build new map here! reinitialize the prediction scores with the scores of the classes
AssociativeArray predictionScores = new AssociativeArray(new HashMap<>(cachedLogPriors));

Expand Down Expand Up @@ -178,7 +179,8 @@ protected void _fit(Dataset trainingData) {

//calculate first statistics about the classes
AssociativeArray totalFeatureOccurrencesForEachClass = new AssociativeArray();
for(Record r : trainingData) {
for(Integer rId : trainingData) {
Record r = trainingData.get(rId);
Object theClass=r.getY();

Double classCount = logPriors.get(theClass);
Expand All @@ -195,7 +197,8 @@ protected void _fit(Dataset trainingData) {
}

//now calculate the statistics of features
for(Record r : trainingData) {
for(Integer rId : trainingData) {
Record r = trainingData.get(rId);

//store the occurrances of the features
for(Map.Entry<Object, Object> entry : r.getX().entrySet()) {
Expand Down
Expand Up @@ -100,7 +100,8 @@ public MaximumEntropy(String dbName, DatabaseConfiguration dbConf) {
protected void predictDataset(Dataset newData) {
Set<Object> classesSet = knowledgeBase.getModelParameters().getClasses();

for(Record r : newData) {
for(Integer rId : newData) {
Record r = newData.get(rId);
AssociativeArray predictionScores = new AssociativeArray();
for(Object theClass : classesSet) {
predictionScores.put(theClass, calculateClassScore(r.getX(),theClass));
Expand Down Expand Up @@ -134,7 +135,8 @@ protected void _fit(Dataset trainingData) {
Set<Object> classesSet = modelParameters.getClasses();

//first we need to find all the classes
for(Record r : trainingData) {
for(Integer rId : trainingData) {
Record r = trainingData.get(rId);
Object theClass=r.getY();

classesSet.add(theClass);
Expand All @@ -152,7 +154,8 @@ protected void _fit(Dataset trainingData) {
double increment = 1.0/n; //this is done for speed reasons. We don't want to repeat the same division over and over

//then we calculate the observed probabilities in training set
for(Record r : trainingData) {
for(Integer rId : trainingData) {
Record r = trainingData.get(rId);
int activeFeatures=0; //counts the number of non-zero (active) features of the record

//store the occurrances of the features
Expand Down Expand Up @@ -231,7 +234,8 @@ private void IIS(Dataset trainingData, Map<List<Object>, Double> EpFj_observed,
}

//calculate the model probabilities
for(Record r : trainingData) {
for(Integer rId : trainingData) {
Record r = trainingData.get(rId);

AssociativeArray classScores = new AssociativeArray();

Expand Down

0 comments on commit 6d30b8c

Please sign in to comment.