-
Notifications
You must be signed in to change notification settings - Fork 0
/
CommonCrawlSagemakerReader.java
86 lines (72 loc) · 4.72 KB
/
CommonCrawlSagemakerReader.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
package com.letsdata.commoncrawl.interfaces.implementations.sagemaker;
import com.letsdata.commoncrawl.interfaces.implementations.documents.CompositeIndexRecord;
import com.letsdata.commoncrawl.interfaces.implementations.documents.IndexRecord;
import com.letsdata.commoncrawl.interfaces.implementations.documents.VectorRecord;
import com.letsdata.commoncrawl.model.filerecords.warc.AbstractWARCRecord;
import com.resonance.letsdata.data.documents.interfaces.DocumentInterface;
import com.resonance.letsdata.data.readers.interfaces.sagemaker.SagemakerVectorsInterface;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.HashMap;
import java.util.Map;
public class CommonCrawlSagemakerReader implements SagemakerVectorsInterface {
private static final Logger logger = LoggerFactory.getLogger(CommonCrawlSagemakerReader.class);
@Override
public Map<String, String> extractDocumentElementsForVectorization(DocumentInterface documentInterface) {
if (documentInterface == null) {
logger.error("extractDocumentElementsForVectorization - documentInterface is null");
throw new RuntimeException("extractDocumentElementsForVectorization - documentInterface is null");
}
if (!(documentInterface instanceof CompositeIndexRecord)) {
logger.error("extractDocumentElementsForVectorization - documentInterface is expected to be of type CompositeIndexRecord");
throw new RuntimeException("extractDocumentElementsForVectorization - documentInterface is expected to be of type CompositeIndexRecord");
}
if (!(((CompositeIndexRecord)documentInterface).getDocumentList().getKey() instanceof IndexRecord)) {
logger.error("extractDocumentElementsForVectorization - indexRecord not found in CompositeIndexRecord");
throw new RuntimeException("extractDocumentElementsForVectorization - indexRecord not found in CompositeIndexRecord");
}
try {
IndexRecord indexRecord = (IndexRecord) ((CompositeIndexRecord)documentInterface).getDocumentList().getKey();
Map<String, String> docMap = new HashMap<>();
if (indexRecord.getDocText() != null) {
docMap.put("DocText", indexRecord.getDocText());
}
if (indexRecord.getDescription() != null) {
docMap.put("DocDescription", indexRecord.getDescription());
}
return docMap;
} catch (Exception ex) {
logger.error("extractDocumentElementsForVectorization threw an exception - doc: {}, ex: {}", documentInterface, ex);
throw new RuntimeException("extractDocumentElementsForVectorization threw an exception", ex);
}
}
@Override
public DocumentInterface constructVectorDoc(DocumentInterface documentInterface, Map<String, Double[]> vectorsMap) {
if (documentInterface == null) {
logger.error("constructFeatureAndVectorDocs - documentInterface is null");
throw new RuntimeException("constructFeatureAndVectorDocs - documentInterface is null");
}
if (vectorsMap == null || vectorsMap.isEmpty()) {
logger.error("constructFeatureAndVectorDocs - vectorsMap is null or empty");
throw new RuntimeException("constructFeatureAndVectorDocs - vectorsMap is null or empty");
}
if (!(documentInterface instanceof CompositeIndexRecord)) {
logger.error("constructFeatureAndVectorDocs - documentInterface is expected to be of type CompositeIndexRecord");
throw new RuntimeException("constructFeatureAndVectorDocs - documentInterface is expected to be of type CompositeIndexRecord");
}
if (!(((CompositeIndexRecord)documentInterface).getDocumentList().getKey() instanceof IndexRecord)) {
logger.error("constructFeatureAndVectorDocs - indexRecord not found in CompositeIndexRecord");
throw new RuntimeException("constructFeatureAndVectorDocs - indexRecord not found in CompositeIndexRecord");
}
IndexRecord indexRecord = (IndexRecord) ((CompositeIndexRecord)documentInterface).getDocumentList().getKey();
Double[] docTextVectors = vectorsMap.get("DocText");
Double[] docDescriptionVectors = vectorsMap.get("DocDescription");
String recordType = "VECTOR";
String documentId = indexRecord.getDocumentId();
// momento vector index key should be less than 256 bytes. 400 threshold to allow for some testing
if (documentId.length() > 400) {
documentId = AbstractWARCRecord.computeMD5Hash(documentId);
}
return new VectorRecord(indexRecord.getUrl(), documentId, recordType, indexRecord.getDocumentMetadata(), docTextVectors, docDescriptionVectors);
}
}