-
Notifications
You must be signed in to change notification settings - Fork 213
Closed
Description
url :https://github.com/salesforce/DeepTime/blob/main/models/modules/regressors.py
code:
def get_weights(self, X: Tensor, Y: Tensor, reg_coeff: float) -> Tensor:
batch_size, n_samples, n_dim = X.shape
ones = torch.ones(batch_size, n_samples, 1, device=X.device)
X = torch.concat([X, ones], dim=-1) # this
if n_samples >= n_dim:
# standard
A = torch.bmm(X.mT, X)
A.diagonal(dim1=-2, dim2=-1).add_(reg_coeff)
B = torch.bmm(X.mT, Y)
weights = torch.linalg.solve(A, B)
else:
# Woodbury
A = torch.bmm(X, X.mT)
A.diagonal(dim1=-2, dim2=-1).add_(reg_coeff)
weights = torch.bmm(X.mT, torch.linalg.solve(A, Y))
return weights[:, :-1], weights[:, -1:]Metadata
Metadata
Assignees
Labels
No labels