Skip to content

Commit

Permalink
Allow to change device, batch_size and num_workers of embedding…
Browse files Browse the repository at this point in the history
… models (#396)

* move device

* fix

* reformat batch_size and num_workers params

* chore: update changelog

* fix notebook

* fix docs

* fix docs

* lints

* lints

* fix docs

---------

Co-authored-by: Egor Baturin <egoriyaa@github.com>
  • Loading branch information
egoriyaa and Egor Baturin committed Jun 19, 2024
1 parent 1318ff3 commit 7429c39
Show file tree
Hide file tree
Showing 7 changed files with 206 additions and 150 deletions.
4 changes: 2 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Add `get_anomalies_isolation_forest` method for anomaly detection ([#375](https://github.com/etna-team/etna/pull/375))
- Add `IForestOutlierTransform` ([#381](https://github.com/etna-team/etna/pull/381))
- Add `IQROutlierTransform` ([#387](https://github.com/etna-team/etna/pull/387))
-
- Add `num_workers` parameter to `TS2VecEmbeddingModel` ([#396](https://github.com/etna-team/etna/pull/396))
-
-

### Changed
-
- Allow to change `device`, `batch_size` and `num_workers` of embedding models ([#396](https://github.com/etna-team/etna/pull/396))
-
-
-
Expand Down
52 changes: 30 additions & 22 deletions etna/libs/ts2vec/ts2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
"""
# Note: Copied from ts2vec repository (https://github.com/yuezhihan/ts2vec/tree/main)
# Removed skipping training loop when model is already pretrained. Removed "multiscale" encode option.
# Move lr parameter to fit method
# Move lr, device, batch_size parameters to fit method
# Move device, batch_size parameters to encode method

import torch
import torch.nn.functional as F
Expand All @@ -45,8 +46,6 @@ def __init__(
output_dims=320,
hidden_dims=64,
depth=10,
device='cuda',
batch_size=16,
max_train_length=None,
temporal_unit=0,
after_iter_callback=None,
Expand All @@ -59,22 +58,17 @@ def __init__(
output_dims (int): The representation dimension.
hidden_dims (int): The hidden dimension of the encoder.
depth (int): The number of hidden residual blocks in the encoder.
device (str): The gpu used for training and inference.
batch_size (int): The batch size.
max_train_length (Union[int, NoneType]): The maximum allowed sequence length for training. For sequence with a length greater than <max_train_length>, it would be cropped into some sequences, each of which has a length less than <max_train_length>.
temporal_unit (int): The minimum unit to perform temporal contrast. When training on a very long sequence, this param helps to reduce the cost of time and memory.
after_iter_callback (Union[Callable, NoneType]): A callback function that would be called after each iteration.
after_epoch_callback (Union[Callable, NoneType]): A callback function that would be called after each epoch.
'''

super().__init__()
self.device = device
self.batch_size = batch_size
self.max_train_length = max_train_length
self.temporal_unit = temporal_unit

self._net = TSEncoder(input_dims=input_dims, output_dims=output_dims, hidden_dims=hidden_dims, depth=depth).to(
self.device)
self._net = TSEncoder(input_dims=input_dims, output_dims=output_dims, hidden_dims=hidden_dims, depth=depth)
self.net = AveragedModel(self._net)
self.net.update_parameters(self._net)

Expand All @@ -84,7 +78,7 @@ def __init__(
self.n_epochs = 0
self.n_iters = 0

def fit(self, train_data, lr=0.001, n_epochs=None, n_iters=None, verbose=False):
def fit(self, train_data, lr=0.001, n_epochs=None, n_iters=None, verbose=False, device="cpu", batch_size=16, num_workers=0):
''' Training the TS2Vec model.
Args:
Expand All @@ -93,6 +87,9 @@ def fit(self, train_data, lr=0.001, n_epochs=None, n_iters=None, verbose=False):
n_epochs (Union[int, NoneType]): The number of epochs. When this reaches, the training stops.
n_iters (Union[int, NoneType]): The number of iterations. When this reaches, the training stops. If both n_epochs and n_iters are not specified, a default setting would be used that sets n_iters to 200 for a dataset with size <= 100000, 600 otherwise.
verbose (bool): Whether to print the training loss after each epoch.
device (str): The gpu used for training and inference.
batch_size (int): The batch size.
num_workers (int): How many subprocesses to use for data loading
Returns:
loss_log: a list containing the training losses on each epoch.
Expand All @@ -115,15 +112,18 @@ def fit(self, train_data, lr=0.001, n_epochs=None, n_iters=None, verbose=False):
train_data = train_data[~np.isnan(train_data).all(axis=2).all(axis=1)]

train_dataset = TensorDataset(torch.from_numpy(train_data).to(torch.float))
train_loader = DataLoader(train_dataset, batch_size=min(self.batch_size, len(train_dataset)), shuffle=True,
drop_last=True)
train_loader = DataLoader(train_dataset, batch_size=min(batch_size, len(train_dataset)), shuffle=True,
drop_last=True, num_workers=num_workers)

optimizer = torch.optim.AdamW(self._net.parameters(), lr=lr)

loss_log = []

cur_epoch = 0
cur_iter = 0

self._net.to(device)
self.net.to(device)
while True:
if n_epochs is not None and cur_epoch >= n_epochs:
break
Expand All @@ -141,7 +141,7 @@ def fit(self, train_data, lr=0.001, n_epochs=None, n_iters=None, verbose=False):
if self.max_train_length is not None and x.size(1) > self.max_train_length:
window_offset = np.random.randint(x.size(1) - self.max_train_length + 1)
x = x[:, window_offset: window_offset + self.max_train_length]
x = x.to(self.device)
x = x.to(device)

ts_l = x.size(1)
crop_l = np.random.randint(low=2 ** (self.temporal_unit + 1), high=ts_l + 1)
Expand Down Expand Up @@ -191,8 +191,8 @@ def fit(self, train_data, lr=0.001, n_epochs=None, n_iters=None, verbose=False):

return loss_log

def _eval_with_pooling(self, x, mask=None, slicing=None, encoding_window=None):
out = self.net(x.to(self.device, non_blocking=True), mask)
def _eval_with_pooling(self, x, mask=None, slicing=None, encoding_window=None, device="cpu"):
out = self.net(x.to(device, non_blocking=True), mask)
if encoding_window == 'full_series':
if slicing is not None:
out = out[:, slicing]
Expand Down Expand Up @@ -220,7 +220,7 @@ def _eval_with_pooling(self, x, mask=None, slicing=None, encoding_window=None):
return out.cpu()

def encode(self, data, mask=None, encoding_window=None, causal=False, sliding_length=None, sliding_padding=0,
batch_size=None):
batch_size=None, device="cpu", num_workers=0):
''' Compute representations using the model.
Args:
Expand All @@ -231,6 +231,8 @@ def encode(self, data, mask=None, encoding_window=None, causal=False, sliding_le
sliding_length (Union[int, NoneType]): The length of sliding window. When this param is specified, a sliding inference would be applied on the time series.
sliding_padding (int): This param specifies the contextual data length used for inference every sliding windows.
batch_size (Union[int, NoneType]): The batch size used for inference. If not specified, this would be the same batch size as training.
device (str): The gpu used for training and inference.
num_workers (int): How many subprocesses to use for data loading
Returns:
repr: The representations for data.
Expand All @@ -242,10 +244,13 @@ def encode(self, data, mask=None, encoding_window=None, causal=False, sliding_le
n_samples, ts_l, _ = data.shape

org_training = self.net.training

self._net.to(device)
self.net.to(device)
self.net.eval()

dataset = TensorDataset(torch.from_numpy(data).to(torch.float))
loader = DataLoader(dataset, batch_size=batch_size)
loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers)

with torch.no_grad():
output = []
Expand All @@ -271,7 +276,8 @@ def encode(self, data, mask=None, encoding_window=None, causal=False, sliding_le
torch.cat(calc_buffer, dim=0),
mask,
slicing=slice(sliding_padding, sliding_padding + sliding_length),
encoding_window=encoding_window
encoding_window=encoding_window,
device=device
)
reprs += torch.split(out, n_samples)
calc_buffer = []
Expand All @@ -283,7 +289,8 @@ def encode(self, data, mask=None, encoding_window=None, causal=False, sliding_le
x_sliding,
mask,
slicing=slice(sliding_padding, sliding_padding + sliding_length),
encoding_window=encoding_window
encoding_window=encoding_window,
device=device
)
reprs.append(out)

Expand All @@ -293,7 +300,8 @@ def encode(self, data, mask=None, encoding_window=None, causal=False, sliding_le
torch.cat(calc_buffer, dim=0),
mask,
slicing=slice(sliding_padding, sliding_padding + sliding_length),
encoding_window=encoding_window
encoding_window=encoding_window,
device=device
)
reprs += torch.split(out, n_samples)
calc_buffer = []
Expand All @@ -306,7 +314,7 @@ def encode(self, data, mask=None, encoding_window=None, causal=False, sliding_le
kernel_size=out.size(1),
).squeeze(1)
else:
out = self._eval_with_pooling(x, mask, encoding_window=encoding_window)
out = self._eval_with_pooling(x, mask, encoding_window=encoding_window, device=device)
if encoding_window == 'full_series':
out = out.squeeze(1)

Expand All @@ -333,5 +341,5 @@ def load(self, fn):
Args:
fn (str): filename.
'''
state_dict = torch.load(fn, map_location=self.device)
state_dict = torch.load(fn, map_location=torch.device("cpu"))
self.net.load_state_dict(state_dict)
10 changes: 4 additions & 6 deletions etna/libs/tstcc/tc.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ def __init__(
hidden_dim,
heads,
depth,
device,
n_seq_steps
):
super(TC, self).__init__()
Expand All @@ -51,7 +50,6 @@ def __init__(
self.depth = depth
self.Wk = nn.ModuleList([nn.Linear(hidden_dim, self.num_channels) for i in range(self.timestep)])
self.lsoftmax = nn.LogSoftmax(dim=1)
self.device = device
self.n_seq_steps = n_seq_steps
with warnings.catch_warnings():
warnings.simplefilter("ignore")
Expand All @@ -65,7 +63,7 @@ def __init__(
self.seq_transformer = Seq_Transformer(patch_size=self.num_channels, dim=self.hidden_dim, depth=self.depth,
heads=self.heads, mlp_dim=64)

def forward(self, features_aug1, features_aug2):
def forward(self, features_aug1, features_aug2, device):
z_aug1 = features_aug1 # features are (batch_size, #channels, seq_len)
seq_len = z_aug1.shape[2]
z_aug1 = z_aug1.transpose(1, 2)
Expand All @@ -75,10 +73,10 @@ def forward(self, features_aug1, features_aug2):

batch = z_aug1.shape[0]
t_samples = torch.randint(seq_len - self.timestep, size=(1,)).long().to(
self.device) # randomly pick time stamps
device) # randomly pick time stamps

score = 0 # average over timestep and batch
encode_samples = torch.empty((self.timestep, batch, self.num_channels)).float().to(self.device)
encode_samples = torch.empty((self.timestep, batch, self.num_channels)).float().to(device)

for i in np.arange(1, self.timestep + 1):
encode_samples[i - 1] = z_aug2[:, t_samples + i, :].view(batch, self.num_channels)
Expand All @@ -87,7 +85,7 @@ def forward(self, features_aug1, features_aug2):

c_t = self.seq_transformer(forward_seq)

pred = torch.empty((self.timestep, batch, self.num_channels)).float().to(self.device)
pred = torch.empty((self.timestep, batch, self.num_channels)).float().to(device)
for i in np.arange(0, self.timestep):
linear = self.Wk[i]
pred[i] = linear(c_t)
Expand Down

0 comments on commit 7429c39

Please sign in to comment.