Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement TS2VecModel #253

Merged
merged 12 commits into from
Mar 6, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions etna/libs/ts2vec/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
"""
MIT License
d-a-bunin marked this conversation as resolved.
Show resolved Hide resolved

Copyright (c) 2022 Zhihan Yue

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""
# Note: Copied from ts2vec repository (https://github.com/yuezhihan/ts2vec/tree/main)

from etna.libs.ts2vec.ts2vec import TS2Vec
58 changes: 58 additions & 0 deletions etna/libs/ts2vec/dilated_conv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np


class SamePadConv(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, dilation=1, groups=1):
super().__init__()
self.receptive_field = (kernel_size - 1) * dilation + 1
padding = self.receptive_field // 2
self.conv = nn.Conv1d(
in_channels, out_channels, kernel_size,
padding=padding,
dilation=dilation,
groups=groups
)
self.remove = 1 if self.receptive_field % 2 == 0 else 0

def forward(self, x):
out = self.conv(x)
if self.remove > 0:
out = out[:, :, : -self.remove]

Check warning on line 23 in etna/libs/ts2vec/dilated_conv.py

View check run for this annotation

Codecov / codecov/patch

etna/libs/ts2vec/dilated_conv.py#L23

Added line #L23 was not covered by tests
return out


class ConvBlock(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, dilation, final=False):
super().__init__()
self.conv1 = SamePadConv(in_channels, out_channels, kernel_size, dilation=dilation)
self.conv2 = SamePadConv(out_channels, out_channels, kernel_size, dilation=dilation)
self.projector = nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels or final else None

def forward(self, x):
residual = x if self.projector is None else self.projector(x)
x = F.gelu(x)
x = self.conv1(x)
x = F.gelu(x)
x = self.conv2(x)
return x + residual


class DilatedConvEncoder(nn.Module):
def __init__(self, in_channels, channels, kernel_size):
super().__init__()
self.net = nn.Sequential(*[
ConvBlock(
channels[i - 1] if i > 0 else in_channels,
channels[i],
kernel_size=kernel_size,
dilation=2 ** i,
final=(i == len(channels) - 1)
)
for i in range(len(channels))
])

def forward(self, x):
return self.net(x)
76 changes: 76 additions & 0 deletions etna/libs/ts2vec/encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
from etna.libs.ts2vec.dilated_conv import DilatedConvEncoder


def generate_continuous_mask(B, T, n=5, l=0.1):
res = torch.full((B, T), True, dtype=torch.bool)
if isinstance(n, float):
n = int(n * T)
n = max(min(n, T // 2), 1)

Check warning on line 12 in etna/libs/ts2vec/encoder.py

View check run for this annotation

Codecov / codecov/patch

etna/libs/ts2vec/encoder.py#L9-L12

Added lines #L9 - L12 were not covered by tests

if isinstance(l, float):
l = int(l * T)
l = max(l, 1)

Check warning on line 16 in etna/libs/ts2vec/encoder.py

View check run for this annotation

Codecov / codecov/patch

etna/libs/ts2vec/encoder.py#L14-L16

Added lines #L14 - L16 were not covered by tests

for i in range(B):
for _ in range(n):
t = np.random.randint(T - l + 1)
res[i, t:t + l] = False
return res

Check warning on line 22 in etna/libs/ts2vec/encoder.py

View check run for this annotation

Codecov / codecov/patch

etna/libs/ts2vec/encoder.py#L18-L22

Added lines #L18 - L22 were not covered by tests


def generate_binomial_mask(B, T, p=0.5):
return torch.from_numpy(np.random.binomial(1, p, size=(B, T))).to(torch.bool)


class TSEncoder(nn.Module):
def __init__(self, input_dims, output_dims, hidden_dims=64, depth=10, mask_mode='binomial'):
super().__init__()
self.input_dims = input_dims
self.output_dims = output_dims
self.hidden_dims = hidden_dims
self.mask_mode = mask_mode
self.input_fc = nn.Linear(input_dims, hidden_dims)
self.feature_extractor = DilatedConvEncoder(
hidden_dims,
[hidden_dims] * depth + [output_dims],
kernel_size=3
)
self.repr_dropout = nn.Dropout(p=0.1)

def forward(self, x, mask=None): # x: B x T x input_dims
nan_mask = ~x.isnan().any(axis=-1)
x[~nan_mask] = 0
x = self.input_fc(x) # B x T x Ch

# generate & apply mask
if mask is None:
if self.training:
mask = self.mask_mode
else:
mask = 'all_true'

Check warning on line 54 in etna/libs/ts2vec/encoder.py

View check run for this annotation

Codecov / codecov/patch

etna/libs/ts2vec/encoder.py#L54

Added line #L54 was not covered by tests

if mask == 'binomial':
mask = generate_binomial_mask(x.size(0), x.size(1)).to(x.device)
elif mask == 'continuous':
mask = generate_continuous_mask(x.size(0), x.size(1)).to(x.device)

Check warning on line 59 in etna/libs/ts2vec/encoder.py

View check run for this annotation

Codecov / codecov/patch

etna/libs/ts2vec/encoder.py#L59

Added line #L59 was not covered by tests
elif mask == 'all_true':
mask = x.new_full((x.size(0), x.size(1)), True, dtype=torch.bool)
elif mask == 'all_false':
mask = x.new_full((x.size(0), x.size(1)), False, dtype=torch.bool)
elif mask == 'mask_last':
mask = x.new_full((x.size(0), x.size(1)), True, dtype=torch.bool)
mask[:, -1] = False

Check warning on line 66 in etna/libs/ts2vec/encoder.py

View check run for this annotation

Codecov / codecov/patch

etna/libs/ts2vec/encoder.py#L62-L66

Added lines #L62 - L66 were not covered by tests

mask &= nan_mask
x[~mask] = 0

# conv encoder
x = x.transpose(1, 2) # B x Ch x T
x = self.repr_dropout(self.feature_extractor(x)) # B x Co x T
x = x.transpose(1, 2) # B x T x Co

return x
53 changes: 53 additions & 0 deletions etna/libs/ts2vec/losses.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import torch
from torch import nn
import torch.nn.functional as F


def hierarchical_contrastive_loss(z1, z2, alpha=0.5, temporal_unit=0):
loss = torch.tensor(0., device=z1.device)
d = 0
while z1.size(1) > 1:
if alpha != 0:
loss += alpha * instance_contrastive_loss(z1, z2)
if d >= temporal_unit:
if 1 - alpha != 0:
loss += (1 - alpha) * temporal_contrastive_loss(z1, z2)
d += 1
z1 = F.max_pool1d(z1.transpose(1, 2), kernel_size=2).transpose(1, 2)
z2 = F.max_pool1d(z2.transpose(1, 2), kernel_size=2).transpose(1, 2)
if z1.size(1) == 1:
if alpha != 0:
loss += alpha * instance_contrastive_loss(z1, z2)
d += 1
return loss / d


def instance_contrastive_loss(z1, z2):
B, T = z1.size(0), z1.size(1)
if B == 1:
return z1.new_tensor(0.)

Check warning on line 28 in etna/libs/ts2vec/losses.py

View check run for this annotation

Codecov / codecov/patch

etna/libs/ts2vec/losses.py#L28

Added line #L28 was not covered by tests
z = torch.cat([z1, z2], dim=0) # 2B x T x C
z = z.transpose(0, 1) # T x 2B x C
sim = torch.matmul(z, z.transpose(1, 2)) # T x 2B x 2B
logits = torch.tril(sim, diagonal=-1)[:, :, :-1] # T x 2B x (2B-1)
logits += torch.triu(sim, diagonal=1)[:, :, 1:]
logits = -F.log_softmax(logits, dim=-1)

i = torch.arange(B, device=z1.device)
loss = (logits[:, i, B + i - 1].mean() + logits[:, B + i, i].mean()) / 2
return loss


def temporal_contrastive_loss(z1, z2):
B, T = z1.size(0), z1.size(1)
if T == 1:
return z1.new_tensor(0.)

Check warning on line 44 in etna/libs/ts2vec/losses.py

View check run for this annotation

Codecov / codecov/patch

etna/libs/ts2vec/losses.py#L44

Added line #L44 was not covered by tests
z = torch.cat([z1, z2], dim=1) # B x 2T x C
sim = torch.matmul(z, z.transpose(1, 2)) # B x 2T x 2T
logits = torch.tril(sim, diagonal=-1)[:, :, :-1] # B x 2T x (2T-1)
logits += torch.triu(sim, diagonal=1)[:, :, 1:]
logits = -F.log_softmax(logits, dim=-1)

t = torch.arange(T, device=z1.device)
loss = (logits[:, t, T + t - 1].mean() + logits[:, T + t, t].mean()) / 2
return loss
Loading
Loading