Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/iorixxx/RecSys-MPD
Browse files Browse the repository at this point in the history
  • Loading branch information
aliyurekli committed Oct 13, 2018
2 parents a3f10b3 + 9cf9bbd commit 2d3b9cc
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 181 deletions.
15 changes: 2 additions & 13 deletions pom.xml
Expand Up @@ -13,17 +13,6 @@
<lucene.version>7.3.0</lucene.version>
</properties>

<repositories>
<repository>
<id>apache.snapshots</id>
<name>Apache Snapshot Repository</name>
<url>https://repository.apache.org/content/repositories/snapshots</url>
<releases>
<enabled>false</enabled>
</releases>
</repository>
</repositories>

<build>
<sourceDirectory>src/main/java</sourceDirectory>
<testSourceDirectory>src/test/java</testSourceDirectory>
Expand Down Expand Up @@ -55,7 +44,7 @@
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-jar-plugin</artifactId>
<version>3.0.2</version>
<version>3.1.0</version>
<configuration>
<archive>
<manifest>
Expand All @@ -69,7 +58,7 @@
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-shade-plugin</artifactId>
<version>3.1.1</version>
<version>3.2.0</version>
<executions>
<execution>
<phase>package</phase>
Expand Down
210 changes: 45 additions & 165 deletions src/main/java/edu/anadolu/BestSearcher.java
Expand Up @@ -6,17 +6,12 @@
import org.apache.lucene.document.Document;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.Term;
import org.apache.lucene.queryparser.classic.ParseException;
import org.apache.lucene.queryparser.classic.QueryParser;
import org.apache.lucene.queryparser.classic.QueryParserBase;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.spans.SpanFirstQuery;
import org.apache.lucene.search.spans.SpanNearQuery;
import org.apache.lucene.search.spans.SpanQuery;
import org.apache.lucene.search.spans.SpanTermQuery;
import org.apache.lucene.store.FSDirectory;

import java.io.BufferedReader;
Expand All @@ -28,6 +23,7 @@
import java.nio.file.Path;
import java.util.*;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;
import java.util.regex.Pattern;

/**
Expand All @@ -39,26 +35,23 @@ public class BestSearcher implements Closeable {

private final MPD challenge;
private final IndexReader reader;
private final IndexSearcher searcher;

private final AtomicReference<PrintWriter> out;
private final SimilarityConfig similarityConfig;

private final Integer maxPlaylist;

private final Integer maxTrack;

private final CustomSorter sorter;

public BestSearcher(Path indexPath, Path challengePath, Path resultPath, SimilarityConfig similarityConfig, Integer maxPlaylist, Integer maxTrack, CustomSorterConfig sorterConfig) throws Exception {
public BestSearcher(Path indexPath, Path challengePath, SimilarityConfig similarityConfig, Integer maxPlaylist, Integer maxTrack, CustomSorterConfig sorterConfig) throws Exception {
if (!Files.exists(indexPath) || !Files.isDirectory(indexPath) || !Files.isReadable(indexPath)) {
throw new IllegalArgumentException(indexPath + " does not exist or is not a directory.");
}

final Gson GSON = new Gson();

this.reader = DirectoryReader.open(FSDirectory.open(indexPath));
this.searcher = new IndexSearcher(reader);
this.out = new AtomicReference<>(new PrintWriter(Files.newBufferedWriter(resultPath, StandardCharsets.US_ASCII)));
this.maxPlaylist = maxPlaylist;
this.maxTrack = maxTrack;
this.sorter = sorterConfig.getCustomSorter();
Expand All @@ -67,32 +60,18 @@ public BestSearcher(Path indexPath, Path challengePath, Path resultPath, Similar
this.challenge = GSON.fromJson(reader, MPD.class);
}

this.searcher.setSimilarity(similarityConfig.getSimilarity());
this.similarityConfig = similarityConfig;
}

public void search() {
Arrays.stream(this.challenge.playlists).parallel().forEach(playlist -> {

try {
tracksOnly(playlist.tracks, playlist.pid);
public void search(Path resultPath) throws IOException {

/*
HashSet<String> results = new LinkedHashSet<>();
if (playlist.tracks.length == 0) {
titleOnly(playlist.name, playlist.pid, results);
}
else {
if (playlist.isSequential()) {
firstNTracks(playlist.tracks, playlist.pid, results);
}
else {
tracksOnly(playlist.tracks, playlist.pid);
}
}
final AtomicReference<PrintWriter> out = new AtomicReference<>(new PrintWriter(Files.newBufferedWriter(resultPath, StandardCharsets.US_ASCII)));

results.clear();
*/
Arrays.stream(this.challenge.playlists).parallel().forEach(playlist -> {
try {
tracksOnly(playlist.tracks, playlist.pid, out, Track::track_uri, "track_uris");
// tracksOnly(playlist.tracks, playlist.pid, out, Track::artist_uri, "artist_uris");
// tracksOnly(playlist.tracks, playlist.pid, out, Track::album_uri, "album_uris");
} catch (IOException | ParseException e) {
throw new RuntimeException(e);
}
Expand All @@ -107,52 +86,9 @@ public void close() throws IOException {
reader.close();
}

/**
* Predict tracks for a playlist given its title only.
*/
private void titleOnly(String title, int playlistID, HashSet<String> results) throws ParseException, IOException {

QueryParser queryParser = new QueryParser("name", Indexer.icu());
queryParser.setDefaultOperator(QueryParser.Operator.AND);

Query query = queryParser.parse(QueryParserBase.escape(title));

ScoreDoc[] hits = searcher.search(query, maxPlaylist).scoreDocs;

/*
* Try with OR operator, relaxed mode.
*/
if (hits.length == 0) {
queryParser.setDefaultOperator(QueryParser.Operator.OR);

hits = searcher.search(query, maxPlaylist).scoreDocs;
}

for (ScoreDoc hit : hits) {
int docID = hit.doc, pos = -1;
private void tracksOnly(Track[] tracks, int playlistID, AtomicReference<PrintWriter> out, Function<Track, String> map, String field) throws ParseException, IOException {

Document doc = searcher.doc(docID);

if (Integer.parseInt(doc.get("id")) == playlistID) continue;

String[] trackURIs = whiteSpaceSplitter.split(doc.get("track_uris"));

for (String trackURI : trackURIs) {
if (!results.contains(trackURI)) {
if (results.size() < maxTrack) {
pos++;
results.add(trackURI);

//export(playlistID, trackURI, hit.score, trackURIs.length - pos);
} else break;
}
}
}
}

private void tracksOnly(Track[] tracks, int playlistID) throws ParseException, IOException {

QueryParser queryParser = new QueryParser("track_uris", new WhitespaceAnalyzer());
QueryParser queryParser = new QueryParser(field, new WhitespaceAnalyzer());
queryParser.setDefaultOperator(QueryParser.Operator.OR);

HashSet<String> seeds = new HashSet<>(100);
Expand All @@ -164,7 +100,7 @@ private void tracksOnly(Track[] tracks, int playlistID) throws ParseException, I

if (seeds.contains(trackURI)) continue;

builder.append(trackURI).append(' ');
builder.append(map.apply(track)).append(' ');

seeds.add(trackURI);
}
Expand All @@ -173,8 +109,16 @@ private void tracksOnly(Track[] tracks, int playlistID) throws ParseException, I

Query query = queryParser.parse(QueryParserBase.escape(builder.toString().trim()));

IndexSearcher searcher = new IndexSearcher(reader);

searcher.setSimilarity(similarityConfig.getSimilarity());

ScoreDoc[] hits = searcher.search(query, maxPlaylist).scoreDocs;

if (hits.length == 0) {
//TODO handle such cases
}

for (ScoreDoc hit : hits) {
int docID = hit.doc;

Expand All @@ -195,111 +139,47 @@ private void tracksOnly(Track[] tracks, int playlistID) throws ParseException, I
if (rt.maxScore < hit.score) {
rt.maxScore = hit.score;
rt.pos = pos;
rt.playlistId = Integer.parseInt(doc.get("id"));
rt.luceneId = hit.doc;
}

recommendations.putIfAbsent(trackURI, rt);
}

pos --;
pos--;
}
}

List<RecommendedTrack> recommendedTracks = new ArrayList<>(recommendations.values());

sorter.sort(recommendedTracks);

int count = 0;

for (RecommendedTrack rt : recommendedTracks) {
count++;

export(playlistID, rt);

if (count == maxTrack)
break;
}
if (recommendedTracks.size() > maxTrack)
export(playlistID, recommendedTracks.subList(0, maxTrack), out.get());
else
export(playlistID, recommendedTracks, out.get());

System.out.println("Tracks only search for pid: " + playlistID);
recommendedTracks.clear();
}

/**
* Predict tracks for a playlist given its first N tracks, where N can equal 1, 5, 10, 25, or 100.
*/
private void firstNTracks(Track[] tracks, int playlistID, HashSet<String> results) throws ParseException, IOException {

LinkedHashSet<String> seeds = new LinkedHashSet<>(100);

ArrayList<SpanQuery> clauses = new ArrayList<>(tracks.length);


for (Track track : tracks) {

// skip duplicate tracks in the playlist. Only consider the first occurrence of the track.
if (seeds.contains(track.track_uri)) continue;

seeds.add(track.track_uri);
clauses.add(new SpanTermQuery(new Term("track_uris", track.track_uri)));
}

//TODO try to figure out n from tracks.length

final int n;
if (tracks.length < 6)
n = tracks.length + 2; // for n=1 and n=5 use 2 and 7
else if (tracks.length < 26)
n = (int) (tracks.length * 1.5); // for n=10 and n=25 use 15 and 37
else
n = (int) (tracks.length * 1.25); // for n=100 use 125

//TODO experiment with SpanOrQuery or SpanNearQuery. Which one performs better?
final SpanFirstQuery
spanFirstQuery = new SpanFirstQuery(
clauses.size() == 1 ? clauses.get(0) : new SpanNearQuery(clauses.toArray(new SpanQuery[clauses.size()]), 0, true), n);


// todo ScoreDoc[] hits = searcher.search(spanFirstQuery, Integer.MAX_VALUE).scoreDocs;
ScoreDoc[] hits = searcher.search(spanFirstQuery, maxPlaylist).scoreDocs;
private void album(List<RecommendedTrack> recommendedTracks) {

if (hits.length == 0) {
System.out.println("SpanFirst found zero result found for playlistID : " + playlistID + " first " + tracks.length);
tracksOnly(tracks, playlistID);
}

for (ScoreDoc hit : hits) {
int docID = hit.doc, pos = -1;

Document doc = searcher.doc(docID);

if (Integer.parseInt(doc.get("id")) == playlistID) continue;

String[] trackURIs = whiteSpaceSplitter.split(doc.get("track_uris"));

for (String trackURI : trackURIs) {
if (!results.contains(trackURI)) {
if (results.size() < maxTrack) {
pos++;
results.add(trackURI);

//export(playlistID, trackURI, hit.score, trackURIs.length - pos);
} else break;
}
}
}

seeds.clear();
clauses.clear();
}

private synchronized void export(int playlistID, RecommendedTrack track) {
out.get().print(playlistID);
out.get().print(",");
out.get().print(track.trackURI);
out.get().print(",");
out.get().print(track.searchResultFrequency);
out.get().print(",");
out.get().print(track.maxScore);
out.get().print(",");
out.get().print(track.pos);
out.get().println();

private static synchronized void export(int playlistID, List<RecommendedTrack> tracks, PrintWriter out) {
tracks.forEach(track -> {
out.print(playlistID);
out.print(",");
out.print(track.trackURI);
out.print(",");
out.print(track.searchResultFrequency);
out.print(",");
out.print(track.maxScore);
out.print(",");
out.print(track.pos);
out.println();
});
}
}
1 change: 0 additions & 1 deletion src/main/java/edu/anadolu/Indexer.java
Expand Up @@ -125,7 +125,6 @@ public int index(Path indexPath, Path mpdPath) throws IOException {
document.add(new TextField("track_uris", builder.toString().trim(), Field.Store.YES));
document.add(new TextField("album_uris", album.toString().trim(), Field.Store.NO));
document.add(new TextField("artist_uris", artist.toString().trim(), Field.Store.NO));
document.add(new TextField(ShingleFilter.DEFAULT_TOKEN_TYPE, builder.toString().trim(), Field.Store.NO));
seeds.clear();
writer.addDocument(document);
}
Expand Down
3 changes: 3 additions & 0 deletions src/main/java/edu/anadolu/RecommendedTrack.java
Expand Up @@ -12,6 +12,9 @@ public class RecommendedTrack {

int pos;

int playlistId;
int luceneId;

RecommendedTrack(String trackURI) {
this.trackURI = trackURI;
this.searchResultFrequency = 0;
Expand Down
11 changes: 11 additions & 0 deletions src/main/java/edu/anadolu/Track.java
Expand Up @@ -9,4 +9,15 @@ public class Track {
String album_uri;
int duration_ms;

String track_uri() {
return track_uri;
}

String artist_uri() {
return artist_uri;
}

String album_uri() {
return album_uri;
}
}
4 changes: 2 additions & 2 deletions src/main/java/edu/anadolu/app/BestSearchApp.java
Expand Up @@ -22,8 +22,8 @@ public static void main(String[] args) {
Integer maxTrack = Integer.valueOf(args[5]);
CustomSorterConfig sorterConfig = CustomSorterConfig.valueOf(args[6]);

try (BestSearcher searcher = new BestSearcher(indexPath, challengePath, resultPath, similarityConfig, maxPlaylist, maxTrack, sorterConfig)) {
searcher.search();
try (BestSearcher searcher = new BestSearcher(indexPath, challengePath, similarityConfig, maxPlaylist, maxTrack, sorterConfig)) {
searcher.search(resultPath);
} catch (Exception e) {
e.printStackTrace();
}
Expand Down

0 comments on commit 2d3b9cc

Please sign in to comment.