Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

provide more granular way to manage embeddings' cache #1

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Original file line number Diff line number Diff line change
@@ -1,17 +1,25 @@
package com.exadel.frs.core.trainservice.cache;

import com.exadel.frs.commonservice.entity.Embedding;
import com.exadel.frs.commonservice.projection.EmbeddingProjection;
import com.exadel.frs.core.trainservice.dto.CacheActionDto;
import com.exadel.frs.core.trainservice.dto.CacheActionDto.AddEmbeddings;
import com.exadel.frs.core.trainservice.dto.CacheActionDto.CacheAction;
import com.exadel.frs.core.trainservice.dto.CacheActionDto.RemoveEmbeddings;
import com.exadel.frs.core.trainservice.dto.CacheActionDto.RemoveSubjects;
import com.exadel.frs.core.trainservice.dto.CacheActionDto.RenameSubjects;
import com.exadel.frs.core.trainservice.service.EmbeddingService;
import com.exadel.frs.core.trainservice.service.NotificationSenderService;
import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;

import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;

import static com.exadel.frs.core.trainservice.system.global.Constants.SERVER_UUID;

Expand All @@ -34,34 +42,81 @@ public class EmbeddingCacheProvider {
.build();

public EmbeddingCollection getOrLoad(final String apiKey) {

var result = cache.getIfPresent(apiKey);

if (result == null) {
result = embeddingService.doWithEnhancedEmbeddingProjectionStream(apiKey, EmbeddingCollection::from);

cache.put(apiKey, result);

notifyCacheEvent("UPDATE", apiKey);
}

return result;
}

public void ifPresent(String apiKey, Consumer<EmbeddingCollection> consumer) {
public void removeEmbedding(String apiKey, EmbeddingProjection embedding) {
Optional.ofNullable(cache.getIfPresent(apiKey))
.ifPresent(consumer);
.ifPresent(
ec -> {
ec.removeEmbedding(embedding);
notifyCacheEvent(
CacheAction.REMOVE_EMBEDDINGS,
apiKey,
new RemoveEmbeddings(Map.of(embedding.subjectName(), List.of(embedding.embeddingId())))
);
}
);
}

cache.getIfPresent(apiKey);
notifyCacheEvent("UPDATE", apiKey);
public void updateSubjectName(String apiKey, String oldSubjectName, String newSubjectName) {
Optional.ofNullable(cache.getIfPresent(apiKey))
.ifPresent(
ec -> {
ec.updateSubjectName(oldSubjectName, newSubjectName);
notifyCacheEvent(CacheAction.RENAME_SUBJECTS, apiKey, new RenameSubjects(Map.of(oldSubjectName, newSubjectName)));
}
);
}

public void removeBySubjectName(String apiKey, String subjectName) {
Optional.ofNullable(cache.getIfPresent(apiKey))
.ifPresent(
ec -> {
ec.removeEmbeddingsBySubjectName(subjectName);
notifyCacheEvent(CacheAction.REMOVE_SUBJECTS, apiKey, new RemoveSubjects(List.of(subjectName)));
}
);
}


public void addEmbedding(String apiKey, Embedding embedding) {
Optional.ofNullable(cache.getIfPresent(apiKey))
.ifPresent(
ec -> {
ec.addEmbedding(embedding);
notifyCacheEvent(CacheAction.ADD_EMBEDDINGS, apiKey, new AddEmbeddings(List.of(embedding.getId())));
}
);
}

/**
* Method can be used to make changes in cache without sending notification.
* Use it carefully, because changes you do will not be visible for other compreface-api instances
*
* @param apiKey domain
* @param action what to do with {@link EmbeddingCollection}
*/
public void expose(String apiKey, Consumer<EmbeddingCollection> action) {
Optional.ofNullable(cache.getIfPresent(apiKey))
.ifPresent(action);
}

public void invalidate(final String apiKey) {
cache.invalidate(apiKey);
notifyCacheEvent("DELETE", apiKey);
notifyCacheEvent(CacheAction.INVALIDATE, apiKey, null);
}


/**
* @deprecated
* See {@link com.exadel.frs.core.trainservice.service.NotificationHandler#handleUpdate(CacheActionDto)}
*/
@Deprecated(forRemoval = true)
public void receivePutOnCache(String apiKey) {
var result = embeddingService.doWithEnhancedEmbeddingProjectionStream(apiKey, EmbeddingCollection::from);
cache.put(apiKey, result);
Expand All @@ -71,8 +126,8 @@ public void receiveInvalidateCache(final String apiKey) {
cache.invalidate(apiKey);
}

private void notifyCacheEvent(String event, String apiKey) {
CacheActionDto cacheActionDto = new CacheActionDto(event, apiKey, SERVER_UUID);
private <T> void notifyCacheEvent(CacheAction event, String apiKey, T action) {
CacheActionDto<T> cacheActionDto = new CacheActionDto<>(event, apiKey, SERVER_UUID, action);
notificationSenderService.notifyCacheChange(cacheActionDto);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ public Collection<String> getSubjectNames(final String apiKey) {
return subjectRepository.getSubjectNames(apiKey);
}

public List<Embedding> loadAllEmbeddingsByIds(Iterable<UUID> ids) {
return embeddingRepository.findAllById(ids);
}

@Transactional
public Subject deleteSubjectByName(final String apiKey, final String subjectName) {
final Optional<Subject> subjectOptional = subjectRepository.findByApiKeyAndSubjectNameIgnoreCase(apiKey, subjectName);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,21 +1,62 @@
package com.exadel.frs.core.trainservice.dto;

import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import java.util.List;
import java.util.Map;
import java.util.UUID;

@Data
@AllArgsConstructor
@NoArgsConstructor
public class CacheActionDto {
@JsonIgnoreProperties(ignoreUnknown = true) // here and below "ignoreUnknown = true" for backward compatibility
public record CacheActionDto<T>(
CacheAction cacheAction,
String apiKey,
@JsonProperty("uuid")
UUID serverUUID,
T payload
) {
public <S> CacheActionDto<S> withPayload(S payload) {
return new CacheActionDto<>(
cacheAction,
apiKey,
serverUUID,
payload
);
}

@JsonProperty("cacheAction")
private String cacheAction;
public enum CacheAction {
// UPDATE and DELETE stays here to support rolling update
@Deprecated
UPDATE,
@Deprecated
DELETE,
REMOVE_EMBEDDINGS,
REMOVE_SUBJECTS,
ADD_EMBEDDINGS,
RENAME_SUBJECTS,
INVALIDATE
}

@JsonProperty("apiKey")
private String apiKey;
@JsonIgnoreProperties(ignoreUnknown = true)
public record RemoveEmbeddings(
Map<String, List<UUID>> embeddings
) {
}

@JsonProperty("uuid")
private String serverUUID;
@JsonIgnoreProperties(ignoreUnknown = true)
public record RemoveSubjects(
List<String> subjects
) {
}

@JsonIgnoreProperties(ignoreUnknown = true)
public record AddEmbeddings(
List<UUID> embeddings
) {
}

@JsonIgnoreProperties(ignoreUnknown = true)
public record RenameSubjects(
Map<String, String> subjectsNamesMapping
) {
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import com.exadel.frs.core.trainservice.system.global.Constants;
import java.util.stream.Stream;
import lombok.RequiredArgsConstructor;
import lombok.val;
import org.springframework.data.domain.Page;
import org.springframework.data.domain.Pageable;
import org.springframework.stereotype.Service;
Expand All @@ -32,9 +31,9 @@ public int updateEmbedding(UUID embeddingId, double[] embedding, String calculat
return embeddingRepository.updateEmbedding(embeddingId, embedding, calculator);
}

@Transactional
@org.springframework.transaction.annotation.Transactional(readOnly = true)
public <T> T doWithEnhancedEmbeddingProjectionStream(String apiKey, Function<Stream<EnhancedEmbeddingProjection>, T> func) {
try (val stream = embeddingRepository.findBySubjectApiKey(apiKey)) {
try (var stream = embeddingRepository.findBySubjectApiKey(apiKey)) {
return func.apply(stream);
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
package com.exadel.frs.core.trainservice.service;

import com.exadel.frs.commonservice.projection.EmbeddingProjection;
import com.exadel.frs.core.trainservice.cache.EmbeddingCacheProvider;
import com.exadel.frs.core.trainservice.dto.CacheActionDto;
import com.exadel.frs.core.trainservice.dto.CacheActionDto.AddEmbeddings;
import com.exadel.frs.core.trainservice.dto.CacheActionDto.RemoveEmbeddings;
import com.exadel.frs.core.trainservice.dto.CacheActionDto.RemoveSubjects;
import com.exadel.frs.core.trainservice.dto.CacheActionDto.RenameSubjects;
import java.util.Objects;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.stereotype.Service;

@Slf4j
@Service
@RequiredArgsConstructor
public class NotificationHandler {
private final EmbeddingCacheProvider cacheProvider;
private final SubjectService subjectService;

public void removeEmbeddings(CacheActionDto<RemoveEmbeddings> action) {
action.payload().embeddings()
.entrySet()
.stream()
.filter(e -> StringUtils.isNotBlank(e.getKey()))
.filter(e -> Objects.nonNull(e.getValue()))
.filter(e -> !e.getValue().isEmpty())
.flatMap(e -> e.getValue().stream().filter(Objects::nonNull).map(id -> new EmbeddingProjection(id, e.getKey())))
.forEach(
em -> cacheProvider.expose(
action.apiKey(),
c -> c.removeEmbedding(em)
)
);
}

public void removeSubjects(CacheActionDto<RemoveSubjects> action) {
action.payload().subjects()
.stream()
.filter(StringUtils::isNotBlank)
.forEach(
s -> cacheProvider.expose(
action.apiKey(),
c -> c.removeEmbeddingsBySubjectName(s)
)
);
}


public void addEmbeddings(CacheActionDto<AddEmbeddings> action) {
var filtered = action.payload().embeddings()
.stream()
.filter(Objects::nonNull)
.toList();
subjectService.loadEmbeddingsById(filtered)
.forEach(
em -> cacheProvider.expose(
action.apiKey(),
c -> c.addEmbedding(em)
)
);
}

public void renameSubjects(CacheActionDto<RenameSubjects> action) {
action.payload().subjectsNamesMapping()
.entrySet()
.stream()
.filter(e -> StringUtils.isNotBlank(e.getKey()))
.filter(e -> StringUtils.isNotBlank(e.getValue()))
.forEach(
e -> cacheProvider.expose(
action.apiKey(),
c -> c.updateSubjectName(e.getKey(), e.getValue())
)
);
}

public <T> void invalidate(CacheActionDto<T> action) {
cacheProvider.expose(
action.apiKey(),
e -> cacheProvider.receiveInvalidateCache(action.apiKey())
);
}

/**
* @param action cacheAction
* @deprecated in favour more granular cache managing.
* See {@link CacheActionDto}.
* Stays here to support rolling update
*/
@Deprecated(forRemoval = true)
public <T> void handleDelete(CacheActionDto<T> action) {
cacheProvider.receiveInvalidateCache(action.apiKey());
}

/**
* @param action cacheAction
* @deprecated in favour more granular cache managing.
* See {@link CacheActionDto}.
* Stays here to support rolling update
*/
@Deprecated(forRemoval = true)
public <T> void handleUpdate(CacheActionDto<T> action) {
cacheProvider.receivePutOnCache(action.apiKey());
}
}
Loading