-
Notifications
You must be signed in to change notification settings - Fork 11
Rnn scorer #190
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Rnn scorer #190
Conversation
|
autointent/modules/scoring/_rnn.py
Outdated
|
|
||
| def __init__( | ||
| self, | ||
| rnn_config: RNNConfig | str | dict[str, Any] | None = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Лучше делать не так.лусге чтобы все параметры были непосредственными параметрами конструктора. Без этого не будет удобно указывать серч спейс
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
class RNNConfig(BaseModel):
model_config = ConfigDict(extra="forbid")
model_name: str = Field("rnn", description="Name of the RNN model.")
embed_dim: int = Field(128, description="Dimension of word embeddings.")
hidden_dim: int = Field(512, description="Dimension of hidden states in RNN.")
n_layers: int = Field(2, description="Number of RNN layers.")
dropout: float = Field(0.1, description="Dropout rate.")
device: str = Field(None, description="Torch notation for CPU or CUDA.")
max_seq_length: int = Field(128, description="Maximum sequence length.")
padding_idx: int = Field(0, description="Index used for padding.")
pretrained_embs: Any = Field(None, description="Pretrained embedding weights if available.")
batch_size: PositiveInt = Field(32, description="Batch size for model inference.")
@classmethod
def from_search_config(cls, values: dict[str, Any] | str | BaseModel | None) -> Self:
if values is None:
return cls()
if isinstance(values, BaseModel):
return values # type: ignore[return-value]
if isinstance(values, str):
return cls(model_name=values)
return cls(**values)
Я правильно понял, что я должен удалить обратно этот класс, и все параметры продублировать в самом классе везде, где нужно?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
тогда init станет очень большим и from_context тоже.
Одни и те же параметры будут перечисляться 3 раза подряд в коде. Сначала в аргументах init, потом в реализации init, затем в аргументах from_context.
ruff запрещает в аргументы функции пихать больше 10 параметров
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
в целом верно, если что будем игнорить ошибку ruff, потому что серч спейс важнее
но я вот смотрю, некоторые параметры это не совсем гиперпараметры а действительно конфиг. я бы сделал так:
- отнести в конфиг: device, max_seq_length, padding_idx
- отнести в init: embed_dim, hidden_dim, n_layers, dropout
- убрать: model_name, pretrained_embs
|
Дампер уже реализован в пр про CNNScorer, но кажется он не работает до конца. Можете обсудить с Лерой и заколлабиться |
No description provided.