Skip to content

Commit

Permalink
CUDA対応
Browse files Browse the repository at this point in the history
Closes #18
  • Loading branch information
johtani committed Jul 20, 2023
1 parent d5613c5 commit 10a4741
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 1 deletion.
4 changes: 4 additions & 0 deletions .devcontainer/devcontainer.json
Expand Up @@ -2,8 +2,12 @@
"name": "search-research",
"dockerComposeFile": "../docker-compose.yml",
"service": "backend",
"runServices": ["backend"],
"workspaceFolder": "/workspace/search-research",
"postCreateCommand": "/bin/sh ./.devcontainer/postCreateCommand.sh",
"features": {
"ghcr.io/devcontainers/features/nvidia-cuda:1": {}
},
"customizations": {
"vscode": {
"settings": {
Expand Down
4 changes: 4 additions & 0 deletions backend/es/processors.py
@@ -1,3 +1,4 @@
import logging
from typing import Any, Dict

import japanese_clip as ja_clip
Expand All @@ -21,10 +22,13 @@ class JaClipEncodeProcessor(Processor):
設定先は、target_fieldで指定
"""

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
_MODEL_NAME = "rinna/japanese-clip-vit-b-16"

def __init__(self, target_field: str):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.logger.info(f"device is {self.device}")
self.model, preprocess = ja_clip.load(self._MODEL_NAME, device=self.device)
self.tokenizer = ja_clip.load_tokenizer()
self.target_field = target_field
Expand Down
9 changes: 8 additions & 1 deletion docker-compose.yml
@@ -1,4 +1,4 @@
version: "2.2"
version: "3.8"

services:
backend:
Expand All @@ -9,6 +9,13 @@ services:
environment:
- TZ=Asia/Tokyo
command: sleep infinity
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: 1
capabilities: [gpu]
volumes:
- .:/workspace/search-research:cached
- venv-search-research-backend:/workspace/search-research/.venv
Expand Down

0 comments on commit 10a4741

Please sign in to comment.