In [1]:
__file__ = "__init__.py"

In [None]:

import os, sys, re, json5
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
import torch.cuda.amp as amp
import pandas as pd
from pathlib import Path
from rapidfuzz import fuzz
from fuzzywuzzy import process
from sentence_transformers import SentenceTransformer
from typing import List, Dict, Optional
import pandas as pd

project_root = Path(__file__).resolve().parents[1]
sys.path.append(str(project_root))

from utils.ncomp import rlst, srlst, clst, glst, rrlst, dtlst, sslst

paths = {
    "processed": os.path.abspath(f"{project_root}/data/storage/processed"),
    "odata": os.path.abspath(f"{project_root}/data/storage/processed/final_cleaning.csv"),
    "config" : os.path.abspath(f"{project_root}/config/model_config.json"),
    "models": os.path.abspath(f"{project_root}/models"),
}
odata = pd.read_csv(paths["odata"])

In [None]:
class ComponentExtractor:
    def __init__(
        self,
        thresholds: Dict[str, float]=None,
        thresholds_name: str="default",
        reset_thresholds: bool=False,
        delete_threshold: str=None,
        ) -> None:

        self.odata = pd.read_csv(paths["odata"])
        self.components = self._load_components()
        self.thresholds = self._load_thresholds(
            thresholds = thresholds, 
            load=thresholds_name,
            reset=reset_thresholds,
            delete=delete_threshold
            )
        
        
    def _load_components(self) -> dict:
        """Tải và chuẩn hóa danh sách component từ odata và các hàm mẫu."""
        components = {
            "brand": [br.lower() for br in self.odata["BRAND"].unique()],
            "gpu": sorted(glst(), key=len, reverse=False),
            "cpu": sorted(clst(), key=len, reverse=False),
            "ram": sorted(rlst(), key=len, reverse=False),
            "resolution": sorted(srlst(), key=len, reverse=True),
            "refresh rate": sorted(rrlst(), key=len, reverse=False),
            "display type": sorted(dtlst(), key=len, reverse=False),
            "screen size": sorted(sslst(), key=len, reverse=False),
        }
        return components
    
    def _load_thresholds(
        self, 
        thresholds: Dict[str, float]=None, 
        load: str="defaul", 
        reset: bool=False, 
        delete: str=None
        ) -> dict:
        if delete is not None:
            del thresholds[delete]
            
        if not os.path.exists(paths["config"]) or reset == True:
            thresholds_load = json5.load(open(paths["config"], "r"))
        else:
            thresholds_load = {"default": {comp: 25 for comp in self.components.keys()}}
            
            
        if load not in thresholds_load:
            for comp in self.components.keys():
                if comp not in thresholds.keys():
                    thresholds[comp] = 25
            thresholds_load[load] = thresholds
        json5.dump(thresholds_load, open(paths["config"], "w"), indent=4)
        return thresholds_load[load]

    def _fuzzy_match(self, component: str=None, query: str=None) -> str:
        min_score = self.thresholds.get(component)
        score_list = process.extractOne(query, self.components[component], scorer=fuzz.WRatio)
        if score_list[1] >= min_score:
            return score_list[0]

    def basic_extract(self, query: str) -> dict:
        """Trích xuất thông tin cơ bản từ text."""
        query = query.lower()
        extracted= {}
        for comps, values in self.components.items():
            for value in values:
                if value in query:
                    extracted.update({comps: value})
            for comp in self.components.keys():
                if comp not in extracted.keys():
                    extracted.update({comp: None})
        return extracted

    def next_extract(self, query: str) -> dict:
        """Trích xuất thông tin chi tiết từ text."""
        query = query.lower()
        extracted = self.basic_extract(query)
        for comp, value in extracted.items():
            if value is None:
                value = self._fuzzy_match(query = query, component = comp)
                extracted[comp] = value
        return extracted