Skip to content

Commit

Permalink
fix imports hard
Browse files Browse the repository at this point in the history
  • Loading branch information
Egor Baturin committed Jun 5, 2024
1 parent d257124 commit 10f34c4
Show file tree
Hide file tree
Showing 14 changed files with 68 additions and 42 deletions.
9 changes: 5 additions & 4 deletions etna/libs/ts2vec/dilated_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,11 @@
"""
# Note: Copied from ts2vec repository (https://github.com/yuezhihan/ts2vec/tree/main)

import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
from etna import SETTINGS

if SETTINGS.torch_required:
from torch import nn
import torch.nn.functional as F


class SamePadConv(nn.Module):
Expand Down
9 changes: 6 additions & 3 deletions etna/libs/ts2vec/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,15 @@
"""
# Note: Copied from ts2vec repository (https://github.com/yuezhihan/ts2vec/tree/main)

import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
from etna.libs.ts2vec.dilated_conv import DilatedConvEncoder

from etna import SETTINGS

if SETTINGS.torch_required:
import torch
from torch import nn


def generate_continuous_mask(B, T, n=5, l=0.1):
res = torch.full((B, T), True, dtype=torch.bool)
Expand Down
7 changes: 4 additions & 3 deletions etna/libs/ts2vec/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,11 @@
SOFTWARE.
"""
# Note: Copied from ts2vec repository (https://github.com/yuezhihan/ts2vec/tree/main)
from etna import SETTINGS

import torch
from torch import nn
import torch.nn.functional as F
if SETTINGS.torch_required:
import torch
import torch.nn.functional as F


def hierarchical_contrastive_loss(z1, z2, alpha=0.5, temporal_unit=0):
Expand Down
13 changes: 8 additions & 5 deletions etna/libs/ts2vec/ts2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,18 @@
# Removed skipping training loop when model is already pretrained. Removed "multiscale" encode option.
# Move lr parameter to fit method

import torch
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
import numpy as np
from etna.libs.ts2vec.encoder import TSEncoder
from etna.loggers import tslogger, ConsoleLogger
from etna.loggers import tslogger
from etna.libs.ts2vec.losses import hierarchical_contrastive_loss
from etna.libs.ts2vec.utils import take_per_row, split_with_nan, centerize_vary_length_series, torch_pad_nan, AveragedModel
import math

from etna import SETTINGS

if SETTINGS.torch_required:
import torch
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader


class TS2Vec:
Expand Down
9 changes: 5 additions & 4 deletions etna/libs/ts2vec/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,13 @@
# Note: Copied from ts2vec repository (https://github.com/yuezhihan/ts2vec/tree/main)

import numpy as np
import pickle
import torch
import random
from datetime import datetime
from copy import deepcopy

from etna import SETTINGS

if SETTINGS.torch_required:
import torch


def torch_pad_nan(arr, left=0, right=0, dim=0):
if left > 0:
Expand Down
11 changes: 7 additions & 4 deletions etna/libs/tstcc/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,13 @@
"""
# Note: Copied from ts-tcc repository (https://github.com/emadeldeen24/TS-TCC/tree/main)

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from etna import SETTINGS

if SETTINGS.torch_required:
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat


class Residual(nn.Module):
Expand Down
6 changes: 5 additions & 1 deletion etna/libs/tstcc/augmentations.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,11 @@
# Fix numpy warning in `permutation` function

import numpy as np
import torch

from etna import SETTINGS

if SETTINGS.torch_required:
import torch


def DataTransform(sample, jitter_scale_ratio, max_seg, jitter_ratio):
Expand Down
8 changes: 5 additions & 3 deletions etna/libs/tstcc/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,13 @@
"""
# Note: Copied from ts-tcc repository (https://github.com/emadeldeen24/TS-TCC/tree/main)

import torch
from torch.utils.data import Dataset

from etna.libs.tstcc.augmentations import DataTransform

from etna import SETTINGS

if SETTINGS.torch_required:
import torch
from torch.utils.data import Dataset

class Load_Dataset(Dataset):
# Initialize your data, download, etc.
Expand Down
5 changes: 4 additions & 1 deletion etna/libs/tstcc/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,10 @@
# Added ignoring warning about even kernel lengths and odd dilation in nn.Conv1d blocks.
import warnings

from torch import nn
from etna import SETTINGS

if SETTINGS.torch_required:
from torch import nn


class ConvEncoder(nn.Module):
Expand Down
6 changes: 5 additions & 1 deletion etna/libs/tstcc/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,13 @@
"""
# Note: Copied from ts-tcc repository (https://github.com/emadeldeen24/TS-TCC/tree/main)

import torch
import numpy as np

from etna import SETTINGS

if SETTINGS.torch_required:
import torch


class NTXentLoss(torch.nn.Module):

Expand Down
7 changes: 5 additions & 2 deletions etna/libs/tstcc/tc.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,13 @@

import warnings

import torch
import torch.nn as nn
import numpy as np
from etna.libs.tstcc.attention import Seq_Transformer
from etna import SETTINGS

if SETTINGS.torch_required:
import torch
import torch.nn as nn


class TC(nn.Module):
Expand Down
11 changes: 7 additions & 4 deletions etna/libs/tstcc/tstcc.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,13 @@
from etna.libs.tstcc.dataloader import Load_Dataset
from etna.libs.tstcc.loss import NTXentLoss
from etna.loggers import tslogger
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader

from etna import SETTINGS

if SETTINGS.torch_required:
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader


class TSTCC:
Expand Down
5 changes: 1 addition & 4 deletions etna/transforms/embeddings/models/ts2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,8 @@

import numpy as np

from etna import SETTINGS
from etna.transforms.embeddings.models import BaseEmbeddingModel

if SETTINGS.torch_required:
from etna.libs.ts2vec import TS2Vec
from etna.libs.ts2vec import TS2Vec


class TS2VecEmbeddingModel(BaseEmbeddingModel):
Expand Down
4 changes: 1 addition & 3 deletions etna/transforms/embeddings/models/tstcc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,9 @@

import numpy as np

from etna import SETTINGS
from etna.transforms.embeddings.models import BaseEmbeddingModel

if SETTINGS.torch_required:
from etna.libs.tstcc import TSTCC
from etna.libs.tstcc import TSTCC


class TSTCCEmbeddingModel(BaseEmbeddingModel):
Expand Down

0 comments on commit 10f34c4

Please sign in to comment.