Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: typing issue raising warnings in mypy and pylance #118

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion beir/retrieval/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

class EvaluateRetrieval:

def __init__(self, retriever: Union[Type[DRES], Type[DRFS], Type[BM25], Type[SS]] = None, k_values: List[int] = [1,3,5,10,100,1000], score_function: str = "cos_sim"):
def __init__(self, retriever: Union[DRES, DRFS, BM25, SS] = None, k_values: List[int] = [1,3,5,10,100,1000], score_function: str = "cos_sim"):
self.k_values = k_values
self.top_k = max(k_values)
self.retriever = retriever
Expand Down
12 changes: 6 additions & 6 deletions beir/retrieval/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@

class TrainRetriever:

def __init__(self, model: Type[SentenceTransformer], batch_size: int = 64):
def __init__(self, model: SentenceTransformer, batch_size: int = 64):
self.model = model
self.batch_size = batch_size

def load_train(self, corpus: Dict[str, Dict[str, str]], queries: Dict[str, str],
qrels: Dict[str, Dict[str, int]]) -> List[Type[InputExample]]:
qrels: Dict[str, Dict[str, int]]) -> List[InputExample]:

query_ids = list(queries.keys())
train_samples = []
Expand All @@ -40,7 +40,7 @@ def load_train(self, corpus: Dict[str, Dict[str, str]], queries: Dict[str, str],
logger.info("Loaded {} training pairs.".format(len(train_samples)))
return train_samples

def load_train_triplets(self, triplets: List[Tuple[str, str, str]]) -> List[Type[InputExample]]:
def load_train_triplets(self, triplets: List[Tuple[str, str, str]]) -> List[InputExample]:

train_samples = []

Expand All @@ -53,15 +53,15 @@ def load_train_triplets(self, triplets: List[Tuple[str, str, str]]) -> List[Type
logger.info("Loaded {} training pairs.".format(len(train_samples)))
return train_samples

def prepare_train(self, train_dataset: List[Type[InputExample]], shuffle: bool = True, dataset_present: bool = False) -> DataLoader:
def prepare_train(self, train_dataset: List[InputExample], shuffle: bool = True, dataset_present: bool = False) -> DataLoader:

if not dataset_present:
train_dataset = SentencesDataset(train_dataset, model=self.model)

train_dataloader = DataLoader(train_dataset, shuffle=shuffle, batch_size=self.batch_size)
return train_dataloader

def prepare_train_triplets(self, train_dataset: List[Type[InputExample]]) -> DataLoader:
def prepare_train_triplets(self, train_dataset: List[InputExample]) -> DataLoader:

train_dataloader = datasets.NoDuplicatesDataLoader(train_dataset, batch_size=self.batch_size)
return train_dataloader
Expand Down Expand Up @@ -117,7 +117,7 @@ def fit(self,
steps_per_epoch = None,
scheduler: str = 'WarmupLinear',
warmup_steps: int = 10000,
optimizer_class: Type[Optimizer] = AdamW,
optimizer_class: Optimizer = AdamW,
optimizer_params : Dict[str, object]= {'lr': 2e-5, 'eps': 1e-6, 'correct_bias': False},
weight_decay: float = 0.01,
evaluation_steps: int = 0,
Expand Down