Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Help me first version #262

Merged
merged 44 commits into from
May 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
b928c3b
feat: help me first version
Ciroye Apr 5, 2024
e329fa5
feat: wonders first version
Ciroye Apr 9, 2024
6e9b3c2
feat: changes nibbler
Ciroye Apr 12, 2024
032a3bc
feat: changes nibbler
Ciroye Apr 12, 2024
9cbfaa7
feat: verification duplicate prompt
Ciroye Apr 15, 2024
b881bc3
feat: verification duplicate prompt
Ciroye Apr 15, 2024
d8c27a7
All models
remg1997 Apr 15, 2024
b95a927
Merge branch 'feat/final-helpme' of github.com:mlcommons/dynabench in…
remg1997 Apr 15, 2024
4ae0e35
feat: last touches nibbler
Ciroye Apr 16, 2024
53c353f
Merge branch 'main' of https://github.com/mlcommons/dynabench into fe…
Ciroye Apr 16, 2024
b702df9
feat: adjust helpme
Ciroye Apr 17, 2024
84546b1
feat: adjust number of images nibbler
Ciroye Apr 19, 2024
ea84e78
Merge multi_gen
remg1997 Apr 26, 2024
a1acad2
Job Service
remg1997 Apr 26, 2024
b2b820f
Include job model
remg1997 Apr 26, 2024
de6aa1f
Basic worker config
remg1997 Apr 26, 2024
8a89044
Task definitions
remg1997 Apr 26, 2024
235e6eb
Include celery dockerfile and reqs
remg1997 Apr 26, 2024
51a1e85
Final fixes
remg1997 Apr 26, 2024
5b33867
feat: adding nibbler jobs fe
Ciroye Apr 27, 2024
789b3c3
Change log path
remg1997 Apr 27, 2024
69fc9b5
feat: new messages
Ciroye Apr 27, 2024
3638890
feat: moving nibbler dev to prod
Ciroye May 1, 2024
3e277f0
feat: wonders changes
Ciroye May 2, 2024
ae30adf
feat: update number of models
Ciroye May 3, 2024
a70fa79
feat: update number of models
Ciroye May 3, 2024
a6183e7
feat: update number of models
Ciroye May 3, 2024
363d053
feat: update number of models
Ciroye May 3, 2024
1438862
feat: update number of models
Ciroye May 3, 2024
41700e7
feat: put a comment
Ciroye May 6, 2024
b3b20a9
feat: error in history nibbler
Ciroye May 7, 2024
cc641c0
feat: wonders and help-me changes
Ciroye May 8, 2024
e95ab5e
feat: fix problem nibbler
Ciroye May 8, 2024
db20315
feat: changes nibbler
Ciroye May 15, 2024
9f05648
Merge branch 'feat/final-helpme' of github.com:mlcommons/dynabench in…
remg1997 May 15, 2024
72efc42
Add model name to example table
remg1997 May 15, 2024
1f756d8
feat: remove split by dot
Ciroye May 15, 2024
b2c09e7
feat: default model help me
Ciroye May 15, 2024
5030f0e
New fixes
remg1997 May 15, 2024
0d3c505
Merge branch 'feat/final-helpme' of github.com:mlcommons/dynabench in…
remg1997 May 15, 2024
a4588d4
Merge branch 'feat/final-helpme' of https://github.com/mlcommons/dyna…
Ciroye May 15, 2024
48b3bd7
feat: fix counter
Ciroye May 15, 2024
8ee7e04
Logging celery tasks
remg1997 May 16, 2024
e8d5638
Change logger path
remg1997 May 16, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
39 changes: 24 additions & 15 deletions backend/app/api/endpoints/base/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import asyncio
import json

from fastapi import APIRouter, WebSocket
from fastapi import APIRouter, File, UploadFile
from fastapi.responses import JSONResponse
from sse_starlette.sse import EventSourceResponse

Expand All @@ -32,21 +32,9 @@ async def get_context_configuration(task_id: int):
return context_config


@router.websocket("/ws/get_generative_contexts")
async def websocket_generative_context(websocket: WebSocket):
await websocket.accept()
model_info = await websocket.receive_json()
model_info = dict(model_info)
while True:
data = ContextService().get_generative_contexts(
model_info["type"], model_info["artifacts"]
)
await websocket.send_json(data)
await asyncio.sleep(1)


@router.post("/get_generative_contexts")
async def get_generative_contexts(model: GetGenerativeContextRequest):
async def get_generative_contexts(model: dict):
model = GetGenerativeContextRequest(**model)
image_list = await ContextService().get_generative_contexts(
model.type, model.artifacts
)
Expand All @@ -57,6 +45,7 @@ async def get_generative_contexts(model: GetGenerativeContextRequest):
async def stream_images(model_info: GetGenerativeContextRequest):
async def event_generator():
iterations = 0
model_info.artifacts["num_images"] = 4
for _ in range(model_info.artifacts["num_batches"]):
data = await ContextService().get_generative_contexts(
model_info.type, model_info.artifacts
Expand All @@ -75,3 +64,23 @@ async def event_generator():
def get_filter_context(model: GetFilterContext):
context = ContextService().get_filter_context(model.real_round_id, model.filters)
return context


@router.post("/get_contexts_from_s3")
def get_contexts_from_s3(artifacts: dict):
return ContextService().get_contexts_from_s3(artifacts)


@router.post("/save_contexts_to_s3")
def save_contexts_to_s3(
task_id: int,
language: str,
country: str,
category: str,
concept: str,
description: str,
file: UploadFile = File(...),
):
return ContextService().save_contexts_to_s3(
file, task_id, language, country, description, category, concept
)
19 changes: 19 additions & 0 deletions backend/app/api/endpoints/base/historical_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from app.domain.schemas.base.historical_data import (
DeleteHistoricalDataRequest,
GetHistoricalData,
GetHistoricalDataRequest,
GetSaveHistoricalDataRequest,
)
Expand Down Expand Up @@ -36,6 +37,24 @@ async def save_historical_data(model: GetSaveHistoricalDataRequest):
)


@router.post("/get_occurrences_with_more_than_one_hundred")
async def get_occurrences_with_more_than_one_hundred(model: GetHistoricalData):
history = HistoricalDataService().get_occurrences_with_more_than_one_hundred(
model.task_id
)
return history


@router.post("/check_if_historical_data_exists")
async def check_if_historical_data_exists(
model: GetSaveHistoricalDataRequest,
):
history = HistoricalDataService().check_if_historical_data_exists(
model.task_id, model.user_id, model.data
)
return history


@router.post("/delete_historical_data")
async def delete_historical_data(model: DeleteHistoricalDataRequest):
return HistoricalDataService().delete_historical_data(model.task_id, model.user_id)
2 changes: 1 addition & 1 deletion backend/app/api/endpoints/base/rounduserexample.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,5 +49,5 @@ async def increment_counter_examples_submitted_today(
@router.post("/redirect_to_third_party_provider", response_model={})
async def redirect_to_third_party_provider(model: RedirectThirdPartyProvider):
return RoundUserExampleInfoService().redirect_to_third_party_provider(
model.task_id, model.user_id, model.round_id
model.task_id, model.user_id, model.round_id, model.url
)
4 changes: 4 additions & 0 deletions backend/app/domain/schemas/base/historical_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ class GetHistoricalDataRequest(BaseModel):
user_id: int


class GetHistoricalData(BaseModel):
task_id: int


class GetSaveHistoricalDataRequest(BaseModel):
task_id: int
user_id: int
Expand Down
3 changes: 3 additions & 0 deletions backend/app/domain/schemas/base/rounduserexample.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (c) MLCommons and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Optional

from pydantic import BaseModel


Expand All @@ -23,3 +25,4 @@ class RedirectThirdPartyProvider(BaseModel):
task_id: int
user_id: int
round_id: int
url: Optional[str] = None
176 changes: 119 additions & 57 deletions backend/app/domain/services/base/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
# LICENSE file in the root directory of this source tree.

import base64
import hashlib
import io
import json
import os
import random
Expand All @@ -13,10 +11,11 @@
import boto3
import yaml
from fastapi import HTTPException
from PIL import Image
from worker.tasks import generate_images

from app.domain.services.base.historical_data import HistoricalDataService
from app.domain.services.base.jobs import JobService
from app.domain.services.base.task import TaskService
from app.domain.services.utils.constant import black_image, forbidden_image
from app.domain.services.utils.llm import (
AlephAlphaProvider,
AnthropicProvider,
Expand All @@ -27,15 +26,17 @@
OpenAIProvider,
ReplicateProvider,
)
from app.domain.services.utils.multi_generator import ImageGenerator, LLMGenerator
from app.domain.services.utils.multi_generator import LLMGenerator
from app.infrastructure.repositories.context import ContextRepository
from app.infrastructure.repositories.round import RoundRepository


class ContextService:
def __init__(self):
self.jobs_service = JobService()
self.context_repository = ContextRepository()
self.round_repository = RoundRepository()
self.historical_data_service = HistoricalDataService()
self.task_service = TaskService()
self.session = boto3.Session(
aws_access_key_id=os.getenv("AWS_ACCESS_KEY_ID"),
Expand Down Expand Up @@ -128,64 +129,88 @@ def get_context_configuration(self, task_id) -> dict:
}
return context_info

def get_nibbler_contexts(
self, prompt: str, user_id: int, num_images: int, models: list, endpoint: str
async def get_nibbler_contexts(
self,
prompt: str,
user_id: int,
num_images: int,
models: list,
endpoint: str,
prompt_already_exists_for_user: bool,
prompt_with_more_than_one_hundred: bool,
task_id: int,
) -> dict:
images = []
start = time.time()
multi_generator = ImageGenerator()
generated_images = multi_generator.generate_all_images(
prompt, num_images, models, endpoint
prompt_already_exists_for_user = (
self.historical_data_service.check_if_historical_data_exists(
task_id, user_id, prompt
)
)
images = []
for generator_dict in generated_images:
if generator_dict:
for image in generator_dict.get("images", []):
image_id = (
generator_dict["generator"]
+ "_"
+ prompt
+ "_"
+ str(user_id)
+ "_"
+ hashlib.md5(image.encode()).hexdigest()
)
print(image_id)
image_bytes = io.BytesIO(base64.b64decode(image))
img = Image.open(image_bytes)
img = img.convert("L")
average_intensity = img.getdata()
average_intensity = sum(average_intensity) / len(average_intensity)
if average_intensity < 10:
print("Image too dark, skipping")
new_dict = {
"image": forbidden_image,
"id": hashlib.md5(forbidden_image.encode()).hexdigest(),
}
images.append(new_dict)

elif black_image in image:
new_dict = {
"image": forbidden_image,
"id": hashlib.md5(forbidden_image.encode()).hexdigest(),
}
images.append(new_dict)

else:
num_of_current_images = 0
print("Prompt already exists for user", prompt_already_exists_for_user)
if prompt_already_exists_for_user:
print("Prompt already exists for user")
# Download the images from the s3 bucket
key = f"adversarial-nibbler/{prompt}/{user_id}"
objects = self.s3.list_objects_v2(Bucket=self.dataperf_bucket, Prefix=key)
if "Contents" in objects:
if len(objects["Contents"]) > 4:
for obj in objects["Contents"]:
image_id = obj["Key"].split("/")[-1].replace(".jpeg", "")
image = self.s3.get_object(
Bucket=self.dataperf_bucket, Key=obj["Key"]
)
image_bytes = image["Body"].read()
image = base64.b64encode(image_bytes).decode("utf-8")
new_dict = {
"image": image,
"id": image_id,
}
filename = f"adversarial-nibbler/{image_id}.jpeg"
self.s3.put_object(
Body=base64.b64decode(image),
Bucket=self.dataperf_bucket,
Key=filename,
)
images.append(new_dict)
random.shuffle(images)
print(f"Time to generate images: {time.time() - start}")
return images
return images
else:
num_of_current_images = len(objects["Contents"])

if prompt_with_more_than_one_hundred:
print("Prompt with less than 100 images")
key = f"adversarial-nibbler/{prompt}"
objects = self.s3.list_objects_v2(Bucket=self.dataperf_bucket, Prefix=key)
users = []
if "Contents" in objects:
users = [obj["Key"] for obj in objects["Contents"]]
users = list({item.split("/")[2] for item in users})
print(f"Users are {users}")
random_user = random.choice(users)
print(f"Random user is {random_user}")
key = f"adversarial-nibbler/{prompt}/{random_user}"
objects = self.s3.list_objects_v2(Bucket=self.dataperf_bucket, Prefix=key)
if "Contents" in objects:
for obj in objects["Contents"]:
image_id = obj["Key"].split("/")[-1].replace(".jpeg", "")
image = self.s3.get_object(
Bucket=self.dataperf_bucket, Key=obj["Key"]
)
image_bytes = image["Body"].read()
image = base64.b64encode(image_bytes).decode("utf-8")
new_dict = {
"image": image,
"id": image_id,
}
images.append(new_dict)
return images
print("generating new images")
self.jobs_service.create_registry({"prompt": prompt, "user_id": user_id})
generate_images.delay(
prompt, num_images, models, endpoint, user_id, num_of_current_images
)
queue_position = self.jobs_service.determine_queue_position(
{"prompt": prompt, "user_id": user_id}
)
return {
"message": "Images are being generated",
"queue_position": queue_position["queue_position"],
"all_positions": queue_position["all_positions"],
}

async def get_perdi_contexts(
self, prompt: str, number_of_samples: int, models: dict
Expand Down Expand Up @@ -221,12 +246,29 @@ async def generate_images_stream(self, model_info):

async def get_generative_contexts(self, type: str, artifacts: dict) -> dict:
if type == "nibbler":
return self.get_nibbler_contexts(
exists = self.jobs_service.metadata_exists(
{"prompt": artifacts["prompt"], "user_id": artifacts["user_id"]}
)
print("Exists is", exists)
if exists:
queue_data = self.jobs_service.determine_queue_position(
{"prompt": artifacts["prompt"], "user_id": artifacts["user_id"]}
)
print("Queue data is", queue_data)
return queue_data
return await self.get_nibbler_contexts(
prompt=artifacts["prompt"],
user_id=artifacts["user_id"],
models=artifacts["model"],
endpoint=artifacts["model"],
num_images=6,
prompt_already_exists_for_user=artifacts[
"prompt_already_exists_for_user"
],
prompt_with_more_than_one_hundred=artifacts[
"prompt_with_more_than_one_hundred"
],
num_images=artifacts.get("num_images", 12),
task_id=artifacts.get("task_id", 59),
)
elif type == "perdi":
return await self.get_perdi_contexts(
Expand All @@ -244,3 +286,23 @@ def get_filter_context(self, real_round_id: int, filters: dict) -> dict:
if context.get(key).lower() == value.lower():
filter_contexts.append(context)
return random.choice(filter_contexts)

def get_contexts_from_s3(self, artifacts: dict):
artifacts = artifacts["artifacts"]
task_code = self.task_service.get_task_code_by_task_id(artifacts["task_id"])[0]
file_name = f"Top_{artifacts['country']}-{artifacts['language']}_Concepts.json"
key = f"{task_code}/{file_name}"
obj = self.s3.get_object(Bucket=self.dataperf_bucket, Key=key)
body = obj["Body"].read()
return json.loads(body)

def save_contexts_to_s3(
self, file, task_id, language, country, description, category, concept
):
task_code = self.task_service.get_task_code_by_task_id(task_id)[0]
random_id = random.randint(0, 100000)
file_name = f"{description}-{random_id}.jpeg"
key = f"{task_code}/{country}/{language}/{category}/{concept}/{file_name}"
file.file.seek(0)
self.s3.put_object(Bucket=self.dataperf_bucket, Key=key, Body=file.file)
return key
2 changes: 1 addition & 1 deletion backend/app/domain/services/base/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def partial_creation_generative_example(
context_id,
user_id,
False,
None,
example_info["select_image"].split("_")[0],
json.dumps(example_info),
json.dumps({}),
json.dumps({}),
Expand Down